diff --git a/tensorflow_lite_support/cc/task/processor/proto/BUILD b/tensorflow_lite_support/cc/task/processor/proto/BUILD index 9b64c8593..a2f2b83b9 100644 --- a/tensorflow_lite_support/cc/task/processor/proto/BUILD +++ b/tensorflow_lite_support/cc/task/processor/proto/BUILD @@ -44,6 +44,9 @@ support_py_proto_library( srcs = ["classifications.proto"], api_version = 2, proto_deps = [":classifications_proto"], + py_proto_deps = [ + ":class_py_pb2", + ], ) proto_library( diff --git a/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py b/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py index 45abcde86..80eed9145 100644 --- a/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py +++ b/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py @@ -23,7 +23,6 @@ from tensorflow_lite_support.python.task.audio.core import audio_record from tensorflow_lite_support.python.task.audio.core import tensor_audio from tensorflow_lite_support.python.task.core.proto import base_options_pb2 -from tensorflow_lite_support.python.task.processor.proto import class_pb2 from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2 from tensorflow_lite_support.python.task.processor.proto import classifications_pb2 from tensorflow_lite_support.python.test import test_util @@ -35,62 +34,75 @@ _FIXED_INPUT_SIZE_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite' _SPEECH_AUDIO_FILE = 'speech.wav' -_FIXED_INPUT_SIZE_MODEL_CLASSIFICATIONS = { - 'scores': [ - { - 'index': 0, - 'score': 0.917969, - 'class_name': 'Speech' - }, - { - 'index': 500, - 'score': 0.058594, - 'class_name': 'Inside, small room' - }, - { - 'index': 494, - 'score': 0.015625, - 'class_name': 'Silence' - } - ] +_FIXED_INPUT_SIZE_MODEL_CLASSIFICATIONS = """ +classifications { + classes { + index: 0 + score: 0.917969 + class_name: "Speech" + } + classes { + index: 500 + score: 0.058594 + class_name: "Inside, small room" + } + classes { + index: 494 + score: 0.013672 + class_name: "Silence" + } + head_index: 0 + head_name: "scores" } +""" _MULTIHEAD_MODEL_FILE = 'two_heads.tflite' _TWO_HEADS_AUDIO_FILE = 'two_heads.wav' -_MULTIHEAD_MODEL_CLASSIFICATIONS = { - 'yamnet_classification': [{ - 'index': 508, - 'score': 0.548616, - 'class_name': 'Environmental noise' - }, { - 'index': 507, - 'score': 0.380869, - 'class_name': 'Noise' - }, { - 'index': 106, - 'score': 0.256137, - 'class_name': 'Bird' - }], - 'bird_classification': [{ - 'index': 4, - 'score': 0.933997, - 'class_name': 'Chestnut-crowned Antpitta' - }, { - 'index': 1, - 'score': 0.065934, - 'class_name': 'White-breasted Wood-Wren' - }, { - 'index': 0, - 'score': 6.1469495e-05, - 'class_name': 'Red Crossbill' - }] +_MULTIHEAD_MODEL_CLASSIFICATIONS = """ +classifications { + classes { + index: 508 + score: 0.548616 + class_name: "Environmental noise" + } + classes { + index: 507 + score: 0.380869 + class_name: "Noise" + } + classes { + index: 106 + score: 0.256137 + class_name: "Bird" + } + head_index: 0 + head_name: "yamnet_classification" } +classifications { + classes { + index: 4 + score: 0.933997 + class_name: "Chestnut-crowned Antpitta" + } + classes { + index: 1 + score: 0.065934 + class_name: "White-breasted Wood-Wren" + } + classes { + index: 0 + score: 6.1469495e-05 + class_name: "Red Crossbill" + } + head_index: 1 + head_name: "bird_classification" +} +""" _ALLOW_LIST = ['Speech', 'Inside, small room'] _DENY_LIST = ['Speech'] _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 -_ACCEPTABLE_ERROR_RANGE = 0.005 class ModelFileType(enum.Enum): @@ -107,19 +119,6 @@ def _create_classifier_from_options(base_options, **classification_options): return classifier -def _build_test_data(classifications): - expected_result = classifications_pb2.ClassificationResult() - - for index, (head_name, categories) in enumerate(classifications.items()): - classifications = classifications_pb2.Classifications( - head_index=index, head_name=head_name) - classifications.classes.extend( - [class_pb2.Category(**args) for args in categories]) - expected_result.classifications.append(classifications) - - return expected_result - - class AudioClassifierTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): @@ -192,7 +191,7 @@ def test_create_audio_record_from_classifier_succeeds(self, _): (_MULTIHEAD_MODEL_FILE, ModelFileType.FILE_CONTENT, _TWO_HEADS_AUDIO_FILE, 3, _MULTIHEAD_MODEL_CLASSIFICATIONS)) def test_classify_model(self, model_name, model_file_type, audio_file_name, - max_results, expected_classifications): + max_results, expected_result_text_proto): # Creates classifier. model_path = test_util.get_test_data_path(model_name) if model_file_type is ModelFileType.FILE_NAME: @@ -216,13 +215,8 @@ def test_classify_model(self, model_name, model_file_type, audio_file_name, # Classifies the input. audio_result = classifier.classify(tensor) - # Builds test data. - expected_result = _build_test_data(expected_classifications) - # Comparing results. - classification_result = classifications_pb2.ClassificationResult() - classification_result.ParseFromString(audio_result.SerializeToString()) - self.assertProtoEquals(classification_result, expected_result) + self.assertProtoEquals(expected_result_text_proto, audio_result) def test_max_results_option(self): # Creates classifier. diff --git a/tensorflow_lite_support/python/test/task/vision/BUILD b/tensorflow_lite_support/python/test/task/vision/BUILD index eeb47668e..640f0b238 100644 --- a/tensorflow_lite_support/python/test/task/vision/BUILD +++ b/tensorflow_lite_support/python/test/task/vision/BUILD @@ -36,7 +36,6 @@ py_test( # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", "//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2", - "//tensorflow_lite_support/python/task/processor/proto:class_pb2", "//tensorflow_lite_support/python/task/processor/proto:classification_options_pb2", "//tensorflow_lite_support/python/task/processor/proto:classifications_pb2", "//tensorflow_lite_support/python/task/vision:image_classifier", @@ -75,8 +74,6 @@ py_test( deps = [ # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", - "//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2", - "//tensorflow_lite_support/python/task/processor/proto:class_pb2", "//tensorflow_lite_support/python/task/processor/proto:detection_options_pb2", "//tensorflow_lite_support/python/task/processor/proto:detections_pb2", "//tensorflow_lite_support/python/task/vision:object_detector", diff --git a/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py b/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py index d6dcb60f0..b184ce7d2 100644 --- a/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py +++ b/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py @@ -20,7 +20,6 @@ from tensorflow_lite_support.python.task.core.proto import base_options_pb2 from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2 -from tensorflow_lite_support.python.task.processor.proto import class_pb2 from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2 from tensorflow_lite_support.python.task.processor.proto import classifications_pb2 from tensorflow_lite_support.python.task.vision import image_classifier @@ -37,7 +36,6 @@ _DENY_LIST = ['cheeseburger'] _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 -_ACCEPTABLE_ERROR_RANGE = 0.000001 def _create_classifier_from_options(base_options, **classification_options): @@ -49,16 +47,6 @@ def _create_classifier_from_options(base_options, **classification_options): return classifier -def _build_test_data(expected_categories): - classifications = classifications_pb2.Classifications(head_index=0) - classifications.classes.extend( - [class_pb2.Category(**args) for args in expected_categories]) - expected_result = classifications_pb2.ClassificationResult() - expected_result.classifications.append(classifications) - - return expected_result - - class ModelFileType(enum.Enum): FILE_CONTENT = 1 FILE_NAME = 2 @@ -101,33 +89,47 @@ def test_create_from_options_succeeds_with_valid_model_content(self): classifier = _ImageClassifier.create_from_options(options) self.assertIsInstance(classifier, _ImageClassifier) - @parameterized.parameters((ModelFileType.FILE_NAME, 3, [{ - 'index': 934, - 'score': 0.739974, - 'class_name': 'cheeseburger' - }, { - 'index': 925, - 'score': 0.026929, - 'class_name': 'guacamole' - }, { - 'index': 932, - 'score': 0.025737, - 'class_name': 'bagel' - }]), (ModelFileType.FILE_CONTENT, 3, [{ - 'index': 934, - 'score': 0.739974, - 'class_name': 'cheeseburger' - }, { - 'index': 925, - 'score': 0.026929, - 'class_name': 'guacamole' - }, { - 'index': 932, - 'score': 0.025737, - 'class_name': 'bagel' - }])) + @parameterized.parameters((ModelFileType.FILE_NAME, 3, """ + classifications { + classes { + index: 934 + score: 0.739974 + class_name: "cheeseburger" + } + classes { + index: 925 + score: 0.026929 + class_name: "guacamole" + } + classes { + index: 932 + score: 0.025737 + class_name: "bagel" + } + head_index: 0 + } + """), (ModelFileType.FILE_CONTENT, 3, """ + classifications { + classes { + index: 934 + score: 0.739974 + class_name: "cheeseburger" + } + classes { + index: 925 + score: 0.026929 + class_name: "guacamole" + } + classes { + index: 932 + score: 0.025737 + class_name: "bagel" + } + head_index: 0 + } + """)) def test_classify_model(self, model_file_type, max_results, - expected_categories): + expected_result_text_proto): # Creates classifier. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(file_name=self.model_path) @@ -148,13 +150,8 @@ def test_classify_model(self, model_file_type, max_results, # Classifies the input. image_result = classifier.classify(image, bounding_box=None) - # Builds test data. - expected_result = _build_test_data(expected_categories) - # Comparing results (classification w/o bounding box). - classification_result = classifications_pb2.ClassificationResult() - classification_result.ParseFromString(image_result.SerializeToString()) - self.assertProtoEquals(classification_result, expected_result) + self.assertProtoEquals(expected_result_text_proto, image_result) def test_classify_model_with_bounding_box(self): # Creates classifier. @@ -173,27 +170,29 @@ def test_classify_model_with_bounding_box(self): image_result = classifier.classify(image, bounding_box) # Expected results. - expected_categories = [{ - 'index': 934, - 'score': 0.881507, - 'class_name': 'cheeseburger' - }, { - 'index': 925, - 'score': 0.019457, - 'class_name': 'guacamole' - }, { - 'index': 932, - 'score': 0.012489, - 'class_name': 'bagel' - }] - - # Builds test data. - expected_result = _build_test_data(expected_categories) + expected_result_text_proto = """ + classifications { + classes { + index: 934 + score: 0.881507 + class_name: "cheeseburger" + } + classes { + index: 925 + score: 0.019457 + class_name: "guacamole" + } + classes { + index: 932 + score: 0.012489 + class_name: "bagel" + } + head_index: 0 + } + """ # Comparing results (classification w/ bounding box). - classification_result = classifications_pb2.ClassificationResult() - classification_result.ParseFromString(image_result.SerializeToString()) - self.assertProtoEquals(classification_result, expected_result) + self.assertProtoEquals(expected_result_text_proto, image_result) def test_max_results_option(self): # Creates classifier. diff --git a/tensorflow_lite_support/python/test/task/vision/object_detector_test.py b/tensorflow_lite_support/python/test/task/vision/object_detector_test.py index 72012956d..eceb53e5a 100644 --- a/tensorflow_lite_support/python/test/task/vision/object_detector_test.py +++ b/tensorflow_lite_support/python/test/task/vision/object_detector_test.py @@ -19,8 +19,6 @@ import tensorflow as tf from tensorflow_lite_support.python.task.core.proto import base_options_pb2 -from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2 -from tensorflow_lite_support.python.task.processor.proto import class_pb2 from tensorflow_lite_support.python.task.processor.proto import detection_options_pb2 from tensorflow_lite_support.python.task.processor.proto import detections_pb2 from tensorflow_lite_support.python.task.vision import object_detector @@ -33,53 +31,28 @@ _MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite' _IMAGE_FILE = 'cats_and_dogs.jpg' -_EXPECTED_DETECTIONS = [ - ({ - 'origin_x': 54, - 'origin_y': 396, - 'width': 393, - 'height': 196 - }, { - 'index': 16, - 'score': 0.64453125, - 'class_name': 'cat' - }), - ({ - 'origin_x': 602, - 'origin_y': 157, - 'width': 394, - 'height': 447 - }, { - 'index': 16, - 'score': 0.59765625, - 'class_name': 'cat' - }), - ({ - 'origin_x': 261, - 'origin_y': 394, - 'width': 179, - 'height': 209 - }, { - 'index': 16, - 'score': 0.5625, - 'class_name': 'cat' - }), - ({ - 'origin_x': 389, - 'origin_y': 197, - 'width': 276, - 'height': 409 - }, { - 'index': 17, - 'score': 0.51171875, - 'class_name': 'dog' - }) -] +_EXPECTED_DETECTIONS = """ +detections { + bounding_box { origin_x: 54 origin_y: 396 width: 393 height: 196 } + classes { index: 16 score: 0.64453125 class_name: "cat" } +} +detections { + bounding_box { origin_x: 602 origin_y: 157 width: 394 height: 447 } + classes { index: 16 score: 0.59765625 class_name: "cat" } +} +detections { + bounding_box { origin_x: 261 origin_y: 394 width: 179 height: 209 } + classes { index: 16 score: 0.5625 class_name: "cat" } +} +detections { + bounding_box { origin_x: 389 origin_y: 197 width: 276 height: 409 } + classes { index: 17 score: 0.51171875 class_name: "dog" } +} +""" _ALLOW_LIST = ['cat', 'dog'] _DENY_LIST = ['cat'] _SCORE_THRESHOLD = 0.3 _MAX_RESULTS = 3 -_ACCEPTABLE_ERROR_RANGE = 0.000001 class ModelFileType(enum.Enum): @@ -96,20 +69,6 @@ def _create_detector_from_options(base_options, **detection_options): return detector -def _build_test_data(expected_detections): - expected_result = detections_pb2.DetectionResult() - - for index in range(len(expected_detections)): - bounding_box, category = expected_detections[index] - detection = detections_pb2.Detection() - detection.bounding_box.CopyFrom( - bounding_box_pb2.BoundingBox(**bounding_box)) - detection.classes.append(class_pb2.Category(**category)) - expected_result.detections.append(detection) - - return expected_result - - class ObjectDetectorTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): @@ -151,7 +110,7 @@ def test_create_from_options_succeeds_with_valid_model_content(self): (ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTIONS), (ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTIONS)) def test_detect_model(self, model_file_type, max_results, - expected_detections): + expected_result_text_proto): # Creates detector. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(file_name=self.model_path) @@ -172,13 +131,8 @@ def test_detect_model(self, model_file_type, max_results, # Performs object detection on the input. image_result = detector.detect(image) - # Builds test data. - expected_result = _build_test_data(expected_detections) - # Comparing results. - detection_result = detections_pb2.DetectionResult() - detection_result.ParseFromString(image_result.SerializeToString()) - self.assertProtoEquals(detection_result, expected_result) + self.assertProtoEquals(expected_result_text_proto, image_result) def test_score_threshold_option(self): # Creates detector.