Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorflow_lite_support/cc/task/processor/proto/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ support_py_proto_library(
srcs = ["classifications.proto"],
api_version = 2,
proto_deps = [":classifications_proto"],
py_proto_deps = [
":class_py_pb2",
],
)

proto_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tensorflow_lite_support.python.task.audio.core import audio_record
from tensorflow_lite_support.python.task.audio.core import tensor_audio
from tensorflow_lite_support.python.task.core.proto import base_options_pb2
from tensorflow_lite_support.python.task.processor.proto import class_pb2
from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2
from tensorflow_lite_support.python.task.processor.proto import classifications_pb2
from tensorflow_lite_support.python.test import test_util
Expand All @@ -35,62 +34,75 @@

_FIXED_INPUT_SIZE_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite'
_SPEECH_AUDIO_FILE = 'speech.wav'
_FIXED_INPUT_SIZE_MODEL_CLASSIFICATIONS = {
'scores': [
{
'index': 0,
'score': 0.917969,
'class_name': 'Speech'
},
{
'index': 500,
'score': 0.058594,
'class_name': 'Inside, small room'
},
{
'index': 494,
'score': 0.015625,
'class_name': 'Silence'
}
]
_FIXED_INPUT_SIZE_MODEL_CLASSIFICATIONS = """
classifications {
classes {
index: 0
score: 0.917969
class_name: "Speech"
}
classes {
index: 500
score: 0.058594
class_name: "Inside, small room"
}
classes {
index: 494
score: 0.013672
class_name: "Silence"
}
head_index: 0
head_name: "scores"
}
"""

_MULTIHEAD_MODEL_FILE = 'two_heads.tflite'
_TWO_HEADS_AUDIO_FILE = 'two_heads.wav'
_MULTIHEAD_MODEL_CLASSIFICATIONS = {
'yamnet_classification': [{
'index': 508,
'score': 0.548616,
'class_name': 'Environmental noise'
}, {
'index': 507,
'score': 0.380869,
'class_name': 'Noise'
}, {
'index': 106,
'score': 0.256137,
'class_name': 'Bird'
}],
'bird_classification': [{
'index': 4,
'score': 0.933997,
'class_name': 'Chestnut-crowned Antpitta'
}, {
'index': 1,
'score': 0.065934,
'class_name': 'White-breasted Wood-Wren'
}, {
'index': 0,
'score': 6.1469495e-05,
'class_name': 'Red Crossbill'
}]
_MULTIHEAD_MODEL_CLASSIFICATIONS = """
classifications {
classes {
index: 508
score: 0.548616
class_name: "Environmental noise"
}
classes {
index: 507
score: 0.380869
class_name: "Noise"
}
classes {
index: 106
score: 0.256137
class_name: "Bird"
}
head_index: 0
head_name: "yamnet_classification"
}
classifications {
classes {
index: 4
score: 0.933997
class_name: "Chestnut-crowned Antpitta"
}
classes {
index: 1
score: 0.065934
class_name: "White-breasted Wood-Wren"
}
classes {
index: 0
score: 6.1469495e-05
class_name: "Red Crossbill"
}
head_index: 1
head_name: "bird_classification"
}
"""

_ALLOW_LIST = ['Speech', 'Inside, small room']
_DENY_LIST = ['Speech']
_SCORE_THRESHOLD = 0.5
_MAX_RESULTS = 3
_ACCEPTABLE_ERROR_RANGE = 0.005


class ModelFileType(enum.Enum):
Expand All @@ -107,19 +119,6 @@ def _create_classifier_from_options(base_options, **classification_options):
return classifier


def _build_test_data(classifications):
expected_result = classifications_pb2.ClassificationResult()

for index, (head_name, categories) in enumerate(classifications.items()):
classifications = classifications_pb2.Classifications(
head_index=index, head_name=head_name)
classifications.classes.extend(
[class_pb2.Category(**args) for args in categories])
expected_result.classifications.append(classifications)

return expected_result


class AudioClassifierTest(parameterized.TestCase, tf.test.TestCase):

def setUp(self):
Expand Down Expand Up @@ -192,7 +191,7 @@ def test_create_audio_record_from_classifier_succeeds(self, _):
(_MULTIHEAD_MODEL_FILE, ModelFileType.FILE_CONTENT, _TWO_HEADS_AUDIO_FILE,
3, _MULTIHEAD_MODEL_CLASSIFICATIONS))
def test_classify_model(self, model_name, model_file_type, audio_file_name,
max_results, expected_classifications):
max_results, expected_result_text_proto):
# Creates classifier.
model_path = test_util.get_test_data_path(model_name)
if model_file_type is ModelFileType.FILE_NAME:
Expand All @@ -216,13 +215,8 @@ def test_classify_model(self, model_name, model_file_type, audio_file_name,
# Classifies the input.
audio_result = classifier.classify(tensor)

# Builds test data.
expected_result = _build_test_data(expected_classifications)

# Comparing results.
classification_result = classifications_pb2.ClassificationResult()
classification_result.ParseFromString(audio_result.SerializeToString())
self.assertProtoEquals(classification_result, expected_result)
self.assertProtoEquals(expected_result_text_proto, audio_result)

def test_max_results_option(self):
# Creates classifier.
Expand Down
3 changes: 0 additions & 3 deletions tensorflow_lite_support/python/test/task/vision/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ py_test(
# build rule placeholder: tensorflow dep,
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2",
"//tensorflow_lite_support/python/task/processor/proto:class_pb2",
"//tensorflow_lite_support/python/task/processor/proto:classification_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:classifications_pb2",
"//tensorflow_lite_support/python/task/vision:image_classifier",
Expand Down Expand Up @@ -75,8 +74,6 @@ py_test(
deps = [
# build rule placeholder: tensorflow dep,
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2",
"//tensorflow_lite_support/python/task/processor/proto:class_pb2",
"//tensorflow_lite_support/python/task/processor/proto:detection_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:detections_pb2",
"//tensorflow_lite_support/python/task/vision:object_detector",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from tensorflow_lite_support.python.task.core.proto import base_options_pb2
from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2
from tensorflow_lite_support.python.task.processor.proto import class_pb2
from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2
from tensorflow_lite_support.python.task.processor.proto import classifications_pb2
from tensorflow_lite_support.python.task.vision import image_classifier
Expand All @@ -37,7 +36,6 @@
_DENY_LIST = ['cheeseburger']
_SCORE_THRESHOLD = 0.5
_MAX_RESULTS = 3
_ACCEPTABLE_ERROR_RANGE = 0.000001


def _create_classifier_from_options(base_options, **classification_options):
Expand All @@ -49,16 +47,6 @@ def _create_classifier_from_options(base_options, **classification_options):
return classifier


def _build_test_data(expected_categories):
classifications = classifications_pb2.Classifications(head_index=0)
classifications.classes.extend(
[class_pb2.Category(**args) for args in expected_categories])
expected_result = classifications_pb2.ClassificationResult()
expected_result.classifications.append(classifications)

return expected_result


class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
Expand Down Expand Up @@ -101,33 +89,47 @@ def test_create_from_options_succeeds_with_valid_model_content(self):
classifier = _ImageClassifier.create_from_options(options)
self.assertIsInstance(classifier, _ImageClassifier)

@parameterized.parameters((ModelFileType.FILE_NAME, 3, [{
'index': 934,
'score': 0.739974,
'class_name': 'cheeseburger'
}, {
'index': 925,
'score': 0.026929,
'class_name': 'guacamole'
}, {
'index': 932,
'score': 0.025737,
'class_name': 'bagel'
}]), (ModelFileType.FILE_CONTENT, 3, [{
'index': 934,
'score': 0.739974,
'class_name': 'cheeseburger'
}, {
'index': 925,
'score': 0.026929,
'class_name': 'guacamole'
}, {
'index': 932,
'score': 0.025737,
'class_name': 'bagel'
}]))
@parameterized.parameters((ModelFileType.FILE_NAME, 3, """
classifications {
classes {
index: 934
score: 0.739974
class_name: "cheeseburger"
}
classes {
index: 925
score: 0.026929
class_name: "guacamole"
}
classes {
index: 932
score: 0.025737
class_name: "bagel"
}
head_index: 0
}
"""), (ModelFileType.FILE_CONTENT, 3, """
classifications {
classes {
index: 934
score: 0.739974
class_name: "cheeseburger"
}
classes {
index: 925
score: 0.026929
class_name: "guacamole"
}
classes {
index: 932
score: 0.025737
class_name: "bagel"
}
head_index: 0
}
"""))
def test_classify_model(self, model_file_type, max_results,
expected_categories):
expected_result_text_proto):
# Creates classifier.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(file_name=self.model_path)
Expand All @@ -148,13 +150,8 @@ def test_classify_model(self, model_file_type, max_results,
# Classifies the input.
image_result = classifier.classify(image, bounding_box=None)

# Builds test data.
expected_result = _build_test_data(expected_categories)

# Comparing results (classification w/o bounding box).
classification_result = classifications_pb2.ClassificationResult()
classification_result.ParseFromString(image_result.SerializeToString())
self.assertProtoEquals(classification_result, expected_result)
self.assertProtoEquals(expected_result_text_proto, image_result)

def test_classify_model_with_bounding_box(self):
# Creates classifier.
Expand All @@ -173,27 +170,29 @@ def test_classify_model_with_bounding_box(self):
image_result = classifier.classify(image, bounding_box)

# Expected results.
expected_categories = [{
'index': 934,
'score': 0.881507,
'class_name': 'cheeseburger'
}, {
'index': 925,
'score': 0.019457,
'class_name': 'guacamole'
}, {
'index': 932,
'score': 0.012489,
'class_name': 'bagel'
}]

# Builds test data.
expected_result = _build_test_data(expected_categories)
expected_result_text_proto = """
classifications {
classes {
index: 934
score: 0.881507
class_name: "cheeseburger"
}
classes {
index: 925
score: 0.019457
class_name: "guacamole"
}
classes {
index: 932
score: 0.012489
class_name: "bagel"
}
head_index: 0
}
"""

# Comparing results (classification w/ bounding box).
classification_result = classifications_pb2.ClassificationResult()
classification_result.ParseFromString(image_result.SerializeToString())
self.assertProtoEquals(classification_result, expected_result)
self.assertProtoEquals(expected_result_text_proto, image_result)

def test_max_results_option(self):
# Creates classifier.
Expand Down
Loading