Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add globs from Lambda before building it #18926

Merged
merged 3 commits into from
Jun 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions tensorflow/python/estimator/keras_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.applications import mobilenet
from tensorflow.python.keras.optimizers import SGD
from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
Expand Down Expand Up @@ -146,13 +146,13 @@ def randomize_io_type(array, name):
def multi_inputs_multi_outputs_model():
a = keras.layers.Input(shape=(16,), name='input_a')
b = keras.layers.Input(shape=(16,), name='input_b')
m = keras.layers.Input(shape=(8,), dtype='bool', name='input_m')
m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
dense = keras.layers.Dense(8, name='dense_1')

a_2 = dense(a)
# Apply a mask
s_2 = keras.layers.Lambda(lambda k:
K.switch(k[0], k[1], K.zeros_like(k[1])))([m, a_2])
# Read m
m_2 = keras.layers.Lambda(gen_parsing_ops.string_to_number)(m)
s_2 = keras.layers.Lambda(lambda k: k[0] * k[1])([m_2, a_2])
b_2 = dense(b)
merged = keras.layers.concatenate([s_2, b_2], name='merge')
c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
Expand Down Expand Up @@ -372,13 +372,13 @@ def test_multi_inputs_multi_outputs(self):

def train_input_fn():
input_dict = {'input_a': a_train, 'input_b': b_train,
'input_m': input_m_train > 0}
'input_m': input_m_train.astype(np.str)}
output_dict = {'dense_2': c_train, 'dense_3': d_train}
return input_dict, output_dict

def eval_input_fn():
input_dict = {'input_a': a_test, 'input_b': b_test,
'input_m': input_m_test > 0}
'input_m': input_m_test.astype(np.str)}
output_dict = {'dense_2': c_test, 'dense_3': d_test}
return input_dict, output_dict

Expand Down
26 changes: 25 additions & 1 deletion tensorflow/python/keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from __future__ import print_function

import copy
import sys
import types as python_types
import warnings

import numpy as np

Expand Down Expand Up @@ -714,28 +716,34 @@ def compute_mask(self, inputs, mask=None):
return self.mask

def get_config(self):
module = self.function.__module__
if isinstance(self.function, python_types.LambdaType):
function = generic_utils.func_dump(self.function)
function_type = 'lambda'
else:
function = self.function.__name__
function_type = 'function'

output_shape_module = None
if isinstance(self._output_shape, python_types.LambdaType):
output_shape = generic_utils.func_dump(self._output_shape)
output_shape_type = 'lambda'
output_shape_module = self._output_shape.__module__
elif callable(self._output_shape):
output_shape = self._output_shape.__name__
output_shape_type = 'function'
output_shape_module = self._output_shape.__module__
else:
output_shape = self._output_shape
output_shape_type = 'raw'

config = {
'function': function,
'module': module,
'function_type': function_type,
'output_shape': output_shape,
'output_shape_type': output_shape_type,
'output_shape_module': output_shape_module,
'arguments': self.arguments
}
base_config = super(Lambda, self).get_config()
Expand All @@ -745,8 +753,16 @@ def get_config(self):
def from_config(cls, config, custom_objects=None):
config = config.copy()
globs = globals()
module = config.pop('module', None)
if module in sys.modules:
globs.update(sys.modules[module].__dict__)
elif module is not None:
# Note: we don't know the name of the function if it's a lambda.
warnings.warn('{} is not loaded, but a Lambda layer uses it. '
'It may cause errors.'.format(module)
, UserWarning)
if custom_objects:
globs = dict(list(globs.items()) + list(custom_objects.items()))
globs.update(custom_objects)
function_type = config.pop('function_type')
if function_type == 'function':
# Simple lookup in custom objects
Expand All @@ -760,6 +776,14 @@ def from_config(cls, config, custom_objects=None):
else:
raise TypeError('Unknown function type:', function_type)

output_shape_module = config.pop('output_shape_module', None)
if output_shape_module in sys.modules:
globs.update(sys.modules[output_shape_module].__dict__)
elif output_shape_module is not None:
# Note: we don't know the name of the function if it's a lambda.
warnings.warn('{} is not loaded, but a Lambda layer uses it. '
'It may cause errors.'.format(output_shape_module)
, UserWarning)
output_shape_type = config.pop('output_shape_type')
if output_shape_type == 'function':
# Simple lookup in custom objects
Expand Down