Skip to content

Commit

Permalink
Verify the metadata_buf in MetadataPopulator
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 319148318
  • Loading branch information
lu-wang-g authored and tflite-support-robot committed Jul 13, 2020
1 parent ca5fba5 commit 9f2bdf9
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 64 deletions.
46 changes: 44 additions & 2 deletions tensorflow_lite_support/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,15 @@ def load_metadata_buffer(self, metadata_buf):
Raises:
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Error occurs when getting the minimum metadata parser version.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
if not metadata_buf:
raise ValueError("The metadata to be populated is empty.")

_assert_metadata_buffer_identifier(metadata_buf)
self._validate_metadata(metadata_buf)

# Gets the minimum metadata parser version of the metadata_buf.
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
Expand All @@ -252,7 +255,12 @@ def load_metadata_file(self, metadata_file):
Raises:
IOError: File not found.
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
_assert_exist(metadata_file)
with open(metadata_file, "rb") as f:
Expand Down Expand Up @@ -399,6 +407,40 @@ def _populate_metadata_buffer(self):
with open(self._model_file, "wb") as f:
f.write(model_buf)

def _validate_metadata(self, metadata_buf):
"""Validates the metadata to be populated."""
_assert_metadata_buffer_identifier(metadata_buf)

# Verify the number of SubgraphMetadata is exactly one.
# TFLite currently only support one subgraph.
model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
metadata_buf, 0)
if model_meta.SubgraphMetadataLength() != 1:
raise ValueError("The number of SubgraphMetadata should be exactly one, "
"but got {0}.".format(
model_meta.SubgraphMetadataLength()))

# Verify if the number of tensor metadata matches the number of tensors.
with open(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.Model.GetRootAsModel(model_buf, 0)

num_input_tensors = model.Subgraphs(0).InputsLength()
num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength()
if num_input_tensors != num_input_meta:
raise ValueError(
"The number of input tensors ({0}) should match the number of "
"input tensor metadata ({1})".format(num_input_tensors,
num_input_meta))
num_output_tensors = model.Subgraphs(0).OutputsLength()
num_output_meta = model_meta.SubgraphMetadata(
0).OutputTensorMetadataLength()
if num_output_tensors != num_output_meta:
raise ValueError(
"The number of output tensors ({0}) should match the number of "
"output tensor metadata ({1})".format(num_output_tensors,
num_output_meta))


class _MetadataPopulatorWithBuffer(MetadataPopulator):
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
Expand Down
Loading

0 comments on commit 9f2bdf9

Please sign in to comment.