diff --git a/tensorflow_lite_support/python/test/BUILD b/tensorflow_lite_support/python/test/BUILD index cf5d45acf..52b142e01 100644 --- a/tensorflow_lite_support/python/test/BUILD +++ b/tensorflow_lite_support/python/test/BUILD @@ -5,13 +5,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -py_library( - name = "base_test", - testonly = 1, - srcs = ["base_test.py"], - srcs_version = "PY3", -) - py_library( name = "test_util", testonly = 1, diff --git a/tensorflow_lite_support/python/test/base_test.py b/tensorflow_lite_support/python/test/base_test.py deleted file mode 100644 index 72fc782e1..000000000 --- a/tensorflow_lite_support/python/test/base_test.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Base TestCase for the unit tests.""" - -import tensorflow as tf - -__unittest = True # Allows shorter stack trace for .assertDeepAlmostEqual pylint: disable=invalid-name - - -class BaseTestCase(tf.test.TestCase): - """Base test case.""" - - def assertDeepAlmostEqual(self, expected, actual, **kwargs): - """Compares lists, dicts and tuples recursively. - - Checks numeric values using test_case's - :py:meth:`unittest.TestCase.assertAlmostEqual` and checks all other values - with :py:meth:`unittest.TestCase.assertEqual`. Accepts additional keyword - arguments and pass those intact to assertAlmostEqual() (that's how you - specify comparison precision). - - Args: - expected: Expected object. - actual: Actual object. - **kwargs: Other parameters to be passed. - """ - if isinstance(expected, (int, float, complex)): - self.assertAlmostEqual(expected, actual, **kwargs) - elif isinstance(expected, (list, tuple)): - self.assertEqual(len(expected), len(actual)) - for index in range(len(expected)): - v1, v2 = expected[index], actual[index] - self.assertDeepAlmostEqual(v1, v2, **kwargs) - elif isinstance(expected, dict): - self.assertEqual(set(expected), set(actual)) - for key in expected: - self.assertDeepAlmostEqual(expected[key], actual[key], **kwargs) - else: - self.assertEqual(expected, actual) diff --git a/tensorflow_lite_support/python/test/task/audio/BUILD b/tensorflow_lite_support/python/test/task/audio/BUILD index eec5404f1..032e76c70 100644 --- a/tensorflow_lite_support/python/test/task/audio/BUILD +++ b/tensorflow_lite_support/python/test/task/audio/BUILD @@ -41,9 +41,7 @@ py_test( "//tensorflow_lite_support/python/task/processor/proto:class_pb2", "//tensorflow_lite_support/python/task/processor/proto:classification_options_pb2", "//tensorflow_lite_support/python/task/processor/proto:classifications_pb2", - "//tensorflow_lite_support/python/test:base_test", "//tensorflow_lite_support/python/test:test_util", "@absl_py//absl/testing:parameterized", - "@com_google_protobuf//:protobuf_python", ], ) diff --git a/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py b/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py index 8a76ab73d..80eed9145 100644 --- a/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py +++ b/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py @@ -14,21 +14,17 @@ """Tests for audio_classifier.""" import enum -import json from absl.testing import parameterized import tensorflow as tf -from google.protobuf import json_format import unittest from tensorflow_lite_support.python.task.audio import audio_classifier from tensorflow_lite_support.python.task.audio.core import audio_record from tensorflow_lite_support.python.task.audio.core import tensor_audio from tensorflow_lite_support.python.task.core.proto import base_options_pb2 -from tensorflow_lite_support.python.task.processor.proto import class_pb2 from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2 from tensorflow_lite_support.python.task.processor.proto import classifications_pb2 -from tensorflow_lite_support.python.test import base_test from tensorflow_lite_support.python.test import test_util _mock = unittest.mock @@ -38,58 +34,75 @@ _FIXED_INPUT_SIZE_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite' _SPEECH_AUDIO_FILE = 'speech.wav' -_FIXED_INPUT_SIZE_MODEL_CLASSIFICATIONS = { - 'scores': [{ - 'index': 0, - 'score': 0.91796875, - 'class_name': 'Speech' - }, { - 'index': 500, - 'score': 0.05859375, - 'class_name': 'Inside, small room' - }, { - 'index': 494, - 'score': 0.01367188, - 'class_name': 'Silence' - }] +_FIXED_INPUT_SIZE_MODEL_CLASSIFICATIONS = """ +classifications { + classes { + index: 0 + score: 0.917969 + class_name: "Speech" + } + classes { + index: 500 + score: 0.058594 + class_name: "Inside, small room" + } + classes { + index: 494 + score: 0.013672 + class_name: "Silence" + } + head_index: 0 + head_name: "scores" } +""" _MULTIHEAD_MODEL_FILE = 'two_heads.tflite' _TWO_HEADS_AUDIO_FILE = 'two_heads.wav' -_MULTIHEAD_MODEL_CLASSIFICATIONS = { - 'yamnet_classification': [{ - 'index': 508, - 'score': 0.5486158, - 'class_name': 'Environmental noise' - }, { - 'index': 507, - 'score': 0.38086897, - 'class_name': 'Noise' - }, { - 'index': 106, - 'score': 0.25613675, - 'class_name': 'Bird' - }], - 'bird_classification': [{ - 'index': 4, - 'score': 0.93399656, - 'class_name': 'Chestnut-crowned Antpitta' - }, { - 'index': 1, - 'score': 0.065934494, - 'class_name': 'White-breasted Wood-Wren' - }, { - 'index': 0, - 'score': 6.1469495e-05, - 'class_name': 'Red Crossbill' - }] +_MULTIHEAD_MODEL_CLASSIFICATIONS = """ +classifications { + classes { + index: 508 + score: 0.548616 + class_name: "Environmental noise" + } + classes { + index: 507 + score: 0.380869 + class_name: "Noise" + } + classes { + index: 106 + score: 0.256137 + class_name: "Bird" + } + head_index: 0 + head_name: "yamnet_classification" } +classifications { + classes { + index: 4 + score: 0.933997 + class_name: "Chestnut-crowned Antpitta" + } + classes { + index: 1 + score: 0.065934 + class_name: "White-breasted Wood-Wren" + } + classes { + index: 0 + score: 6.1469495e-05 + class_name: "Red Crossbill" + } + head_index: 1 + head_name: "bird_classification" +} +""" _ALLOW_LIST = ['Speech', 'Inside, small room'] _DENY_LIST = ['Speech'] _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 -_ACCEPTABLE_ERROR_RANGE = 0.005 class ModelFileType(enum.Enum): @@ -106,22 +119,7 @@ def _create_classifier_from_options(base_options, **classification_options): return classifier -def _build_test_data(classifications): - expected_result = classifications_pb2.ClassificationResult() - - for index, (head_name, categories) in enumerate(classifications.items()): - classifications = classifications_pb2.Classifications( - head_index=index, head_name=head_name) - classifications.classes.extend( - [class_pb2.Category(**args) for args in categories]) - expected_result.classifications.append(classifications) - - expected_result_dict = json.loads(json_format.MessageToJson(expected_result)) - - return expected_result_dict - - -class AudioClassifierTest(parameterized.TestCase, base_test.BaseTestCase): +class AudioClassifierTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): super().setUp() @@ -193,7 +191,7 @@ def test_create_audio_record_from_classifier_succeeds(self, _): (_MULTIHEAD_MODEL_FILE, ModelFileType.FILE_CONTENT, _TWO_HEADS_AUDIO_FILE, 3, _MULTIHEAD_MODEL_CLASSIFICATIONS)) def test_classify_model(self, model_name, model_file_type, audio_file_name, - max_results, expected_classifications): + max_results, expected_result_text_proto): # Creates classifier. model_path = test_util.get_test_data_path(model_name) if model_file_type is ModelFileType.FILE_NAME: @@ -216,14 +214,9 @@ def test_classify_model(self, model_name, model_file_type, audio_file_name, # Classifies the input. audio_result = classifier.classify(tensor) - audio_result_dict = json.loads(json_format.MessageToJson(audio_result)) - - # Builds test data. - expected_result_dict = _build_test_data(expected_classifications) # Comparing results. - self.assertDeepAlmostEqual( - audio_result_dict, expected_result_dict, delta=_ACCEPTABLE_ERROR_RANGE) + self.assertProtoEquals(expected_result_text_proto, audio_result) def test_max_results_option(self): # Creates classifier. @@ -238,9 +231,7 @@ def test_max_results_option(self): # Classifies the input. audio_result = classifier.classify(tensor) - audio_result_dict = json.loads(json_format.MessageToJson(audio_result)) - - categories = audio_result_dict['classifications'][0]['classes'] + categories = audio_result.classifications[0].classes self.assertLessEqual( len(categories), _MAX_RESULTS, 'Too many results returned.') @@ -258,14 +249,11 @@ def test_score_threshold_option(self): # Classifies the input. audio_result = classifier.classify(tensor) - audio_result_dict = json.loads(json_format.MessageToJson(audio_result)) - - categories = audio_result_dict['classifications'][0]['classes'] + categories = audio_result.classifications[0].classes for category in categories: - score = category['score'] self.assertGreaterEqual( - score, _SCORE_THRESHOLD, + category.score, _SCORE_THRESHOLD, 'Classification with score lower than threshold found. {0}'.format( category)) @@ -282,12 +270,10 @@ def test_allowlist_option(self): # Classifies the input. audio_result = classifier.classify(tensor) - audio_result_dict = json.loads(json_format.MessageToJson(audio_result)) - - categories = audio_result_dict['classifications'][0]['classes'] + categories = audio_result.classifications[0].classes for category in categories: - label = category['className'] + label = category.class_name self.assertIn( label, _ALLOW_LIST, 'Label "{0}" found but not in label allow list'.format(label)) @@ -305,12 +291,10 @@ def test_denylist_option(self): # Classifies the input. audio_result = classifier.classify(tensor) - audio_result_dict = json.loads(json_format.MessageToJson(audio_result)) - - categories = audio_result_dict['classifications'][0]['classes'] + categories = audio_result.classifications[0].classes for category in categories: - label = category['className'] + label = category.class_name self.assertNotIn(label, _DENY_LIST, 'Label "{0}" found but in deny list.'.format(label)) diff --git a/tensorflow_lite_support/python/test/task/vision/BUILD b/tensorflow_lite_support/python/test/task/vision/BUILD index e27e27db3..0dba55150 100644 --- a/tensorflow_lite_support/python/test/task/vision/BUILD +++ b/tensorflow_lite_support/python/test/task/vision/BUILD @@ -41,10 +41,8 @@ py_test( "//tensorflow_lite_support/python/task/processor/proto:classifications_pb2", "//tensorflow_lite_support/python/task/vision:image_classifier", "//tensorflow_lite_support/python/task/vision/core:tensor_image", - "//tensorflow_lite_support/python/test:base_test", "//tensorflow_lite_support/python/test:test_util", "@absl_py//absl/testing:parameterized", - "@com_google_protobuf//:protobuf_python", ], ) @@ -78,15 +76,11 @@ py_test( deps = [ # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", - "//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2", - "//tensorflow_lite_support/python/task/processor/proto:class_pb2", "//tensorflow_lite_support/python/task/processor/proto:detection_options_pb2", "//tensorflow_lite_support/python/task/processor/proto:detections_pb2", "//tensorflow_lite_support/python/task/vision:object_detector", "//tensorflow_lite_support/python/task/vision/core:tensor_image", - "//tensorflow_lite_support/python/test:base_test", "//tensorflow_lite_support/python/test:test_util", "@absl_py//absl/testing:parameterized", - "@com_google_protobuf//:protobuf_python", ], ) diff --git a/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py b/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py index ffff35541..cd6712a02 100644 --- a/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py +++ b/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py @@ -14,20 +14,16 @@ """Tests for image_classifier.""" import enum -import json from absl.testing import parameterized import tensorflow as tf -from google.protobuf import json_format from tensorflow_lite_support.python.task.core.proto import base_options_pb2 from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2 -from tensorflow_lite_support.python.task.processor.proto import class_pb2 from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2 from tensorflow_lite_support.python.task.processor.proto import classifications_pb2 from tensorflow_lite_support.python.task.vision import image_classifier from tensorflow_lite_support.python.task.vision.core import tensor_image -from tensorflow_lite_support.python.test import base_test from tensorflow_lite_support.python.test import test_util _BaseOptions = base_options_pb2.BaseOptions @@ -40,7 +36,6 @@ _DENY_LIST = ['cheeseburger'] _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 -_ACCEPTABLE_ERROR_RANGE = 0.000001 def _create_classifier_from_options(base_options, **classification_options): @@ -52,23 +47,12 @@ def _create_classifier_from_options(base_options, **classification_options): return classifier -def _build_test_data(expected_categories): - classifications = classifications_pb2.Classifications(head_index=0) - classifications.classes.extend( - [class_pb2.Category(**args) for args in expected_categories]) - expected_result = classifications_pb2.ClassificationResult() - expected_result.classifications.append(classifications) - expected_result_dict = json.loads(json_format.MessageToJson(expected_result)) - - return expected_result_dict - - class ModelFileType(enum.Enum): FILE_CONTENT = 1 FILE_NAME = 2 -class ImageClassifierTest(parameterized.TestCase, base_test.BaseTestCase): +class ImageClassifierTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): super().setUp() @@ -105,33 +89,47 @@ def test_create_from_options_succeeds_with_valid_model_content(self): classifier = _ImageClassifier.create_from_options(options) self.assertIsInstance(classifier, _ImageClassifier) - @parameterized.parameters((ModelFileType.FILE_NAME, 3, [{ - 'index': 934, - 'score': 0.7399742007255554, - 'class_name': 'cheeseburger' - }, { - 'index': 925, - 'score': 0.026928534731268883, - 'class_name': 'guacamole' - }, { - 'index': 932, - 'score': 0.025737214833498, - 'class_name': 'bagel' - }]), (ModelFileType.FILE_CONTENT, 3, [{ - 'index': 934, - 'score': 0.7399742007255554, - 'class_name': 'cheeseburger' - }, { - 'index': 925, - 'score': 0.026928534731268883, - 'class_name': 'guacamole' - }, { - 'index': 932, - 'score': 0.025737214833498, - 'class_name': 'bagel' - }])) + @parameterized.parameters((ModelFileType.FILE_NAME, 3, """ + classifications { + classes { + index: 934 + score: 0.739974 + class_name: "cheeseburger" + } + classes { + index: 925 + score: 0.026929 + class_name: "guacamole" + } + classes { + index: 932 + score: 0.025737 + class_name: "bagel" + } + head_index: 0 + } + """), (ModelFileType.FILE_CONTENT, 3, """ + classifications { + classes { + index: 934 + score: 0.739974 + class_name: "cheeseburger" + } + classes { + index: 925 + score: 0.026929 + class_name: "guacamole" + } + classes { + index: 932 + score: 0.025737 + class_name: "bagel" + } + head_index: 0 + } + """)) def test_classify_model(self, model_file_type, max_results, - expected_categories): + expected_result_text_proto): # Creates classifier. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(file_name=self.model_path) @@ -151,14 +149,9 @@ def test_classify_model(self, model_file_type, max_results, # Classifies the input. image_result = classifier.classify(image, bounding_box=None) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - # Builds test data. - expected_result_dict = _build_test_data(expected_categories) # Comparing results (classification w/o bounding box). - self.assertDeepAlmostEqual( - image_result_dict, expected_result_dict, delta=_ACCEPTABLE_ERROR_RANGE) + self.assertProtoEquals(expected_result_text_proto, image_result) def test_classify_model_with_bounding_box(self): # Creates classifier. @@ -175,29 +168,35 @@ def test_classify_model_with_bounding_box(self): # Classifies the input. image_result = classifier.classify(image, bounding_box) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) # Expected results. - expected_categories = [{ - 'index': 934, - 'score': 0.8815076351165771, - 'class_name': 'cheeseburger' - }, { - 'index': 925, - 'score': 0.019456762820482254, - 'class_name': 'guacamole' - }, { - 'index': 932, - 'score': 0.012489477172493935, - 'class_name': 'bagel' - }] - - # Builds test data. - expected_result_dict = _build_test_data(expected_categories) + expected_result_text_proto = """ + classifications { + classes { + index: 934 + score: 0.881507 + class_name: "cheeseburger" + } + classes { + index: 925 + score: 0.019457 + class_name: "guacamole" + } + classes { + index: 932 + score: 0.012489 + class_name: "bagel" + } + head_index: 0 + } + """ # Comparing results (classification w/ bounding box). - self.assertDeepAlmostEqual( - image_result_dict, expected_result_dict, delta=_ACCEPTABLE_ERROR_RANGE) + expected_result = classifications_pb2.ClassificationResult() + test_util.parse_text_proto(expected_result_text_proto, expected_result) + classification_result = classifications_pb2.ClassificationResult() + classification_result.ParseFromString(image_result.SerializeToString()) + self.assertProtoEquals(classification_result, expected_result) def test_max_results_option(self): # Creates classifier. @@ -211,9 +210,7 @@ def test_max_results_option(self): # Classifies the input. image_result = classifier.classify(image, bounding_box=None) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - categories = image_result_dict['classifications'][0]['classes'] + categories = image_result.classifications[0].classes self.assertLessEqual( len(categories), _MAX_RESULTS, 'Too many results returned.') @@ -230,14 +227,11 @@ def test_score_threshold_option(self): # Classifies the input. image_result = classifier.classify(image, bounding_box=None) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - categories = image_result_dict['classifications'][0]['classes'] + categories = image_result.classifications[0].classes for category in categories: - score = category['score'] self.assertGreaterEqual( - score, _SCORE_THRESHOLD, + category.score, _SCORE_THRESHOLD, f'Classification with score lower than threshold found. {category}') def test_allowlist_option(self): @@ -252,12 +246,10 @@ def test_allowlist_option(self): # Classifies the input. image_result = classifier.classify(image, bounding_box=None) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - categories = image_result_dict['classifications'][0]['classes'] + categories = image_result.classifications[0].classes for category in categories: - label = category['className'] + label = category.class_name self.assertIn(label, _ALLOW_LIST, f'Label {label} found but not in label allow list') @@ -273,12 +265,10 @@ def test_denylist_option(self): # Classifies the input. image_result = classifier.classify(image, bounding_box=None) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - categories = image_result_dict['classifications'][0]['classes'] + categories = image_result.classifications[0].classes for category in categories: - label = category['className'] + label = category.class_name self.assertNotIn(label, _DENY_LIST, f'Label {label} found but in deny list.') diff --git a/tensorflow_lite_support/python/test/task/vision/object_detector_test.py b/tensorflow_lite_support/python/test/task/vision/object_detector_test.py index 4c406e2f5..eceb53e5a 100644 --- a/tensorflow_lite_support/python/test/task/vision/object_detector_test.py +++ b/tensorflow_lite_support/python/test/task/vision/object_detector_test.py @@ -14,20 +14,15 @@ """Tests for object detector.""" import enum -import json from absl.testing import parameterized import tensorflow as tf -from google.protobuf import json_format from tensorflow_lite_support.python.task.core.proto import base_options_pb2 -from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2 -from tensorflow_lite_support.python.task.processor.proto import class_pb2 from tensorflow_lite_support.python.task.processor.proto import detection_options_pb2 from tensorflow_lite_support.python.task.processor.proto import detections_pb2 from tensorflow_lite_support.python.task.vision import object_detector from tensorflow_lite_support.python.task.vision.core import tensor_image -from tensorflow_lite_support.python.test import base_test from tensorflow_lite_support.python.test import test_util _BaseOptions = base_options_pb2.BaseOptions @@ -36,53 +31,28 @@ _MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite' _IMAGE_FILE = 'cats_and_dogs.jpg' -_EXPECTED_DETECTIONS = [ - ({ - 'origin_x': 54, - 'origin_y': 396, - 'width': 393, - 'height': 196 - }, { - 'index': 16, - 'score': 0.64453125, - 'class_name': 'cat' - }), - ({ - 'origin_x': 602, - 'origin_y': 157, - 'width': 394, - 'height': 447 - }, { - 'index': 16, - 'score': 0.59765625, - 'class_name': 'cat' - }), - ({ - 'origin_x': 261, - 'origin_y': 394, - 'width': 179, - 'height': 209 - }, { - 'index': 16, - 'score': 0.5625, - 'class_name': 'cat' - }), - ({ - 'origin_x': 389, - 'origin_y': 197, - 'width': 276, - 'height': 409 - }, { - 'index': 17, - 'score': 0.51171875, - 'class_name': 'dog' - }) -] +_EXPECTED_DETECTIONS = """ +detections { + bounding_box { origin_x: 54 origin_y: 396 width: 393 height: 196 } + classes { index: 16 score: 0.64453125 class_name: "cat" } +} +detections { + bounding_box { origin_x: 602 origin_y: 157 width: 394 height: 447 } + classes { index: 16 score: 0.59765625 class_name: "cat" } +} +detections { + bounding_box { origin_x: 261 origin_y: 394 width: 179 height: 209 } + classes { index: 16 score: 0.5625 class_name: "cat" } +} +detections { + bounding_box { origin_x: 389 origin_y: 197 width: 276 height: 409 } + classes { index: 17 score: 0.51171875 class_name: "dog" } +} +""" _ALLOW_LIST = ['cat', 'dog'] _DENY_LIST = ['cat'] _SCORE_THRESHOLD = 0.3 _MAX_RESULTS = 3 -_ACCEPTABLE_ERROR_RANGE = 0.000001 class ModelFileType(enum.Enum): @@ -99,23 +69,7 @@ def _create_detector_from_options(base_options, **detection_options): return detector -def _build_test_data(expected_detections): - expected_result = detections_pb2.DetectionResult() - - for index in range(len(expected_detections)): - bounding_box, category = expected_detections[index] - detection = detections_pb2.Detection() - detection.bounding_box.CopyFrom( - bounding_box_pb2.BoundingBox(**bounding_box)) - detection.classes.append(class_pb2.Category(**category)) - expected_result.detections.append(detection) - - expected_result_dict = json.loads(json_format.MessageToJson(expected_result)) - - return expected_result_dict - - -class ObjectDetectorTest(parameterized.TestCase, base_test.BaseTestCase): +class ObjectDetectorTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): super().setUp() @@ -156,7 +110,7 @@ def test_create_from_options_succeeds_with_valid_model_content(self): (ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTIONS), (ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTIONS)) def test_detect_model(self, model_file_type, max_results, - expected_detections): + expected_result_text_proto): # Creates detector. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(file_name=self.model_path) @@ -176,14 +130,9 @@ def test_detect_model(self, model_file_type, max_results, # Performs object detection on the input. image_result = detector.detect(image) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - # Builds test data. - expected_result_dict = _build_test_data(expected_detections) # Comparing results. - self.assertDeepAlmostEqual( - image_result_dict, expected_result_dict, delta=_ACCEPTABLE_ERROR_RANGE) + self.assertProtoEquals(expected_result_text_proto, image_result) def test_score_threshold_option(self): # Creates detector. @@ -196,15 +145,13 @@ def test_score_threshold_option(self): # Performs object detection on the input. image_result = detector.detect(image) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - categories = image_result_dict['detections'] + detections = image_result.detections - for category in categories: - score = category['classes'][0]['score'] + for detection in detections: + score = detection.classes[0].score self.assertGreaterEqual( score, _SCORE_THRESHOLD, - f'Classification with score lower than threshold found. {category}') + f'Detection with score lower than threshold found. {detection}') def test_max_results_option(self): # Creates detector. @@ -217,8 +164,7 @@ def test_max_results_option(self): # Performs object detection on the input. image_result = detector.detect(image) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - detections = image_result_dict['detections'] + detections = image_result.detections self.assertLessEqual( len(detections), _MAX_RESULTS, 'Too many results returned.') @@ -234,12 +180,10 @@ def test_allow_list_option(self): # Performs object detection on the input. image_result = detector.detect(image) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) + detections = image_result.detections - categories = image_result_dict['detections'] - - for category in categories: - label = category['classes'][0]['className'] + for detection in detections: + label = detection.classes[0].class_name self.assertIn(label, _ALLOW_LIST, f'Label {label} found but not in label allow list') @@ -254,12 +198,10 @@ def test_deny_list_option(self): # Performs object detection on the input. image_result = detector.detect(image) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - categories = image_result_dict['detections'] + detections = image_result.detections - for category in categories: - label = category['classes'][0]['className'] + for detection in detections: + label = detection.classes[0].class_name self.assertNotIn(label, _DENY_LIST, f'Label {label} found but in deny list.')