Skip to content

Commit

Permalink
Fix API tflite.OperatorCode.BuiltinCode() compatibility issue
Browse files Browse the repository at this point in the history
`tflite.OperatorCode.BuiltinCode()`: maintains API compability in 2.4.0.
See these threads:
* #9
* #10
* tensorflow/tensorflow#46663
  • Loading branch information
zhenhuaw-me committed Nov 12, 2022
1 parent 51c9d13 commit b686bf9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
7 changes: 2 additions & 5 deletions tests/test_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,10 @@ def test_mobilenet():
op_code = model.OperatorCodes(op.OpcodeIndex())

# The first operator is a convolution.
# The builtin code is extended in TensorFlow 2.4.x from 8 bit to 32 bit.
# See this example of how to handle it gracefully in your code
# https://github.com/apache/tvm/blob/b20b7c4ad4ad3774a42f47614245f8eeabe875cb/python/tvm/relay/frontend/tflite.py#L297-L316
assert(op_code.DeprecatedBuiltinCode() == tflite.BuiltinOperator.CONV_2D)
assert(op_code.BuiltinCode() == tflite.BuiltinOperator.CONV_2D)

# Custom operator need more interface, won't cover here.
assert(op_code.DeprecatedBuiltinCode() != tflite.BuiltinOperator.CUSTOM)
assert(op_code.BuiltinCode() != tflite.BuiltinOperator.CUSTOM)


############# the operator ##################################################
Expand Down
11 changes: 8 additions & 3 deletions tflite/OperatorCode.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ def Version(self):
def BuiltinCode(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
return 0
o = self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)

from tflite.BuiltinOperator import BuiltinOperator
if o < BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES:
return self.DeprecatedBuiltinCode()
else:
return o

def OperatorCodeStart(builder): builder.StartObject(4)
def Start(builder):
Expand All @@ -73,4 +78,4 @@ def AddBuiltinCode(builder, builtinCode):
return OperatorCodeAddBuiltinCode(builder, builtinCode)
def OperatorCodeEnd(builder): return builder.EndObject()
def End(builder):
return OperatorCodeEnd(builder)
return OperatorCodeEnd(builder)

0 comments on commit b686bf9

Please sign in to comment.