Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions tensorflow_lite_support/python/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@ package(
licenses = ["notice"], # Apache 2.0
)

py_library(
name = "base_test",
testonly = 1,
srcs = ["base_test.py"],
srcs_version = "PY3",
)

py_library(
name = "test_util",
testonly = 1,
Expand Down
50 changes: 0 additions & 50 deletions tensorflow_lite_support/python/test/base_test.py

This file was deleted.

2 changes: 0 additions & 2 deletions tensorflow_lite_support/python/test/task/audio/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ py_test(
"//tensorflow_lite_support/python/task/processor/proto:class_pb2",
"//tensorflow_lite_support/python/task/processor/proto:classification_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:classifications_pb2",
"//tensorflow_lite_support/python/test:base_test",
"//tensorflow_lite_support/python/test:test_util",
"@absl_py//absl/testing:parameterized",
"@com_google_protobuf//:protobuf_python",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,17 @@
"""Tests for audio_classifier."""

import enum
import json

from absl.testing import parameterized
import tensorflow as tf

from google.protobuf import json_format
import unittest
from tensorflow_lite_support.python.task.audio import audio_classifier
from tensorflow_lite_support.python.task.audio.core import audio_record
from tensorflow_lite_support.python.task.audio.core import tensor_audio
from tensorflow_lite_support.python.task.core.proto import base_options_pb2
from tensorflow_lite_support.python.task.processor.proto import class_pb2
from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2
from tensorflow_lite_support.python.task.processor.proto import classifications_pb2
from tensorflow_lite_support.python.test import base_test
from tensorflow_lite_support.python.test import test_util

_mock = unittest.mock
Expand All @@ -38,58 +34,75 @@

_FIXED_INPUT_SIZE_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite'
_SPEECH_AUDIO_FILE = 'speech.wav'
_FIXED_INPUT_SIZE_MODEL_CLASSIFICATIONS = {
'scores': [{
'index': 0,
'score': 0.91796875,
'class_name': 'Speech'
}, {
'index': 500,
'score': 0.05859375,
'class_name': 'Inside, small room'
}, {
'index': 494,
'score': 0.01367188,
'class_name': 'Silence'
}]
_FIXED_INPUT_SIZE_MODEL_CLASSIFICATIONS = """
classifications {
classes {
index: 0
score: 0.917969
class_name: "Speech"
}
classes {
index: 500
score: 0.058594
class_name: "Inside, small room"
}
classes {
index: 494
score: 0.013672
class_name: "Silence"
}
head_index: 0
head_name: "scores"
}
"""

_MULTIHEAD_MODEL_FILE = 'two_heads.tflite'
_TWO_HEADS_AUDIO_FILE = 'two_heads.wav'
_MULTIHEAD_MODEL_CLASSIFICATIONS = {
'yamnet_classification': [{
'index': 508,
'score': 0.5486158,
'class_name': 'Environmental noise'
}, {
'index': 507,
'score': 0.38086897,
'class_name': 'Noise'
}, {
'index': 106,
'score': 0.25613675,
'class_name': 'Bird'
}],
'bird_classification': [{
'index': 4,
'score': 0.93399656,
'class_name': 'Chestnut-crowned Antpitta'
}, {
'index': 1,
'score': 0.065934494,
'class_name': 'White-breasted Wood-Wren'
}, {
'index': 0,
'score': 6.1469495e-05,
'class_name': 'Red Crossbill'
}]
_MULTIHEAD_MODEL_CLASSIFICATIONS = """
classifications {
classes {
index: 508
score: 0.548616
class_name: "Environmental noise"
}
classes {
index: 507
score: 0.380869
class_name: "Noise"
}
classes {
index: 106
score: 0.256137
class_name: "Bird"
}
head_index: 0
head_name: "yamnet_classification"
}
classifications {
classes {
index: 4
score: 0.933997
class_name: "Chestnut-crowned Antpitta"
}
classes {
index: 1
score: 0.065934
class_name: "White-breasted Wood-Wren"
}
classes {
index: 0
score: 6.1469495e-05
class_name: "Red Crossbill"
}
head_index: 1
head_name: "bird_classification"
}
"""

_ALLOW_LIST = ['Speech', 'Inside, small room']
_DENY_LIST = ['Speech']
_SCORE_THRESHOLD = 0.5
_MAX_RESULTS = 3
_ACCEPTABLE_ERROR_RANGE = 0.005


class ModelFileType(enum.Enum):
Expand All @@ -106,22 +119,7 @@ def _create_classifier_from_options(base_options, **classification_options):
return classifier


def _build_test_data(classifications):
expected_result = classifications_pb2.ClassificationResult()

for index, (head_name, categories) in enumerate(classifications.items()):
classifications = classifications_pb2.Classifications(
head_index=index, head_name=head_name)
classifications.classes.extend(
[class_pb2.Category(**args) for args in categories])
expected_result.classifications.append(classifications)

expected_result_dict = json.loads(json_format.MessageToJson(expected_result))

return expected_result_dict


class AudioClassifierTest(parameterized.TestCase, base_test.BaseTestCase):
class AudioClassifierTest(parameterized.TestCase, tf.test.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -193,7 +191,7 @@ def test_create_audio_record_from_classifier_succeeds(self, _):
(_MULTIHEAD_MODEL_FILE, ModelFileType.FILE_CONTENT, _TWO_HEADS_AUDIO_FILE,
3, _MULTIHEAD_MODEL_CLASSIFICATIONS))
def test_classify_model(self, model_name, model_file_type, audio_file_name,
max_results, expected_classifications):
max_results, expected_result_text_proto):
# Creates classifier.
model_path = test_util.get_test_data_path(model_name)
if model_file_type is ModelFileType.FILE_NAME:
Expand All @@ -216,14 +214,9 @@ def test_classify_model(self, model_name, model_file_type, audio_file_name,

# Classifies the input.
audio_result = classifier.classify(tensor)
audio_result_dict = json.loads(json_format.MessageToJson(audio_result))

# Builds test data.
expected_result_dict = _build_test_data(expected_classifications)

# Comparing results.
self.assertDeepAlmostEqual(
audio_result_dict, expected_result_dict, delta=_ACCEPTABLE_ERROR_RANGE)
self.assertProtoEquals(expected_result_text_proto, audio_result)

def test_max_results_option(self):
# Creates classifier.
Expand All @@ -238,9 +231,7 @@ def test_max_results_option(self):

# Classifies the input.
audio_result = classifier.classify(tensor)
audio_result_dict = json.loads(json_format.MessageToJson(audio_result))

categories = audio_result_dict['classifications'][0]['classes']
categories = audio_result.classifications[0].classes

self.assertLessEqual(
len(categories), _MAX_RESULTS, 'Too many results returned.')
Expand All @@ -258,14 +249,11 @@ def test_score_threshold_option(self):

# Classifies the input.
audio_result = classifier.classify(tensor)
audio_result_dict = json.loads(json_format.MessageToJson(audio_result))

categories = audio_result_dict['classifications'][0]['classes']
categories = audio_result.classifications[0].classes

for category in categories:
score = category['score']
self.assertGreaterEqual(
score, _SCORE_THRESHOLD,
category.score, _SCORE_THRESHOLD,
'Classification with score lower than threshold found. {0}'.format(
category))

Expand All @@ -282,12 +270,10 @@ def test_allowlist_option(self):

# Classifies the input.
audio_result = classifier.classify(tensor)
audio_result_dict = json.loads(json_format.MessageToJson(audio_result))

categories = audio_result_dict['classifications'][0]['classes']
categories = audio_result.classifications[0].classes

for category in categories:
label = category['className']
label = category.class_name
self.assertIn(
label, _ALLOW_LIST,
'Label "{0}" found but not in label allow list'.format(label))
Expand All @@ -305,12 +291,10 @@ def test_denylist_option(self):

# Classifies the input.
audio_result = classifier.classify(tensor)
audio_result_dict = json.loads(json_format.MessageToJson(audio_result))

categories = audio_result_dict['classifications'][0]['classes']
categories = audio_result.classifications[0].classes

for category in categories:
label = category['className']
label = category.class_name
self.assertNotIn(label, _DENY_LIST,
'Label "{0}" found but in deny list.'.format(label))

Expand Down
6 changes: 0 additions & 6 deletions tensorflow_lite_support/python/test/task/vision/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ py_test(
"//tensorflow_lite_support/python/task/processor/proto:classifications_pb2",
"//tensorflow_lite_support/python/task/vision:image_classifier",
"//tensorflow_lite_support/python/task/vision/core:tensor_image",
"//tensorflow_lite_support/python/test:base_test",
"//tensorflow_lite_support/python/test:test_util",
"@absl_py//absl/testing:parameterized",
"@com_google_protobuf//:protobuf_python",
],
)

Expand Down Expand Up @@ -78,15 +76,11 @@ py_test(
deps = [
# build rule placeholder: tensorflow dep,
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2",
"//tensorflow_lite_support/python/task/processor/proto:class_pb2",
"//tensorflow_lite_support/python/task/processor/proto:detection_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:detections_pb2",
"//tensorflow_lite_support/python/task/vision:object_detector",
"//tensorflow_lite_support/python/task/vision/core:tensor_image",
"//tensorflow_lite_support/python/test:base_test",
"//tensorflow_lite_support/python/test:test_util",
"@absl_py//absl/testing:parameterized",
"@com_google_protobuf//:protobuf_python",
],
)
Loading