Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CherryPick r2.3] Fix a critical breakage in training argument default value in inference for layers with a default of training=True called in e.g. a Sequential container. #40807

Merged
merged 1 commit into from Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorflow/python/keras/engine/base_layer.py
Expand Up @@ -2891,9 +2891,9 @@ def _init_call_fn_args(self):
self._expects_training_arg = ('training' in call_fn_args or
self._call_accepts_kwargs)
# The default training arg will be any (non-None) default specified in the
# method signature, or `False` if no non-None default is specified.
# method signature, or None if no value is specified.
self._default_training_arg = self._call_fn_arg_defaults.get(
'training') or False
'training')
self._expects_mask_arg = ('mask' in call_fn_args or
self._call_accepts_kwargs)

Expand Down
28 changes: 24 additions & 4 deletions tensorflow/python/keras/engine/base_layer_test.py
Expand Up @@ -650,6 +650,17 @@ def call(self, inputs, training):
else:
return self._nested_layer(inputs) * 0.5

class CustomLayerDefaultTrainingNone(base_layer.Layer):

def __init__(self, nested_layer=None):
self._nested_layer = nested_layer or array_ops.identity

def call(self, inputs, training=None):
if training:
return self._nested_layer(inputs)
else:
return self._nested_layer(inputs) * 0.5

class CustomLayerDefaultTrainingFalse(base_layer.Layer):

def __init__(self, nested_layer=None):
Expand Down Expand Up @@ -701,21 +712,30 @@ def call(self, inputs, training=True):
# Outer layers/models should set the training context implicitly for all
# nested layers, respecting whatever mode the outer layer was run with.
layer = CustomLayerDefaultTrainingTrue(CustomLayerDefaultTrainingFalse())
self.assertAllEqual(layer(x), x)
# No outer value passed: use local defaults
self.assertAllEqual(layer(x), x * 0.25) # Use local default False
# Outer value passed: override local defaults
self.assertAllEqual(layer(x, training=False), x * 0.25)
self.assertAllEqual(layer(x, training=True), x)

layer = CustomLayerDefaultTrainingFalse(CustomLayerDefaultTrainingTrue())
self.assertAllEqual(layer(x), x * 0.25)
# No outer value passed: use local defaults
self.assertAllEqual(layer(x), x) # Use local default True
# Outer value passed: override local defaults
self.assertAllEqual(layer(x, training=False), x * 0.25)
self.assertAllEqual(layer(x, training=True), x)

# If the outer layer `call` doesn't take a training argument at all,
# it'll set the nested scope as inference when no training arg is passed in.
# it'll set the nested scope as None when no training arg is passed in.
# If a training arg is passed in it won't use it directly in `call`, but
# it will set the nested training mode.
layer = CustomLayerNoTrainingArg(CustomLayerDefaultTrainingTrue())
self.assertAllEqual(layer(x), x * 0.5)
self.assertAllEqual(layer(x), x) # Use local default True
self.assertAllEqual(layer(x, training=False), x * 0.5)
self.assertAllEqual(layer(x, training=True), x)

layer = CustomLayerDefaultTrainingNone(CustomLayerDefaultTrainingTrue())
self.assertAllEqual(layer(x), x) # Use local default True
self.assertAllEqual(layer(x, training=False), x * 0.5)
self.assertAllEqual(layer(x, training=True), x)

Expand Down
8 changes: 4 additions & 4 deletions tensorflow/python/keras/engine/functional_test.py
Expand Up @@ -2116,13 +2116,13 @@ def call(self, inputs, training=True):

if context.executing_eagerly():
# In v2, construction still works when no `training` is specified
# When no value passed during construction, it uses the runtime value.
# When no value passed during construction, it uses the local default.
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs)
network = functional.Functional(inputs, outputs)
self.assertAllEqual(network(x, training=True), _call(x, True))
self.assertAllEqual(network(x, training=False), _call(x, False))
self.assertAllEqual(network(x), _call(x, False))
self.assertAllEqual(network(x), _call(x, True)) # Use local default

# `None` value passed positionally during construction is ignored at runtime
inputs = input_layer_lib.Input(10)
Expand All @@ -2131,7 +2131,7 @@ def call(self, inputs, training=True):
self.assertAllEqual(network(x, training=True), _call(x, True))
self.assertAllEqual(network(x, training=False), _call(x, False))
if context.executing_eagerly():
self.assertAllEqual(network(x), _call(x, False))
self.assertAllEqual(network(x), _call(x, True)) # Use local default
else:
# in v1 training would have defaulted to using the `None` inside the layer
# if training is not passed at runtime
Expand All @@ -2144,7 +2144,7 @@ def call(self, inputs, training=True):
self.assertAllEqual(network(x, training=True), _call(x, True))
self.assertAllEqual(network(x, training=False), _call(x, False))
if context.executing_eagerly():
self.assertAllEqual(network(x), _call(x, False))
self.assertAllEqual(network(x), _call(x, True)) # Use local default
else:
# in v1 training would have defaulted to using the `None` inside the layer
# if training is not passed at runtime
Expand Down
Expand Up @@ -26,6 +26,7 @@
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import gen_stateful_random_ops
Expand Down Expand Up @@ -1273,5 +1274,38 @@ def test_config_with_custom_name(self):
self.assertEqual(layer_1.name, layer.name)


@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class LearningPhaseTest(keras_parameterized.TestCase):

def test_plain_call(self):
layer = image_preprocessing.RandomWidth(.5, seed=123)
shape = (12, 12, 3)
img = np.random.random((12,) + shape)
out = layer(img) # Default to training=True
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)

out = layer(img, training=True)
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)

out = layer(img, training=False)
self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape)

def test_call_in_container(self):
layer1 = image_preprocessing.RandomWidth(.5, seed=123)
layer2 = image_preprocessing.RandomHeight(.5, seed=123)
seq = sequential.Sequential([layer1, layer2])

shape = (12, 12, 3)
img = np.random.random((12,) + shape)
out = seq(img) # Default to training=True
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)

out = seq(img, training=True)
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)

out = seq(img, training=False)
self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape)


if __name__ == '__main__':
test.main()