Skip to content

Commit

Permalink
BertNLClassifier - a NLClassifier extension for bert models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 321883577
  • Loading branch information
flamearrow authored and tflite-support-robot committed Jul 18, 2020
1 parent 7eeef21 commit d40c3fe
Show file tree
Hide file tree
Showing 5 changed files with 441 additions and 60 deletions.
39 changes: 39 additions & 0 deletions tensorflow_lite_support/cc/task/text/nlclassifier/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,52 @@ cc_library(
],
deps = [
"//tensorflow_lite_support/cc:common",
"//tensorflow_lite_support/cc/port:status_macros",
"//tensorflow_lite_support/cc/port:statusor",
"//tensorflow_lite_support/cc/task/core:base_task_api",
"//tensorflow_lite_support/cc/task/core:category",
"//tensorflow_lite_support/cc/task/core:task_api_factory",
"//tensorflow_lite_support/cc/task/core:task_utils",
"//tensorflow_lite_support/cc/utils:common_utils",
"//tensorflow_lite_support/metadata/cc:metadata_extractor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:string",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/core/api",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
],
)

cc_library(
name = "bert_nl_classifier",
srcs = [
"bert_nl_classifier.cc",
],
hdrs = [
"bert_nl_classifier.h",
],
deps = [
":nl_classifier",
"//tensorflow_lite_support/cc:common",
"//tensorflow_lite_support/cc/port:status_macros",
"//tensorflow_lite_support/cc/task/core:category",
"//tensorflow_lite_support/cc/task/core:task_api_factory",
"//tensorflow_lite_support/cc/task/core:task_utils",
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer",
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils",
"//tensorflow_lite_support/metadata/cc:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite:string",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/core/api",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"

#include <stddef.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_format.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/string_type.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/task/core/category.h"
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"

namespace tflite {
namespace support {
namespace task {
namespace text {
namespace nlclassifier {

using ::tflite::support::task::core::FindTensorByName;
using ::tflite::support::task::core::PopulateTensor;
using ::tflite::support::text::tokenizer::CreateTokenizerFromMetadata;
using ::tflite::support::text::tokenizer::TokenizerResult;

namespace {
constexpr char kIdsTensorName[] = "ids";
constexpr char kMaskTensorName[] = "mask";
constexpr char kSegmentIdsTensorName[] = "segment_ids";
constexpr char kScoreTensorName[] = "probability";
constexpr char kClassificationToken[] = "[CLS]";
constexpr char kSeparator[] = "[SEP]";
} // namespace

absl::Status BertNLClassifier::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
auto* input_tensor_metadatas =
GetMetadataExtractor()->GetInputTensorMetadata();
auto* ids_tensor =
FindTensorByName(input_tensors, input_tensor_metadatas, kIdsTensorName);
auto* mask_tensor =
FindTensorByName(input_tensors, input_tensor_metadatas, kMaskTensorName);
auto* segment_ids_tensor = FindTensorByName(
input_tensors, input_tensor_metadatas, kSegmentIdsTensorName);

std::string processed_input = input;
absl::AsciiStrToLower(&processed_input);

TokenizerResult input_tokenize_results;
input_tokenize_results = tokenizer_->Tokenize(processed_input);

// 2 accounts for [CLS], [SEP]
absl::Span<const std::string> query_tokens =
absl::MakeSpan(input_tokenize_results.subwords.data(),
input_tokenize_results.subwords.data() +
std::min(static_cast<size_t>(kMaxSeqLen - 2),
input_tokenize_results.subwords.size()));

std::vector<std::string> tokens;
tokens.reserve(2 + query_tokens.size());
// Start of generating the features.
tokens.push_back(kClassificationToken);
// For query input.
for (const auto& query_token : query_tokens) {
tokens.push_back(query_token);
}
// For Separation.
tokens.push_back(kSeparator);

std::vector<int> input_ids(kMaxSeqLen, 0);
std::vector<int> input_mask(kMaxSeqLen, 0);
// Convert tokens back into ids and set mask
for (int i = 0; i < tokens.size(); ++i) {
tokenizer_->LookupId(tokens[i], &input_ids[i]);
input_mask[i] = 1;
}
// |<-----------kMaxSeqLen---------->|
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
// input_masks 1 1 1... 1 1 0 0... 0
// segment_ids 0 0 0... 0 0 0 0... 0

PopulateTensor(input_ids, ids_tensor);
PopulateTensor(input_mask, mask_tensor);
PopulateTensor(std::vector<int>(kMaxSeqLen, 0), segment_ids_tensor);

return absl::OkStatus();
}

StatusOr<std::vector<core::Category>> BertNLClassifier::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
const std::string& /*input*/) {
if (output_tensors.size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("BertNLClassifier models are expected to have only 1 "
"output, found %d",
output_tensors.size()),
TfLiteSupportStatus::kInvalidNumOutputTensorsError);
}
const TfLiteTensor* scores = FindTensorByName(
output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
kScoreTensorName);

// optional labels extracted from metadata
return BuildResults(scores, /*labels=*/nullptr);
}

StatusOr<std::unique_ptr<BertNLClassifier>>
BertNLClassifier::CreateBertNLClassifierWithMetadata(
const std::string& path_to_model_with_metadata,
std::unique_ptr<tflite::OpResolver> resolver) {
std::unique_ptr<BertNLClassifier> bert_nl_classifier;
ASSIGN_OR_RETURN(bert_nl_classifier,
core::TaskAPIFactory::CreateFromFile<BertNLClassifier>(
path_to_model_with_metadata, std::move(resolver)));
RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
return std::move(bert_nl_classifier);
}

StatusOr<std::unique_ptr<BertNLClassifier>>
BertNLClassifier::CreateBertNLClassifierWithMetadataFromBinary(
const char* model_with_metadata_buffer_data,
size_t model_with_metadata_buffer_size,
std::unique_ptr<tflite::OpResolver> resolver) {
std::unique_ptr<BertNLClassifier> bert_nl_classifier;
ASSIGN_OR_RETURN(bert_nl_classifier,
core::TaskAPIFactory::CreateFromBuffer<BertNLClassifier>(
model_with_metadata_buffer_data,
model_with_metadata_buffer_size, std::move(resolver)));
RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
return std::move(bert_nl_classifier);
}

absl::Status BertNLClassifier::InitializeFromMetadata() {
// Set up mandatory tokenizer.
ASSIGN_OR_RETURN(tokenizer_,
CreateTokenizerFromMetadata(*GetMetadataExtractor()));

// Set up optional label vector.
TrySetLabelFromMetadata(
GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex))
.IgnoreError();
return absl::OkStatus();
}

} // namespace nlclassifier
} // namespace text
} // namespace task
} // namespace support
} // namespace tflite
101 changes: 101 additions & 0 deletions tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_

#include <stddef.h>

#include <memory>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/string_type.h"
#include "tensorflow_lite_support/cc/task/core/category.h"
#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"

namespace tflite {
namespace support {
namespace task {
namespace text {
namespace nlclassifier {

// Classifier API for NLClassification tasks with Bert models, categorizes
// string into different classes.
//
// The API expects a Bert based TFLite model with metadata populated.
// The metadata should contain the following information:
// - input_process_units for Wordpiece/Sentencepiece Tokenizer
// - 3 input tensors with names "ids", "mask" and "segment_ids"
// - 1 output tensor of type float32[1, 2], with a optionally attached label
// file. If a label file is attached, the file should be a plain text file
// with one label per line, the number of labels should match the number of
// categories the model outputs.

class BertNLClassifier : public NLClassifier {
public:
using NLClassifier::NLClassifier;
// Max number of tokens to pass to the model.
static constexpr int kMaxSeqLen = 128;

// Factory function to create a BertNLClassifier from TFLite model with
// metadata.
static StatusOr<std::unique_ptr<BertNLClassifier>>
CreateBertNLClassifierWithMetadata(
const std::string& path_to_model_with_metadata,
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());

// Factory function to create a BertNLClassifier from in memory buffer of a
// TFLite model with metadata.
static StatusOr<std::unique_ptr<BertNLClassifier>>
CreateBertNLClassifierWithMetadataFromBinary(
const char* model_with_metadata_buffer_data,
size_t model_with_metadata_buffer_size,
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());

protected:
// Run tokenization on input text and construct three input tensors ids, mask
// and segment_ids for the model input.
absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
const std::string& input) override;

// Extract model output and create results with label file attached in
// metadata. If no label file is attached, use output score index as labels.
StatusOr<std::vector<core::Category>> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
const std::string& input) override;

private:
// Initialize the API with the tokenizer and label files set in the metadata.
absl::Status InitializeFromMetadata();

std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_;
};

} // namespace nlclassifier
} // namespace text
} // namespace task
} // namespace support
} // namespace tflite

#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_
Loading

0 comments on commit d40c3fe

Please sign in to comment.