Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task API Objc/Swift implementation #9

Merged
merged 1 commit into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,29 @@ http_archive(
],
)

# Apple and Swift rules.
# https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_rules_apple",
sha256 = "ee9e6073aeb5a65c100cb9c44b0017c937706a4ae03176e14a7e78620a198079",
strip_prefix = "rules_apple-5131f3d46794bf227d296c82f30c2499c9de3c5b",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_apple/archive/5131f3d46794bf227d296c82f30c2499c9de3c5b.tar.gz",
"https://github.com/bazelbuild/rules_apple/archive/5131f3d46794bf227d296c82f30c2499c9de3c5b.tar.gz",
],
)

# https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "d0833bc6dad817a367936a5f902a0c11318160b5e80a20ece35fb85a5675c886",
strip_prefix = "rules_swift-3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_swift/archive/3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8.tar.gz",
"https://github.com/bazelbuild/rules_swift/archive/3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8.tar.gz",
],
)

http_archive(
name = "org_tensorflow",
strip_prefix = "tensorflow-2.2.0",
Expand Down
47 changes: 47 additions & 0 deletions tensorflow_lite_support/ios/task/text/qa/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
load("//third_party/tensorflow/lite/experimental/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION")
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
load("//third_party/tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner")

package(
default_visibility = ["//third_party/tensorflow_lite_support:users"],
licenses = ["notice"], # Apache 2.0
)

objc_library(
name = "TFLBertQuestionAnswerer",
srcs = ["Sources/TFLBertQuestionAnswerer.mm"],
hdrs = ["Sources/TFLBertQuestionAnswerer.h"],
module_name = "TFLBertQuestionAnswerer",
deps = [
"//third_party/objective_c/google_toolbox_for_mac:GTM_Defines",
"//third_party/tensorflow_lite_support/cc/task/text/qa:bert_question_answerer",
"//third_party/tensorflow_lite_support/ios/utils:TFLStringUtil",
],
)

swift_library(
name = "TFLBertQuestionAnswererTestLibrary",
testonly = 1,
srcs = glob(["Tests/*.swift"]),
data = [
"//third_party/tensorflow_lite_support/cc/testdata:albert_model",
"//third_party/tensorflow_lite_support/cc/testdata:mobile_bert_model",
],
tags = TFL_DEFAULT_TAGS,
deps = [
":TFLBertQuestionAnswerer",
"//third_party/swift/xctest",
],
)

ios_unit_test(
name = "TFLBertQuestionAnswererTest",
size = "large",
minimum_os_version = TFL_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":TFLBertQuestionAnswererTestLibrary",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <Foundation/Foundation.h>

NS_ASSUME_NONNULL_BEGIN
/**
* Struct to represent the logit and offset of the answer related to context.
*/
struct TFLPos {
int start;
int end;
float logit;
};

/**
* Class for the Answer to BertQuestionAnswerer.
*/
@interface TFLQAAnswer : NSObject
@property(nonatomic) struct TFLPos pos;
@property(nonatomic) NSString* text;
@end

/**
* BertQA task API, performs tokenization for models (BERT, Albert, etc.) in
* preprocess and returns most possible answers.
*
* In particular, the branch of BERT models use WordPiece tokenizer, and the
* branch of Albert models use SentencePiece tokenizer, respectively.
*/
@interface TFLBertQuestionAnswerer : NSObject

/**
* Creates a BertQuestionAnswerer instance with a mobilebert model and
* vocabulary file for wordpiece tokenization.
* One suitable model is:
* https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
* @param modelPath The file path to the mobilebert tflite model.
* @param vocabPath The file path to the vocab file for wordpiece tokenization.
*
* @return A BertQuestionAnswerer instance.
*/
+ (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath
vocabPath:(NSString*)vocabPath
NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:));

/**
* Creates a BertQuestionAnswerer instance with an albert model and spmodel file
* for sentencepiece tokenization.
* One suitable model is:
* https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
* @param modelPath The file path to the albert tflite model.
* @param setencepieceModelPath The file path to the model file for sentence piece tokenization.
*
* @return A BertQuestionAnswerer instance.
*/
+ (instancetype)albertQuestionAnswererWithModelPath:(NSString*)modelPath
setencepieceModelPath:(NSString*)setencepieceModelPath
NS_SWIFT_NAME(albertQuestionAnswerer(modelPath:setencepieceModelPath:));

/**
* Answers question based on the context.
* @param context Context the question bases on.
* @param question Question to ask.
*
* @return A list of answers to the question, reversely sorted by the probability of each answer.
*/
- (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context
question:(NSString*)question
NS_SWIFT_NAME(answer(context:question:));
@end
NS_ASSUME_NONNULL_END
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "third_party/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h"
#import "third_party/objective_c/google_toolbox_for_mac/GTMDefines.h"
#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h"

#include "third_party/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h"

NS_ASSUME_NONNULL_BEGIN
using BertQuestionAnswererCPP = ::tflite::support::task::text::qa::BertQuestionAnswerer;
using QuestionAnswererCPP = ::tflite::support::task::text::qa::QuestionAnswerer;
using QaAnswerCPP = ::tflite::support::task::text::qa::QaAnswer;

@implementation TFLQAAnswer
@synthesize pos;
@synthesize text;
@end

@implementation TFLBertQuestionAnswerer {
std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer;
}

+ (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath
vocabPath:(NSString *)vocabPath {
absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer =
BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath),
MakeString(vocabPath));
_GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer");
return [[TFLBertQuestionAnswerer alloc]
initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())];
}

+ (instancetype)albertQuestionAnswererWithModelPath:(NSString *)modelPath
setencepieceModelPath:(NSString *)setencepieceModelPath {
absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer =
BertQuestionAnswererCPP::CreateAlbertQuestionAnswerer(MakeString(modelPath),
MakeString(setencepieceModelPath));
_GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer");
return [[TFLBertQuestionAnswerer alloc]
initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())];
}

- (instancetype)initWithQuestionAnswerer:
(std::unique_ptr<QuestionAnswererCPP>)bertQuestionAnswerer {
self = [super init];
if (self) {
_bertQuestionAnswerwer = std::move(bertQuestionAnswerer);
}
return self;
}

- (NSMutableArray<TFLQAAnswer *> *)arrayFromVector:(std::vector<QaAnswerCPP>)vector {
NSMutableArray<TFLQAAnswer *> *ret = [NSMutableArray arrayWithCapacity:vector.size()];

for (int i = 0; i < vector.size(); i++) {
QaAnswerCPP answerCpp = vector[i];
TFLQAAnswer *answer = [[TFLQAAnswer alloc] init];
[answer setPos:{.start = answerCpp.pos.start,
.end = answerCpp.pos.end,
.logit = answerCpp.pos.logit}];
[answer setText:MakeNSString(answerCpp.text)];
[ret addObject:answer];
}
return ret;
}

- (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question {
std::vector<QaAnswerCPP> results =
_bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question));
return [self arrayFromVector:results];
}
@end
NS_ASSUME_NONNULL_END
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import XCTest

@testable import TFLBertQuestionAnswerer

class TFLBertQuestionAnswererTest: XCTestCase {

static let bundle = Bundle(for: TFLBertQuestionAnswererTest.self)
static let mobileBertModelPath = bundle.path(forResource: "mobile_bert", ofType: "tflite")!
static let mobileBertVocabPath = bundle.path(forResource: "vocab", ofType: "txt")!

static let albertModelPath = bundle.path(forResource: "albert", ofType: "tflite")!
static let albertSPmodelPath = bundle.path(forResource: "30k-clean", ofType: "model")!

static let context = """
The role of teacher is often formal and ongoing, carried out at a school or other place of
formal education. In many countries, a person who wishes to become a teacher must first obtain
specified professional qualifications or credentials from a university or college. These
professional qualifications may include the study of pedagogy, the science of teaching.
Teachers, like other professionals, may have to continue their education after they qualify,
a process known as continuing professional development. Teachers may use a lesson plan to
facilitate student learning, providing a course of study which is called the curriculum.
"""
static let question = "What is a course of study called?"
static let answer = "the curriculum."

func testInitMobileBert() {
let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
modelPath: TFLBertQuestionAnswererTest.mobileBertModelPath,
vocabPath: TFLBertQuestionAnswererTest.mobileBertVocabPath)

XCTAssertNotNil(mobileBertAnswerer)

let answers = mobileBertAnswerer.answer(
context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)

XCTAssertNotNil(answers)
XCTAssertTrue(answers.count >= 0)
XCTAssertEqual(answers[0].text, TFLBertQuestionAnswererTest.answer)
}

func testInitAlbert() {
let albertAnswerer = TFLBertQuestionAnswerer.albertQuestionAnswerer(
modelPath: TFLBertQuestionAnswererTest.albertModelPath,
setencepieceModelPath: TFLBertQuestionAnswererTest.albertSPmodelPath)

XCTAssertNotNil(albertAnswerer)

let answers = albertAnswerer.answer(
context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)

XCTAssertNotNil(answers)
XCTAssertTrue(answers.count >= 0)
XCTAssertEqual(answers[0].text, TFLBertQuestionAnswererTest.answer)
}
}
Loading