Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a single positional argument mode for shape inference in subclass… #20203

Merged
merged 1 commit into from
Jun 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 37 additions & 8 deletions tensorflow/python/keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import print_function

import collections
import enum # pylint: disable=g-bad-import-order
import inspect # Necessary supplement to tf_inspect to deal with variadic args.

import numpy as np
Expand Down Expand Up @@ -50,6 +51,20 @@
from tensorflow.python.util.tf_export import tf_export


class CallConvention(enum.Enum):
"""Calling conventions for passing `Layer` inputs to `Layer.call`."""
# The Layer takes inputs as its first argument, named "inputs" for
# compatibility with the signature of Layer.__call__. This is the mode assumed
# for Layers which are not subclassed Models.
EXPLICIT_INPUTS_ARGUMENT = 1
# The Layer takes a single positional argument, not named "inputs". It's
# treated like an "inputs" argument.
SINGLE_POSITIONAL_ARGUMENT = 2
# The Layer has multiple positional arguments to which its inputs should be
# bound.
POSITIONAL_ARGUMENTS_ARE_INPUTS = 3


@tf_export('keras.layers.Layer')
class Layer(checkpointable.CheckpointableBase):
"""Base layer class.
Expand Down Expand Up @@ -149,7 +164,7 @@ def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
hasattr(self, 'compute_mask'))
self._uses_inputs_arg = True
self._call_convention = CallConvention.EXPLICIT_INPUTS_ARGUMENT

# These lists will be filled via successive calls
# to self._add_inbound_node().
Expand Down Expand Up @@ -793,12 +808,22 @@ def _set_mask_metadata(self, inputs, outputs, previous_mask):
pass # C type such as dict. Masking not supported in this case.

def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):
if args and getattr(self, '_uses_inputs_arg', True):
raise TypeError(
'This Layer takes an `inputs` argument to call(), and only the '
'`inputs` argument may be specified as a positional argument. '
'Pass everything else as a keyword argument (those arguments will'
' not be tracked as inputs to the Layer).')
call_convention = getattr(self, '_call_convention',
CallConvention.EXPLICIT_INPUTS_ARGUMENT)
if args:
if call_convention == CallConvention.EXPLICIT_INPUTS_ARGUMENT:
raise TypeError(
'This Layer takes an `inputs` argument to call(), and only the '
'`inputs` argument may be specified as a positional argument. '
'Pass everything else as a keyword argument (those arguments will'
' not be tracked as inputs to the Layer).')
elif call_convention == CallConvention.SINGLE_POSITIONAL_ARGUMENT:
raise TypeError(
'This Layer takes a single positional argument to call(), which is '
'by convention the inputs argument, and only this argument may be '
'specified as a positional argument. Pass everything else as a '
'keyword argument (those arguments will not be tracked as inputs '
'to the Layer).')

# If the layer returns tensors from its inputs, unmodified,
# we copy them to avoid loss of tensor metadata.
Expand Down Expand Up @@ -834,7 +859,11 @@ def _inputs_from_call_args(self, call_args, call_kwargs):
A tuple of (inputs, non_input_kwargs). These may be the same objects as
were passed in (call_args and call_kwargs).
"""
if getattr(self, '_uses_inputs_arg', True):
call_convention = getattr(self, '_call_convention',
CallConvention.EXPLICIT_INPUTS_ARGUMENT)
if (call_convention in (
CallConvention.EXPLICIT_INPUTS_ARGUMENT,
CallConvention.SINGLE_POSITIONAL_ARGUMENT)):
assert len(call_args) == 1 # TypeError raised earlier in __call__.
return call_args[0], call_kwargs
else:
Expand Down
50 changes: 43 additions & 7 deletions tensorflow/python/keras/engine/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _base_init(self, name=None):
self._in_progress_restore_finalizer = None

def _init_graph_network(self, inputs, outputs, name=None):
self._uses_inputs_arg = True
self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
# Normalize and set self.inputs, self.outputs.
if isinstance(inputs, (list, tuple)):
self.inputs = list(inputs) # Tensor or list of tensors.
Expand Down Expand Up @@ -295,19 +295,55 @@ def _init_graph_network(self, inputs, outputs, name=None):
def _init_subclassed_network(self, name=None):
self._base_init(name=name)
self._is_graph_network = False
call_args = tf_inspect.getargspec(self.call).args
if 'training' in call_args:
call_argspec = tf_inspect.getargspec(self.call)
if 'training' in call_argspec.args:
self._expects_training_arg = True
else:
self._expects_training_arg = False
if 'inputs' in call_args:
self._uses_inputs_arg = True
else:
self._uses_inputs_arg = False
self._call_convention = self._determine_call_convention(call_argspec)
self.outputs = None
self.inputs = None
self.built = False

def _determine_call_convention(self, call_argspec):
"""Decides how `self.call()` is invoked. See base_layer.CallConvention."""
if call_argspec.varargs:
may_take_single_argument = False
else:
try:
# Note: tf_inspect doesn't raise a TypeError when regular inspect would,
# so we need to keep in mind that "getcallargs" may have returned
# something even though we under-specified positional arguments.
all_args = tf_inspect.getcallargs(self.call, None)
self_args = set()
for arg_name, obj in all_args.items():
if obj is self:
self_args.add(arg_name)
may_take_single_argument = True
except TypeError:
may_take_single_argument = False
if may_take_single_argument:
# A single positional argument (plus "self") is considered equivalent to
# an "inputs" argument.
all_positional_args = len(call_argspec.args)
if call_argspec.defaults is not None:
all_positional_args -= len(call_argspec.defaults)
non_self_positional_args = all_positional_args
for positional_arg_name in call_argspec.args[:all_positional_args]:
if positional_arg_name in self_args:
non_self_positional_args -= 1
if non_self_positional_args == 1:
if 'inputs' in call_argspec.args[all_positional_args:]:
raise TypeError(
"Model.call() takes a single positional argument (to which "
"inputs are passed by convention) and a separate 'inputs' "
"argument. Unable to determine which arguments are inputs.")
return base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT
if 'inputs' in call_argspec.args:
return base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
else:
return base_layer.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS

def _track_layers(self, layers):
"""Add Checkpointable dependencies on a list of Layers."""
weight_layer_index = 0
Expand Down
27 changes: 16 additions & 11 deletions tensorflow/python/keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@
from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import training_arrays
from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine.base_layer import DeferredTensor
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
Expand Down Expand Up @@ -523,7 +522,7 @@ def handle_metrics(metrics, weights=None):

# Keep track of state updates created by
# stateful metrics (i.e. metrics layers).
if isinstance(metric_fn, Layer) and metric_fn.stateful:
if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
self.stateful_metric_names.append(metric_name)
self.stateful_metric_functions.append(metric_fn)
self.metrics_updates += metric_fn.updates
Expand Down Expand Up @@ -959,11 +958,17 @@ def _set_inputs(self, inputs, training=None):
whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None).
"""
if not getattr(self, '_uses_inputs_arg', True):
call_convention = getattr(
self,
'_call_convention',
base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT)
if call_convention not in (
base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT,
base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT):
raise NotImplementedError(
'Subclassed Models without "inputs" in their call() signatures do '
'not yet support shape inference. File a feature request if this '
'limitation bothers you.')
'Subclassed Models without "inputs" (or single positional arguments) '
'in their call() signatures do not yet support shape inference. File '
'a feature request if this limitation bothers you.')
if self.__class__.__name__ == 'Sequential':
# Note: we can't test whether the model is `Sequential` via `isinstance`
# since `Sequential` depends on `Model`.
Expand Down Expand Up @@ -1020,11 +1025,11 @@ def _eager_set_inputs(self, inputs):
else:
dummy_output_values = [dummy_output_values]
self.outputs = [
DeferredTensor(shape=(None for _ in v.shape),
dtype=v.dtype) for v in dummy_output_values]
base_layer.DeferredTensor(shape=(None for _ in v.shape),
dtype=v.dtype) for v in dummy_output_values]
self.inputs = [
DeferredTensor(shape=(None for _ in v.shape),
dtype=v.dtype) for v in dummy_input_values]
base_layer.DeferredTensor(shape=(None for _ in v.shape),
dtype=v.dtype) for v in dummy_input_values]
self.input_names = [
'input_%d' % (i + 1) for i in range(len(dummy_input_values))]
self.output_names = [
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/python/keras/model_subclassing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(self, use_bn=False, use_dp=False, num_classes=10):
if self.use_bn:
self.bn = keras.layers.BatchNormalization(axis=-1)

def call(self, inputs):
x = self.dense1(inputs)
def call(self, x):
x = self.dense1(x)
if self.use_dp:
x = self.dp(x)
if self.use_bn:
Expand Down