Skip to content

Commit

Permalink
Accept "custom_objects" as arguments to `TFLiteConverter.from_keras_m…
Browse files Browse the repository at this point in the history
…odel`

This would be needed to, for example, load a keras model containing a `tensorflow_hub.KerasLayer`

PiperOrigin-RevId: 241821677
  • Loading branch information
MarkDaoust authored and tensorflower-gardener committed Apr 3, 2019
1 parent ea5004e commit 09deaeb
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
7 changes: 5 additions & 2 deletions tensorflow/lite/python/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ def from_keras_model_file(cls,
model_file,
input_arrays=None,
input_shapes=None,
output_arrays=None):
output_arrays=None,
custom_objects=None):
"""Creates a TFLiteConverter class from a tf.keras model file.
Args:
Expand All @@ -592,13 +593,15 @@ def from_keras_model_file(cls,
None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
custom_objects: Dict mapping names (strings) to custom classes or
functions to be considered during model deserialization. (default None)
Returns:
TFLiteConverter class.
"""
_keras.backend.clear_session()
_keras.backend.set_learning_phase(False)
keras_model = _keras.models.load_model(model_file)
keras_model = _keras.models.load_model(model_file, custom_objects)
sess = _keras.backend.get_session()

# Get input and output tensors.
Expand Down
54 changes: 53 additions & 1 deletion tensorflow/lite/python/lite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,16 +1066,33 @@ def testSimpleModelTocoConverter(self):
interpreter.allocate_tensors()


class MyAddLayer(keras.layers.Layer):

def __init__(self, increment, **kwargs):
super(MyAddLayer, self).__init__(**kwargs)
self._increment = increment

def call(self, inputs):
return inputs + self._increment

def get_config(self):
config = super(MyAddLayer, self).get_config()
config['increment'] = self._increment
return config


@test_util.run_v1_only('b/120545219')
class FromKerasFile(test_util.TensorFlowTestCase):

def setUp(self):
keras.backend.clear_session()

def _getSequentialModel(self):
def _getSequentialModel(self, include_custom_layer=False):
with session.Session().as_default():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
if include_custom_layer:
model.add(MyAddLayer(1.0))
model.add(keras.layers.RepeatVector(3))
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
model.compile(
Expand All @@ -1093,6 +1110,10 @@ def _getSequentialModel(self):
keras.models.save_model(model, keras_file)
finally:
os.close(fd)

if include_custom_layer:
custom_objects = {'MyAddLayer': MyAddLayer}
return keras_file, custom_objects
return keras_file

def testSequentialModel(self):
Expand Down Expand Up @@ -1133,6 +1154,37 @@ def testSequentialModel(self):
np.testing.assert_almost_equal(tflite_result, keras_result, 5)
os.remove(keras_file)

def testCustomLayer(self):
"""Test a Sequential tf.keras model with default inputs."""
keras_file, custom_objects = self._getSequentialModel(
include_custom_layer=True)

converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, custom_objects=custom_objects)

tflite_model = converter.convert()
self.assertTrue(tflite_model)

# Check tensor details of converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Check inference of converted model.
input_data = np.array([[1, 2, 3]], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
tflite_result = interpreter.get_tensor(output_details[0]['index'])

keras_model = keras.models.load_model(
keras_file, custom_objects=custom_objects)
keras_result = keras_model.predict(input_data)

np.testing.assert_almost_equal(tflite_result, keras_result, 5)
os.remove(keras_file)

def testSequentialModelInputArray(self):
"""Test a Sequential tf.keras model testing input arrays argument."""
keras_file = self._getSequentialModel()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ tf_class {
}
member_method {
name: "from_keras_model_file"
argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_saved_model"
Expand Down

1 comment on commit 09deaeb

@Lotte1990
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow, this functionality is not in TF2. In TensorFlow 2.1.0 I get:

TypeError: from_keras_model() got an unexpected keyword argument 'custom_objects'
when using tf.lite.TFLiteConverter.from_keras_model

or

AttributeError: type object 'TFLiteConverterV2' has no attribute 'from_keras_model_file'
when using tf.lite.TFLiteConverter.from_keras_model_file

My model was created using TensorFlow 2.1.0 (Keras 2.2.4-tf) so I cannot use the V1 converter.

If I understand correctly, it is currently not possible to use custom objects in .tflite models for TF2. When I use tf.lite.TFLiteConverter.from_keras_model without custom_objects it creates the .tflite file without errors/warnings, but in TensorFlow Lite this results in failure of interpreter->AllocateTensors(). This is probably due to an error in the converted model, because this does not happen when I don't use custom objects. interpreter->AllocateTensors() clearly fails because I didn't define custom_objects, which makes sense. Please copy this functionality to TF2.

Please sign in to comment.