Skip to content

Commit

Permalink
Use cpu version of tflite inference model
Browse files Browse the repository at this point in the history
  • Loading branch information
xeonqq committed Mar 18, 2024
1 parent 844ade7 commit e13f0c8
Show file tree
Hide file tree
Showing 8 changed files with 503 additions and 30 deletions.
Binary file added models/ssd_mobilenet_v2_coco_quant_no_nms.tflite
Binary file not shown.
Binary file added models/traffic_light.tflite
Binary file not shown.
106 changes: 106 additions & 0 deletions src/adapters/classify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Lint as: python3
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to work with a classification model."""

import collections
import operator
import numpy as np


Class = collections.namedtuple('Class', ['id', 'score'])
"""Represents a single classification, with the following fields:
.. py:attribute:: id
The class id.
.. py:attribute:: score
The prediction score.
"""


def num_classes(interpreter):
"""Gets the number of classes output by a classification model.
Args:
interpreter: The ``tf.lite.Interpreter`` holding the model.
Returns:
The total number of classes output by the model.
"""
return np.prod(interpreter.get_output_details()[0]['shape'])


def get_scores(interpreter):
"""Gets the output (all scores) from a classification model, dequantizing it if necessary.
Args:
interpreter: The ``tf.lite.Interpreter`` to query for output.
Returns:
The output tensor (flattened and dequantized) as :obj:`numpy.array`.
"""
output_details = interpreter.get_output_details()[0]
output_data = interpreter.tensor(output_details['index'])().flatten()

if np.issubdtype(output_details['dtype'], np.integer):
scale, zero_point = output_details['quantization']
# Always convert to np.int64 to avoid overflow on subtraction.
return scale * (output_data.astype(np.int64) - zero_point)

return output_data.copy()


def get_classes_from_scores(scores,
top_k=float('inf'),
score_threshold=-float('inf')):
"""Gets results from a classification model as a list of ordered classes, based on given scores.
Args:
scores: The output from a classification model. Must be flattened and
dequantized.
top_k (int): The number of top results to return.
score_threshold (float): The score threshold for results. All returned
results have a score greater-than-or-equal-to this value.
Returns:
A list of :obj:`Class` objects representing the classification results,
ordered by scores.
"""
top_k = min(top_k, len(scores))
classes = [
Class(i, scores[i])
for i in np.argpartition(scores, -top_k)[-top_k:]
if scores[i] >= score_threshold
]
return sorted(classes, key=operator.itemgetter(1), reverse=True)


def get_classes(interpreter, top_k=float('inf'), score_threshold=-float('inf')):
"""Gets results from a classification model as a list of ordered classes.
Args:
interpreter: The ``tf.lite.Interpreter`` to query for results.
top_k (int): The number of top results to return.
score_threshold (float): The score threshold for results. All returned
results have a score greater-than-or-equal-to this value.
Returns:
A list of :obj:`Class` objects representing the classification results,
ordered by scores.
"""
return get_classes_from_scores(
get_scores(interpreter), top_k, score_threshold)
100 changes: 100 additions & 0 deletions src/adapters/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Lint as: python3
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to work with any model."""

import numpy as np


def output_tensor(interpreter, i):
"""Gets a model's ith output tensor.
Args:
interpreter: The ``tf.lite.Interpreter`` holding the model.
i (int): The index position of an output tensor.
Returns:
The output tensor at the specified position.
"""
return interpreter.tensor(interpreter.get_output_details()[i]['index'])()


def input_details(interpreter, key):
"""Gets a model's input details by specified key.
Args:
interpreter: The ``tf.lite.Interpreter`` holding the model.
key (int): The index position of an input tensor.
Returns:
The input details.
"""
return interpreter.get_input_details()[0][key]


def input_size(interpreter):
"""Gets a model's input size as (width, height) tuple.
Args:
interpreter: The ``tf.lite.Interpreter`` holding the model.
Returns:
The input tensor size as (width, height) tuple.
"""
_, height, width, _ = input_details(interpreter, 'shape')
return width, height


def input_tensor(interpreter):
"""Gets a model's input tensor view as numpy array of shape (height, width, 3).
Args:
interpreter: The ``tf.lite.Interpreter`` holding the model.
Returns:
The input tensor view as :obj:`numpy.array` (height, width, 3).
"""
tensor_index = input_details(interpreter, 'index')
return interpreter.tensor(tensor_index)()[0]


def set_input(interpreter, data):
"""Copies data to a model's input tensor.
Args:
interpreter: The ``tf.lite.Interpreter`` to update.
data: The input tensor.
"""
input_tensor(interpreter)[:, :] = data


def set_resized_input(interpreter, size, resize):
"""Copies a resized and properly zero-padded image to a model's input tensor.
Args:
interpreter: The ``tf.lite.Interpreter`` to update.
size (tuple): The original image size as (width, height) tuple.
resize: A function that takes a (width, height) tuple, and returns an
image resized to those dimensions.
Returns:
The resized tensor with zero-padding as tuple
(resized_tensor, resize_ratio).
"""
width, height = input_size(interpreter)
w, h = size
scale = min(width / w, height / h)
w, h = int(w * scale), int(h * scale)
tensor = input_tensor(interpreter)
tensor.fill(0) # padding
_, _, channel = tensor.shape
result = resize((w, h))
tensor[:h, :w] = np.reshape(result, (h, w, channel))
return result, (scale, scale)
Loading

0 comments on commit e13f0c8

Please sign in to comment.