diff --git a/tensorflow_lite_support/metadata/metadata.py b/tensorflow_lite_support/metadata/metadata.py index 6ef5c1632..ec7792ce1 100644 --- a/tensorflow_lite_support/metadata/metadata.py +++ b/tensorflow_lite_support/metadata/metadata.py @@ -222,12 +222,15 @@ def load_metadata_buffer(self, metadata_buf): Raises: ValueError: The metadata to be populated is empty. ValueError: The metadata does not have the expected flatbuffer identifer. - ValueError: Error occurs when getting the minimum metadata parser version. + ValueError: Cannot get minimum metadata parser version. + ValueError: The number of SubgraphMetadata is not 1. + ValueError: The number of input/output tensors does not match the number + of input/output tensor metadata. """ if not metadata_buf: raise ValueError("The metadata to be populated is empty.") - _assert_metadata_buffer_identifier(metadata_buf) + self._validate_metadata(metadata_buf) # Gets the minimum metadata parser version of the metadata_buf. min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion( @@ -252,7 +255,12 @@ def load_metadata_file(self, metadata_file): Raises: IOError: File not found. + ValueError: The metadata to be populated is empty. ValueError: The metadata does not have the expected flatbuffer identifer. + ValueError: Cannot get minimum metadata parser version. + ValueError: The number of SubgraphMetadata is not 1. + ValueError: The number of input/output tensors does not match the number + of input/output tensor metadata. """ _assert_exist(metadata_file) with open(metadata_file, "rb") as f: @@ -399,6 +407,40 @@ def _populate_metadata_buffer(self): with open(self._model_file, "wb") as f: f.write(model_buf) + def _validate_metadata(self, metadata_buf): + """Validates the metadata to be populated.""" + _assert_metadata_buffer_identifier(metadata_buf) + + # Verify the number of SubgraphMetadata is exactly one. + # TFLite currently only support one subgraph. + model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata( + metadata_buf, 0) + if model_meta.SubgraphMetadataLength() != 1: + raise ValueError("The number of SubgraphMetadata should be exactly one, " + "but got {0}.".format( + model_meta.SubgraphMetadataLength())) + + # Verify if the number of tensor metadata matches the number of tensors. + with open(self._model_file, "rb") as f: + model_buf = f.read() + model = _schema_fb.Model.GetRootAsModel(model_buf, 0) + + num_input_tensors = model.Subgraphs(0).InputsLength() + num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength() + if num_input_tensors != num_input_meta: + raise ValueError( + "The number of input tensors ({0}) should match the number of " + "input tensor metadata ({1})".format(num_input_tensors, + num_input_meta)) + num_output_tensors = model.Subgraphs(0).OutputsLength() + num_output_meta = model_meta.SubgraphMetadata( + 0).OutputTensorMetadataLength() + if num_output_tensors != num_output_meta: + raise ValueError( + "The number of output tensors ({0}) should match the number of " + "output tensor metadata ({1})".format(num_output_tensors, + num_output_meta)) + class _MetadataPopulatorWithBuffer(MetadataPopulator): """Subclass of MetadtaPopulator that populates metadata to a model buffer. diff --git a/tensorflow_lite_support/metadata/metadata_test.py b/tensorflow_lite_support/metadata/metadata_test.py index 9caf4bc94..a16a59a0b 100644 --- a/tensorflow_lite_support/metadata/metadata_test.py +++ b/tensorflow_lite_support/metadata/metadata_test.py @@ -37,11 +37,10 @@ def setUp(self): super(MetadataTest, self).setUp() self._invalid_model_buf = None self._invalid_file = "not_existed_file" - self._empty_model_buf = self._create_empty_model_buf() - self._empty_model_file = self.create_tempfile().full_path - with open(self._empty_model_file, "wb") as f: - f.write(self._empty_model_buf) - self._model_file = self._create_model_file_with_metadata_and_buf_fields() + self._model_buf = self._create_model_buf() + self._model_file = self.create_tempfile().full_path + with open(self._model_file, "wb") as f: + f.write(self._model_buf) self._metadata_file = self._create_metadata_file() self._metadata_file_with_version = self._create_metadata_file_with_version( self._metadata_file, "1.0.0") @@ -49,31 +48,26 @@ def setUp(self): self._file2 = self.create_tempfile("file2").full_path self._file3 = self.create_tempfile("file3").full_path - def _create_empty_model_buf(self): - model = _schema_fb.ModelT() - model_builder = flatbuffers.Builder(0) - model_builder.Finish( - model.Pack(model_builder), - _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) - return model_builder.Output() - - def _create_model_file_with_metadata_and_buf_fields(self): + def _create_model_buf(self): + # Create a model with two inputs and one output, which matches the metadata + # created by _create_metadata_file(). metadata_field = _schema_fb.MetadataT() + subgraph = _schema_fb.SubGraphT() + subgraph.inputs = [0, 1] + subgraph.outputs = [2] + metadata_field.name = "meta" buffer_field = _schema_fb.BufferT() model = _schema_fb.ModelT() + model.subgraphs = [subgraph] + # Creates the metadata and buffer fields for testing purposes. model.metadata = [metadata_field, metadata_field] model.buffers = [buffer_field, buffer_field, buffer_field] model_builder = flatbuffers.Builder(0) model_builder.Finish( model.Pack(model_builder), _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) - - mnodel_file = self.create_tempfile().full_path - with open(mnodel_file, "wb") as f: - f.write(model_builder.Output()) - - return mnodel_file + return model_builder.Output() def _create_metadata_file(self): associated_file1 = _metadata_fb.AssociatedFileT() @@ -85,9 +79,12 @@ def _create_metadata_file(self): six.ensure_str(associated_file2.name) ] + input_meta = _metadata_fb.TensorMetadataT() output_meta = _metadata_fb.TensorMetadataT() output_meta.associatedFiles = [associated_file2] subgraph = _metadata_fb.SubGraphMetadataT() + # Create a model with two inputs and one output. + subgraph.inputTensorMetadata = [input_meta, input_meta] subgraph.outputTensorMetadata = [output_meta] model_meta = _metadata_fb.ModelMetadataT() @@ -160,8 +157,7 @@ def _create_metadata_file_with_version(self, metadata_file, min_version): class MetadataPopulatorTest(MetadataTest): def testToValidModelFile(self): - populator = _metadata.MetadataPopulator.with_model_file( - self._empty_model_file) + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) self.assertIsInstance(populator, _metadata.MetadataPopulator) def testToInvalidModelFile(self): @@ -171,8 +167,7 @@ def testToInvalidModelFile(self): str(error.exception)) def testToValidModelBuffer(self): - populator = _metadata.MetadataPopulator.with_model_buffer( - self._empty_model_buf) + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) self.assertIsInstance(populator, _metadata.MetadataPopulator) def testToInvalidModelBuffer(self): @@ -189,8 +184,7 @@ def testToModelBufferWithWrongIdentifier(self): "may not be a valid TFLite model.", str(error.exception)) def testSinglePopulateAssociatedFile(self): - populator = _metadata.MetadataPopulator.with_model_buffer( - self._empty_model_buf) + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) populator.load_associated_files([self._file1]) populator.populate() @@ -199,8 +193,7 @@ def testSinglePopulateAssociatedFile(self): self.assertEqual(set(packed_files), set(expected_packed_files)) def testRepeatedPopulateAssociatedFile(self): - populator = _metadata.MetadataPopulator.with_model_file( - self._empty_model_file) + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) populator.load_associated_files([self._file1, self._file2]) # Loads file2 multiple times. populator.load_associated_files([self._file2]) @@ -216,22 +209,20 @@ def testRepeatedPopulateAssociatedFile(self): # Check if the model buffer read from file is the same as that read from # get_model_buffer(). - with open(self._empty_model_file, "rb") as f: + with open(self._model_file, "rb") as f: model_buf_from_file = f.read() model_buf_from_getter = populator.get_model_buffer() self.assertEqual(model_buf_from_file, model_buf_from_getter) def testPopulateInvalidAssociatedFile(self): - populator = _metadata.MetadataPopulator.with_model_buffer( - self._empty_model_buf) + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) with self.assertRaises(IOError) as error: populator.load_associated_files([self._invalid_file]) self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), str(error.exception)) def testPopulatePackedAssociatedFile(self): - populator = _metadata.MetadataPopulator.with_model_buffer( - self._empty_model_buf) + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) populator.load_associated_files([self._file1]) populator.populate() with self.assertRaises(ValueError) as error: @@ -242,22 +233,22 @@ def testPopulatePackedAssociatedFile(self): os.path.basename(self._file1)), str(error.exception)) def testGetPackedAssociatedFileList(self): - populator = _metadata.MetadataPopulator.with_model_buffer( - self._empty_model_buf) + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) packed_files = populator.get_packed_associated_file_list() self.assertEqual(packed_files, []) def testPopulateMetadataFileToEmptyModelFile(self): - populator = _metadata.MetadataPopulator.with_model_file( - self._empty_model_file) + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) populator.load_metadata_file(self._metadata_file) populator.load_associated_files([self._file1, self._file2]) populator.populate() - with open(self._empty_model_file, "rb") as f: + with open(self._model_file, "rb") as f: model_buf_from_file = f.read() model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0) - metadata_field = model.Metadata(0) + # self._model_file already has two elements in the metadata field, so the + # populated TFLite metadata will be the third element. + metadata_field = model.Metadata(2) self.assertEqual( six.ensure_str(metadata_field.Name()), six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME)) @@ -279,8 +270,7 @@ def testPopulateMetadataFileToEmptyModelFile(self): self.assertEqual(model_buf_from_file, model_buf_from_getter) def testPopulateMetadataFileWithoutAssociatedFiles(self): - populator = _metadata.MetadataPopulator.with_model_file( - self._empty_model_file) + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) populator.load_metadata_file(self._metadata_file) populator.load_associated_files([self._file1]) # Suppose to populate self._file2, because it is recorded in the metadta. @@ -319,27 +309,35 @@ def _assert_golden_metadata(self, model_file): self.assertEqual(metadata_buf, expected_metadata_buf) def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self): - # First, creates a dummy metadata. Populates it and the associated files - # into the model. + # First, creates a dummy metadata different from self._metadata_file. It + # needs to have the same input/output tensor numbers as self._model_file. + # Populates it and the associated files into the model. + input_meta = _metadata_fb.TensorMetadataT() + output_meta = _metadata_fb.TensorMetadataT() + subgraph = _metadata_fb.SubGraphMetadataT() + # Create a model with two inputs and one output. + subgraph.inputTensorMetadata = [input_meta, input_meta] + subgraph.outputTensorMetadata = [output_meta] model_meta = _metadata_fb.ModelMetadataT() - model_meta.name = "Mobilenet_quantized" + model_meta.subgraphMetadata = [subgraph] b = flatbuffers.Builder(0) b.Finish( model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) metadata_buf = b.Output() + # Populate the metadata. populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file) populator1.load_metadata_buffer(metadata_buf) populator1.load_associated_files([self._file1, self._file2]) populator1.populate() - # Then, populates the metadata again. + # Then, populate the metadata again. populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file) populator2.load_metadata_file(self._metadata_file) populator2.populate() - # Tests if the metadata is populated correctly. + # Test if the metadata is populated correctly. self._assert_golden_metadata(self._model_file) def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self): @@ -362,37 +360,93 @@ def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self): self.assertEqual(model_buf_from_file, model_buf_from_getter) def testPopulateInvalidMetadataFile(self): - populator = _metadata.MetadataPopulator.with_model_buffer( - self._empty_model_buf) + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) with self.assertRaises(IOError) as error: populator.load_metadata_file(self._invalid_file) self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), str(error.exception)) def testPopulateInvalidMetadataBuffer(self): - populator = _metadata.MetadataPopulator.with_model_buffer( - self._empty_model_buf) + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) with self.assertRaises(ValueError) as error: populator.load_metadata_buffer([]) self.assertEqual("The metadata to be populated is empty.", str(error.exception)) def testGetModelBufferBeforePopulatingData(self): - populator = _metadata.MetadataPopulator.with_model_buffer( - self._empty_model_buf) + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) model_buf = populator.get_model_buffer() - expected_model_buf = self._empty_model_buf + expected_model_buf = self._model_buf self.assertEqual(model_buf, expected_model_buf) + def testLoadMetadataBufferWithNoSubgraphMetadataThrowsException(self): + # Create a dummy metadata without Subgraph. + model_meta = _metadata_fb.ModelMetadataT() + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + "The number of SubgraphMetadata should be exactly one, but got 0.", + str(error.exception)) + + def testLoadMetadataBufferWithWrongInputMetaNumberThrowsException(self): + # Create a dummy metadata with no input tensor metadata, while the expected + # number is 2. + output_meta = _metadata_fb.TensorMetadataT() + subgprah_meta = _metadata_fb.SubGraphMetadataT() + subgprah_meta.outputTensorMetadata = [output_meta] + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgprah_meta] + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + ("The number of input tensors (2) should match the number of " + "input tensor metadata (0)"), str(error.exception)) + + def testLoadMetadataBufferWithWrongOutputMetaNumberThrowsException(self): + # Create a dummy metadata with no output tensor metadata, while the expected + # number is 1. + input_meta = _metadata_fb.TensorMetadataT() + subgprah_meta = _metadata_fb.SubGraphMetadataT() + subgprah_meta.inputTensorMetadata = [input_meta, input_meta] + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgprah_meta] + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + ("The number of output tensors (1) should match the number of " + "output tensor metadata (0)"), str(error.exception)) + class MetadataDisplayerTest(MetadataTest): def setUp(self): super(MetadataDisplayerTest, self).setUp() - self._model_file = self._create_model_with_metadata_and_associated_files() + self._model_with_meta_file = ( + self._create_model_with_metadata_and_associated_files()) def _create_model_with_metadata_and_associated_files(self): - model_buf = self._create_empty_model_buf() + model_buf = self._create_model_buf() model_file = self.create_tempfile().full_path with open(model_file, "wb") as f: f.write(model_buf) @@ -439,26 +493,27 @@ def test_load_model_file_invalidModelFile_throwsException(self): def test_load_model_file_modelWithoutMetadata_throwsException(self): with self.assertRaises(ValueError) as error: - _metadata.MetadataDisplayer.with_model_file(self._empty_model_file) + _metadata.MetadataDisplayer.with_model_file(self._model_file) self.assertEqual("The model does not have metadata.", str(error.exception)) def test_load_model_file_modelWithMetadata(self): - displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file) + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) self.assertIsInstance(displayer, _metadata.MetadataDisplayer) def test_load_model_buffer_modelWithOutMetadata_throwsException(self): with self.assertRaises(ValueError) as error: - _metadata.MetadataDisplayer.with_model_buffer( - self._create_empty_model_buf()) + _metadata.MetadataDisplayer.with_model_buffer(self._create_model_buf()) self.assertEqual("The model does not have metadata.", str(error.exception)) def test_load_model_buffer_modelWithMetadata(self): displayer = _metadata.MetadataDisplayer.with_model_buffer( - open(self._model_file, "rb").read()) + open(self._model_with_meta_file, "rb").read()) self.assertIsInstance(displayer, _metadata.MetadataDisplayer) def test_get_metadata_json_modelWithMetadata(self): - displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file) + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) actual = displayer.get_metadata_json() # Verifies the generated json file. @@ -469,7 +524,8 @@ def test_get_metadata_json_modelWithMetadata(self): self.assertEqual(actual, expected) def test_get_packed_associated_file_list_modelWithMetadata(self): - displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file) + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) packed_files = displayer.get_packed_associated_file_list() expected_packed_files = [ diff --git a/tensorflow_lite_support/metadata/testdata/golden_json.json b/tensorflow_lite_support/metadata/testdata/golden_json.json index 9ff5581fb..601a5976c 100644 --- a/tensorflow_lite_support/metadata/testdata/golden_json.json +++ b/tensorflow_lite_support/metadata/testdata/golden_json.json @@ -2,6 +2,12 @@ "name": "Mobilenet_quantized", "subgraph_metadata": [ { + "input_tensor_metadata": [ + { + }, + { + } + ], "output_tensor_metadata": [ { "associated_files": [