Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verify the metadata_buf in MetadataPopulator #13

Merged
merged 1 commit into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 44 additions & 2 deletions tensorflow_lite_support/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,15 @@ def load_metadata_buffer(self, metadata_buf):
Raises:
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Error occurs when getting the minimum metadata parser version.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
if not metadata_buf:
raise ValueError("The metadata to be populated is empty.")

_assert_metadata_buffer_identifier(metadata_buf)
self._validate_metadata(metadata_buf)

# Gets the minimum metadata parser version of the metadata_buf.
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
Expand All @@ -252,7 +255,12 @@ def load_metadata_file(self, metadata_file):

Raises:
IOError: File not found.
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
_assert_exist(metadata_file)
with open(metadata_file, "rb") as f:
Expand Down Expand Up @@ -399,6 +407,40 @@ def _populate_metadata_buffer(self):
with open(self._model_file, "wb") as f:
f.write(model_buf)

def _validate_metadata(self, metadata_buf):
"""Validates the metadata to be populated."""
_assert_metadata_buffer_identifier(metadata_buf)

# Verify the number of SubgraphMetadata is exactly one.
# TFLite currently only support one subgraph.
model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
metadata_buf, 0)
if model_meta.SubgraphMetadataLength() != 1:
raise ValueError("The number of SubgraphMetadata should be exactly one, "
"but got {0}.".format(
model_meta.SubgraphMetadataLength()))

# Verify if the number of tensor metadata matches the number of tensors.
with open(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.Model.GetRootAsModel(model_buf, 0)

num_input_tensors = model.Subgraphs(0).InputsLength()
num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength()
if num_input_tensors != num_input_meta:
raise ValueError(
"The number of input tensors ({0}) should match the number of "
"input tensor metadata ({1})".format(num_input_tensors,
num_input_meta))
num_output_tensors = model.Subgraphs(0).OutputsLength()
num_output_meta = model_meta.SubgraphMetadata(
0).OutputTensorMetadataLength()
if num_output_tensors != num_output_meta:
raise ValueError(
"The number of output tensors ({0}) should match the number of "
"output tensor metadata ({1})".format(num_output_tensors,
num_output_meta))


class _MetadataPopulatorWithBuffer(MetadataPopulator):
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
Expand Down