Skip to content

Commit

Permalink
add support for MicroResourceVariables from python wrapper (#1462)
Browse files Browse the repository at this point in the history
  • Loading branch information
suleshahid committed Oct 13, 2022
1 parent 9dcdde4 commit 81f5208
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ InterpreterWrapper::~InterpreterWrapper() {

InterpreterWrapper::InterpreterWrapper(
PyObject* model_data, const std::vector<std::string>& registerers_by_name,
size_t arena_size) {
size_t arena_size, int num_resource_variables) {
interpreter_ = nullptr;

// `model_data` is used as a raw pointer beyond the scope of this
Expand All @@ -221,7 +221,11 @@ InterpreterWrapper::InterpreterWrapper(

const Model* model = GetModel(buf);
model_ = model_data;
memory_arena_ = std::unique_ptr<uint8_t[]>(new uint8_t[arena_size]);
allocator_ = MicroAllocator::Create(new uint8_t[arena_size], arena_size);
resource_variables_ = nullptr;
if (num_resource_variables > 0)
resource_variables_ =
MicroResourceVariables::Create(allocator_, num_resource_variables);

for (const auto& registerer : registerers_by_name) {
if (!AddCustomOpRegistererByName(registerer.c_str(), &all_ops_resolver_)) {
Expand All @@ -232,8 +236,8 @@ InterpreterWrapper::InterpreterWrapper(
}
}

interpreter_ = new MicroInterpreter(model, all_ops_resolver_,
memory_arena_.get(), arena_size);
interpreter_ = new MicroInterpreter(model, all_ops_resolver_, allocator_,
resource_variables_);

TfLiteStatus status = interpreter_->AllocateTensors();
if (status != kTfLiteOk) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class InterpreterWrapper {
public:
InterpreterWrapper(PyObject* model_data,
const std::vector<std::string>& registerers_by_name,
size_t arena_size);
size_t arena_size, int num_resource_variables);
~InterpreterWrapper();

int Invoke();
Expand All @@ -38,7 +38,8 @@ class InterpreterWrapper {

private:
const PyObject* model_;
std::unique_ptr<uint8_t[]> memory_arena_;
tflite::MicroAllocator* allocator_;
tflite::MicroResourceVariables* resource_variables_;
tflite::AllOpsResolver all_ops_resolver_;
tflite::MicroInterpreter* interpreter_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ PYBIND11_MODULE(interpreter_wrapper_pybind, m) {
py::class_<InterpreterWrapper>(m, "InterpreterWrapper")
.def(py::init([](const py::bytes& data,
const std::vector<std::string>& registerers_by_name,
size_t arena_size) {
return std::unique_ptr<InterpreterWrapper>(new InterpreterWrapper(
data.ptr(), registerers_by_name, arena_size));
size_t arena_size, int num_resource_variables) {
return std::unique_ptr<InterpreterWrapper>(
new InterpreterWrapper(data.ptr(), registerers_by_name, arena_size,
num_resource_variables));
}))
.def("Invoke", &InterpreterWrapper::Invoke)
.def("Reset", &InterpreterWrapper::Reset)
Expand Down
32 changes: 26 additions & 6 deletions tensorflow/lite/micro/python/interpreter/src/tflm_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

class Interpreter(object):

def __init__(self, model_data, custom_op_registerers, arena_size):
def __init__(self,
model_data,
custom_op_registerers,
arena_size,
num_resource_variables=0):
if model_data is None:
raise ValueError("Model must not be None")

Expand All @@ -34,10 +38,14 @@ def __init__(self, model_data, custom_op_registerers, arena_size):
arena_size = len(model_data) * 10

self._interpreter = interpreter_wrapper_pybind.InterpreterWrapper(
model_data, custom_op_registerers, arena_size)
model_data, custom_op_registerers, arena_size, num_resource_variables)

@classmethod
def from_file(self, model_path, custom_op_registerers=[], arena_size=None):
def from_file(self,
model_path,
custom_op_registerers=[],
arena_size=None,
num_resource_variables=0):
"""Instantiates a TFLM interpreter from a model .tflite filepath.
Args:
Expand All @@ -46,6 +54,9 @@ def from_file(self, model_path, custom_op_registerers=[], arena_size=None):
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
num_resource_variables: (Only required if using MicroResourceVariables)
The number of resource variables can be found by counting the
ASSIGN_VARIBLE operators in the initialization subgraph.
Returns:
An Interpreter instance
Expand All @@ -56,10 +67,15 @@ def from_file(self, model_path, custom_op_registerers=[], arena_size=None):
with open(model_path, "rb") as f:
model_data = f.read()

return Interpreter(model_data, custom_op_registerers, arena_size)
return Interpreter(model_data, custom_op_registerers, arena_size,
num_resource_variables)

@classmethod
def from_bytes(self, model_data, custom_op_registerers=[], arena_size=None):
def from_bytes(self,
model_data,
custom_op_registerers=[],
arena_size=None,
num_resource_variables=0):
"""Instantiates a TFLM interpreter from a model in byte array.
Args:
Expand All @@ -68,12 +84,16 @@ def from_bytes(self, model_data, custom_op_registerers=[], arena_size=None):
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
num_resource_variables: (Only required if using MicroResourceVariables)
The number of resource variables can be found by counting the
ASSIGN_VARIBLE operators in the initialization subgraph.
Returns:
An Interpreter instance
"""

return Interpreter(model_data, custom_op_registerers, arena_size)
return Interpreter(model_data, custom_op_registerers, arena_size,
num_resource_variables)

def invoke(self):
"""Invoke the TFLM interpreter to run an inference.
Expand Down
18 changes: 15 additions & 3 deletions tensorflow/lite/micro/python/interpreter/tests/interpreter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,13 @@ def testCompareWithTFLite(self):
# TODO: Remove tolerance when the bug is fixed.
self.assertAllLessEqual((tflite_output - tflm_output), 1)

def testModelFromFileAndBufferEqual(self):
def _helperModelFromFileAndBufferEqual(self, number_resource_variables=0):
model_data = generate_test_models.generate_conv_model(True, self.filename)

file_interpreter = tflm_runtime.Interpreter.from_file(self.filename)
bytes_interpreter = tflm_runtime.Interpreter.from_bytes(model_data)
file_interpreter = tflm_runtime.Interpreter.from_file(
self.filename, num_resource_variables=number_resource_variables)
bytes_interpreter = tflm_runtime.Interpreter.from_bytes(
model_data, num_resource_variables=number_resource_variables)

num_steps = 100
for i in range(0, num_steps):
Expand All @@ -196,6 +198,9 @@ def testModelFromFileAndBufferEqual(self):
# Same interpreter and model, should expect all equal
self.assertAllEqual(file_output, bytes_output)

def testModelFromFileAndBufferEqual(self):
self._helperModelFromFileAndBufferEqual()

def testMultipleInterpreters(self):
model_data = generate_test_models.generate_conv_model(False)

Expand Down Expand Up @@ -269,6 +274,13 @@ def testNonExistentCustomOps(self):
interpreter = tflm_runtime.Interpreter.from_bytes(
model_data, custom_op_registerers)

def testResourceVariableFunctionCall(self):
# Both interpreter function call cases should be valid for various
# number of resource variables.
self._helperModelFromFileAndBufferEqual(-2)
self._helperModelFromFileAndBufferEqual(1)
self._helperModelFromFileAndBufferEqual(12)


if __name__ == "__main__":
test.main()

0 comments on commit 81f5208

Please sign in to comment.