Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added tokenizers as process units to the metadata schema #18

Merged
merged 1 commit into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 112 additions & 53 deletions tensorflow_lite_support/metadata/cc/metadata_version.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ namespace {
// of 1.0.0.
enum class SchemaMembers {
kAssociatedFileTypeVocabulary = 0,
kSubGraphMetadataInputProcessUnits = 1,
kSubGraphMetadataOutputProcessUnits = 2,
kProcessUnitOptionsBertTokenizerOptions = 3,
kProcessUnitOptionsSentencePieceTokenizerOptions = 4
};

// Helper class to compare semantic versions in terms of three integers, major,
Expand Down Expand Up @@ -86,10 +90,21 @@ Version GetMemberVersion(SchemaMembers member) {
switch (member) {
case SchemaMembers::kAssociatedFileTypeVocabulary:
return Version(1, 0, 1);
case SchemaMembers::kSubGraphMetadataInputProcessUnits:
return Version(1, 1, 0);
case SchemaMembers::kSubGraphMetadataOutputProcessUnits:
return Version(1, 1, 0);
case SchemaMembers::kProcessUnitOptionsBertTokenizerOptions:
return Version(1, 1, 0);
case SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions:
return Version(1, 1, 0);
default:
// Should never happen.
TFLITE_LOG(FATAL) << "Unsupported schema member: "
<< static_cast<int>(member);
}
// Should never happen.
return Version(0, 0, 0);
}

// Updates min_version if it precedes the new_version.
Expand All @@ -100,78 +115,120 @@ inline void UpdateMinimumVersion(const Version& new_version,
}
}

void UpdateMinimumVersionForAssociatedFile(
const tflite::AssociatedFile* associated_file, Version* min_version) {
if (associated_file == nullptr) return;
template <typename T>
void UpdateMinimumVersionForTable(const T* table, Version* min_version);

if (associated_file->type() == AssociatedFileType_VOCABULARY) {
template <typename T>
void UpdateMinimumVersionForArray(
const flatbuffers::Vector<flatbuffers::Offset<T>>* array,
Version* min_version) {
if (array == nullptr) return;

for (int i = 0; i < array->size(); ++i) {
UpdateMinimumVersionForTable<T>(array->Get(i), min_version);
}
}

template <>
void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
const tflite::AssociatedFile* table, Version* min_version) {
if (table == nullptr) return;

if (table->type() == AssociatedFileType_VOCABULARY) {
UpdateMinimumVersion(
GetMemberVersion(SchemaMembers::kAssociatedFileTypeVocabulary),
min_version);
}
}

void UpdateMinimumVersionForAssociatedFileArray(
const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
associated_files,
Version* min_version) {
if (associated_files == nullptr) return;
template <>
void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
const tflite::ProcessUnit* table, Version* min_version) {
if (table == nullptr) return;

for (int i = 0; i < associated_files->size(); ++i) {
UpdateMinimumVersionForAssociatedFile(associated_files->Get(i),
min_version);
tflite::ProcessUnitOptions process_unit_type = table->options_type();
if (process_unit_type == ProcessUnitOptions_BertTokenizerOptions) {
UpdateMinimumVersion(
GetMemberVersion(
SchemaMembers::kProcessUnitOptionsBertTokenizerOptions),
min_version);
}
if (process_unit_type == ProcessUnitOptions_SentencePieceTokenizerOptions) {
UpdateMinimumVersion(
GetMemberVersion(
SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions),
min_version);
}
}

void UpdateMinimumVersionForTensorMetadata(
const tflite::TensorMetadata* tensor_metadata, Version* min_version) {
if (tensor_metadata == nullptr) return;
template <>
void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
const tflite::TensorMetadata* table, Version* min_version) {
if (table == nullptr) return;

// Checks the associated_files field.
UpdateMinimumVersionForAssociatedFileArray(
tensor_metadata->associated_files(), min_version);
}
UpdateMinimumVersionForArray<tflite::AssociatedFile>(
table->associated_files(), min_version);

void UpdateMinimumVersionForTensorMetadataArray(
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
tensor_metadata_array,
Version* min_version) {
if (tensor_metadata_array == nullptr) return;

for (int i = 0; i < tensor_metadata_array->size(); ++i) {
UpdateMinimumVersionForTensorMetadata(tensor_metadata_array->Get(i),
min_version);
}
// Checks the process_units field.
UpdateMinimumVersionForArray<tflite::ProcessUnit>(table->process_units(),
min_version);
}

void UpdateMinimumVersionForSubGraphMetadata(
const tflite::SubGraphMetadata* subgraph_metadata, Version* min_version) {
if (subgraph_metadata == nullptr) return;
template <>
void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
const tflite::SubGraphMetadata* table, Version* min_version) {
if (table == nullptr) return;

// Checks in the input/output metadata arrays.
UpdateMinimumVersionForTensorMetadataArray(
subgraph_metadata->input_tensor_metadata(), min_version);
UpdateMinimumVersionForTensorMetadataArray(
subgraph_metadata->output_tensor_metadata(), min_version);
UpdateMinimumVersionForArray<tflite::TensorMetadata>(
table->input_tensor_metadata(), min_version);
UpdateMinimumVersionForArray<tflite::TensorMetadata>(
table->output_tensor_metadata(), min_version);

// Checks the associated_files field.
UpdateMinimumVersionForAssociatedFileArray(
subgraph_metadata->associated_files(), min_version);
UpdateMinimumVersionForArray<tflite::AssociatedFile>(
table->associated_files(), min_version);

// Checks for the input_process_units field.
if (table->input_process_units() != nullptr) {
UpdateMinimumVersion(
GetMemberVersion(SchemaMembers::kSubGraphMetadataInputProcessUnits),
min_version);
UpdateMinimumVersionForArray<tflite::ProcessUnit>(
table->input_process_units(), min_version);
}

// Checks for the output_process_units field.
if (table->output_process_units() != nullptr) {
UpdateMinimumVersion(
GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputProcessUnits),
min_version);
UpdateMinimumVersionForArray<tflite::ProcessUnit>(
table->output_process_units(), min_version);
}
}

void UpdateMinimumVersionForModelMetadata(
const tflite::ModelMetadata& model_metadata, Version* min_version) {
template <>
void UpdateMinimumVersionForTable<tflite::ModelMetadata>(
const tflite::ModelMetadata* table, Version* min_version) {
if (table == nullptr) {
// Should never happen, because VerifyModelMetadataBuffer has verified it.
TFLITE_LOG(FATAL) << "The ModelMetadata object is null.";
return;
}

// Checks the subgraph_metadata field.
if (model_metadata.subgraph_metadata() != nullptr) {
for (int i = 0; i < model_metadata.subgraph_metadata()->size(); ++i) {
UpdateMinimumVersionForSubGraphMetadata(
model_metadata.subgraph_metadata()->Get(i), min_version);
if (table->subgraph_metadata() != nullptr) {
for (int i = 0; i < table->subgraph_metadata()->size(); ++i) {
UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
table->subgraph_metadata()->Get(i), min_version);
}
}

// Checks the associated_files field.
UpdateMinimumVersionForAssociatedFileArray(model_metadata.associated_files(),
min_version);
UpdateMinimumVersionForArray<tflite::AssociatedFile>(
table->associated_files(), min_version);
}

} // namespace
Expand All @@ -196,15 +253,17 @@ TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
const tflite::ModelMetadata* model_metadata = GetModelMetadata(buffer_data);

// All tables in the metadata schema should have their dedicated
// UpdateMinimumVersionFor**() methods, respectively. We'll gradually add
// these methods when new fields show up in later schema versions.
// UpdateMinimumVersionForTable<Foo>() methods, respectively. We'll gradually
// add these methods when new fields show up in later schema versions.
//
// UpdateMinimumVersionFor<Foo>() takes a const pointer of Foo. The pointer
// can be a nullptr if Foo is not populated into the corresponding table of
// the Flatbuffer object. In this case, UpdateMinimumVersionFor<Foo>() will be
// skipped. An exception is UpdateMinimumVersionForModelMetadata(), where
// ModelMetadata is the root table, and it won't be null.
UpdateMinimumVersionForModelMetadata(*model_metadata, &min_version);
// UpdateMinimumVersionForTable<Foo>() takes a const pointer of Foo. The
// pointer can be a nullptr if Foo is not populated into the corresponding
// table of the Flatbuffer object. In this case,
// UpdateMinimumVersionFor<Foo>() will be skipped. An exception is
// UpdateMinimumVersionForModelMetadata(), where ModelMetadata is the root
// table, and it won't be null.
UpdateMinimumVersionForTable<tflite::ModelMetadata>(model_metadata,
&min_version);

*min_version_str = min_version.ToString();
return kTfLiteOk;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public final class MetadataParser {
* The version of the metadata parser that this metadata extractor library is depending on. The
* value should match the value of "Schema Semantic version" in metadata_schema.fbs.
*/
public static final String VERSION = "1.0.1";
public static final String VERSION = "1.1.0";

private MetadataParser() {}
}
45 changes: 44 additions & 1 deletion tensorflow_lite_support/metadata/metadata_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace tflite;
// for which they were added.
//
// LINT.IfChange
// Schema Semantic version: 1.0.1
// Schema Semantic version: 1.1.0
// LINT.ThenChange(//tensorflow_lite_support/\
// metadata/java/src/java/org/tensorflow/lite/support/metadata/\
// MetadataParser.java)
Expand All @@ -62,6 +62,10 @@ file_identifier "M001";

// History:
// 1.0.1 - Added VOCABULARY type to AssociatedFileType.
// 1.1.0 - Added BertTokenizerOptions to ProcessUnitOptions.
// Added SentencePieceTokenizerOptions to ProcessUnitOptions.
// Added input_process_units to SubGraphMetadata.
// Added output_process_units to SubGraphMetadata.

// File extension of any written files.
file_extension "tflitemeta";
Expand Down Expand Up @@ -414,11 +418,34 @@ table ScoreThresholdingOptions {
global_score_threshold:float;
}

// Performs Bert tokenization as in tf.text.BertTokenizer
// (https://github.com/tensorflow/text/blob/3599f6fcd2b780a2dc413b90fb9315464f10b314/docs/api_docs/python/text/BertTokenizer.md)
// Added in: 1.1.0
table BertTokenizerOptions {
// The vocabulary files used in the BertTokenizer.
vocab_file:[AssociatedFile];
}

// Performs SentencePiece tokenization as in tf.text.SentencepieceTokenizer
// (https://github.com/tensorflow/text/blob/3599f6fcd2b780a2dc413b90fb9315464f10b314/docs/api_docs/python/text/SentencepieceTokenizer.md).
// Added in: 1.1.0
table SentencePieceTokenizerOptions {
// The SentencePiece model files used in the SentencePieceTokenizer.
sentencePiece_model:[AssociatedFile];

// The optional vocabulary model files used in the SentencePieceTokenizer.
vocab_file:[AssociatedFile];
}

// Options that are used when processing the tensor.
union ProcessUnitOptions {
NormalizationOptions,
ScoreCalibrationOptions,
ScoreThresholdingOptions,
// Added in: 1.1.0
BertTokenizerOptions,
// Added in: 1.1.0
SentencePieceTokenizerOptions
}

// A process unit that is used to process the tensor out-of-graph.
Expand Down Expand Up @@ -519,6 +546,22 @@ table SubGraphMetadata {

// A list of associated files of this subgraph.
associated_files:[AssociatedFile];

// Input process units of the subgraph. Some models may have complex pre and
// post processing logics where the process units do not work on one tensor at
// a time, but in a similar way of a TFLite graph. For example, in the
// MobileBert model (https://www.tensorflow.org/lite/models/bert_qa/overview),
// the inputs are: ids / mask / segment ids;
// the outputs are: end logits / start logits.
// The preprocessing converts the query string and the context string to the
// model inputs, and the post-processing converts the model outputs to the
// answer string.
// Added in: 1.1.0
input_process_units:[ProcessUnit];

// Output process units of the subgraph.
// Added in: 1.1.0
output_process_units:[ProcessUnit];
}

table ModelMetadata {
Expand Down