Skip to content

Commit

Permalink
Introduce TFDecorator, a base class for Python TensorFlow decorators.…
Browse files Browse the repository at this point in the history
… Provides basic introspection and "unwrap" services, allowing tooling code to fully 'understand' the wrapped object.

Change: 153854044
  • Loading branch information
Charles Nicholson authored and tensorflower-gardener committed Apr 21, 2017
1 parent c3bf39b commit 8e50419
Show file tree
Hide file tree
Showing 57 changed files with 1,354 additions and 335 deletions.
6 changes: 3 additions & 3 deletions tensorflow/contrib/distributions/python/ops/distribution.py
Expand Up @@ -20,7 +20,6 @@

import abc
import contextlib
import inspect
import types

import numpy as np
Expand All @@ -33,6 +32,7 @@
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_inspect


_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
Expand Down Expand Up @@ -154,12 +154,12 @@ def __new__(mcs, classname, baseclasses, attrs):
if class_special_attr_value is None:
# No _special method available, no need to update the docstring.
continue
class_special_attr_docstring = inspect.getdoc(class_special_attr_value)
class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
if not class_special_attr_docstring:
# No docstring to append.
continue
class_attr_value = _copy_fn(base_attr_value)
class_attr_docstring = inspect.getdoc(base_attr_value)
class_attr_docstring = tf_inspect.getdoc(base_attr_value)
if class_attr_docstring is None:
raise ValueError(
"Expected base class fn to contain a docstring: %s.%s"
Expand Down
Expand Up @@ -18,21 +18,20 @@
from __future__ import division
from __future__ import print_function

import inspect

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_inspect


_DIVERGENCES = {}


def _registered_kl(type_a, type_b):
"""Get the KL function registered for classes a and b."""
hierarchy_a = inspect.getmro(type_a)
hierarchy_b = inspect.getmro(type_b)
hierarchy_a = tf_inspect.getmro(type_a)
hierarchy_b = tf_inspect.getmro(type_b)
dist_to_children = None
kl_fn = None
for mro_to_a, parent_a in enumerate(hierarchy_a):
Expand Down
11 changes: 5 additions & 6 deletions tensorflow/contrib/framework/python/ops/arg_scope.py
Expand Up @@ -61,8 +61,9 @@ def conv2d(*args, **kwargs)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import functools

from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator

__all__ = ['arg_scope',
'add_arg_scope',
Expand Down Expand Up @@ -106,7 +107,7 @@ def _add_op(op):
_DECORATED_OPS[key_op] = _kwarg_names(op)


@contextlib.contextmanager
@tf_contextlib.contextmanager
def arg_scope(list_ops_or_scope, **kwargs):
"""Stores the default arguments for the given set of list_ops.
Expand Down Expand Up @@ -170,7 +171,6 @@ def add_arg_scope(func):
Returns:
A tuple with the decorated function func_with_args().
"""
@functools.wraps(func)
def func_with_args(*args, **kwargs):
current_scope = _current_arg_scope()
current_args = kwargs
Expand All @@ -181,8 +181,7 @@ def func_with_args(*args, **kwargs):
return func(*args, **current_args)
_add_op(func)
setattr(func_with_args, '_key_op', _key_op(func))
setattr(func_with_args, '__doc__', func.__doc__)
return func_with_args
return tf_decorator.make_decorator(func, func_with_args)


def has_arg_scope(func):
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/contrib/keras/python/keras/backend_test.py
Expand Up @@ -18,12 +18,11 @@
from __future__ import division
from __future__ import print_function

import inspect

import numpy as np

from tensorflow.contrib.keras.python import keras
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect


def compare_single_input_op_to_numpy(keras_op,
Expand Down Expand Up @@ -207,7 +206,7 @@ def test_reduction_ops(self):
compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
keras_kwargs={'axis': -1},
np_kwargs={'axis': -1})
if 'keepdims' in inspect.getargspec(keras_op).args:
if 'keepdims' in tf_inspect.getargspec(keras_op).args:
compare_single_input_op_to_numpy(keras_op, np_op,
input_shape=(4, 7, 5),
keras_kwargs={'axis': 1,
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/contrib/keras/python/keras/engine/topology.py
Expand Up @@ -20,7 +20,6 @@
from __future__ import print_function

import copy
import inspect
import json
import os
import re
Expand All @@ -35,6 +34,7 @@
from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import tf_inspect


# pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -584,7 +584,7 @@ def __call__(self, inputs, **kwargs):
user_kwargs = copy.copy(kwargs)
if not _is_all_none(previous_mask):
# The previous layer generated a mask.
if 'mask' in inspect.getargspec(self.call).args:
if 'mask' in tf_inspect.getargspec(self.call).args:
if 'mask' not in kwargs:
# If mask is explicitly passed to __call__,
# we should override the default mask.
Expand Down Expand Up @@ -2166,7 +2166,7 @@ def run_internal_graph(self, inputs, masks=None):
kwargs = {}
if len(computed_data) == 1:
computed_tensor, computed_mask = computed_data[0]
if 'mask' in inspect.getargspec(layer.call).args:
if 'mask' in tf_inspect.getargspec(layer.call).args:
if 'mask' not in kwargs:
kwargs['mask'] = computed_mask
output_tensors = _to_list(layer.call(computed_tensor, **kwargs))
Expand All @@ -2177,7 +2177,7 @@ def run_internal_graph(self, inputs, masks=None):
else:
computed_tensors = [x[0] for x in computed_data]
computed_masks = [x[1] for x in computed_data]
if 'mask' in inspect.getargspec(layer.call).args:
if 'mask' in tf_inspect.getargspec(layer.call).args:
if 'mask' not in kwargs:
kwargs['mask'] = computed_masks
output_tensors = _to_list(layer.call(computed_tensors, **kwargs))
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/contrib/keras/python/keras/layers/core.py
Expand Up @@ -19,7 +19,6 @@
from __future__ import print_function

import copy
import inspect
import types as python_types

import numpy as np
Expand All @@ -35,6 +34,7 @@
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import tf_inspect


class Masking(Layer):
Expand Down Expand Up @@ -595,7 +595,7 @@ def __init__(self, function, mask=None, arguments=None, **kwargs):

def call(self, inputs, mask=None):
arguments = self.arguments
arg_spec = inspect.getargspec(self.function)
arg_spec = tf_inspect.getargspec(self.function)
if 'mask' in arg_spec.args:
arguments['mask'] = mask
return self.function(inputs, **arguments)
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/contrib/keras/python/keras/layers/wrappers.py
Expand Up @@ -20,12 +20,12 @@
from __future__ import print_function

import copy
import inspect

from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.engine import InputSpec
from tensorflow.contrib.keras.python.keras.engine import Layer
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import tf_inspect


class Wrapper(Layer):
Expand Down Expand Up @@ -284,7 +284,7 @@ def _compute_output_shape(self, input_shape):

def call(self, inputs, training=None, mask=None):
kwargs = {}
func_args = inspect.getargspec(self.layer.call).args
func_args = tf_inspect.getargspec(self.layer.call).args
if 'training' in func_args:
kwargs['training'] = training
if 'mask' in func_args:
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/contrib/keras/python/keras/testing_utils.py
Expand Up @@ -18,11 +18,10 @@
from __future__ import division
from __future__ import print_function

import inspect

import numpy as np

from tensorflow.contrib.keras.python import keras
from tensorflow.python.util import tf_inspect


def get_test_data(train_samples,
Expand Down Expand Up @@ -98,7 +97,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
layer.set_weights(weights)

# test and instantiation from weights
if 'weights' in inspect.getargspec(layer_cls.__init__):
if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
kwargs['weights'] = weights
layer = layer_cls(**kwargs)

Expand Down
6 changes: 4 additions & 2 deletions tensorflow/contrib/keras/python/keras/utils/generic_utils.py
Expand Up @@ -17,7 +17,6 @@
from __future__ import division
from __future__ import print_function

import inspect
import marshal
import sys
import time
Expand All @@ -26,6 +25,8 @@
import numpy as np
import six

from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect

_GLOBAL_CUSTOM_OBJECTS = {}

Expand Down Expand Up @@ -116,6 +117,7 @@ def get_custom_objects():


def serialize_keras_object(instance):
_, instance = tf_decorator.unwrap(instance)
if instance is None:
return None
if hasattr(instance, 'get_config'):
Expand Down Expand Up @@ -149,7 +151,7 @@ def deserialize_keras_object(identifier,
if cls is None:
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
if hasattr(cls, 'from_config'):
arg_spec = inspect.getargspec(cls.from_config)
arg_spec = tf_inspect.getargspec(cls.from_config)
if 'custom_objects' in arg_spec.args:
custom_objects = custom_objects or {}
return cls.from_config(
Expand Down
Expand Up @@ -19,13 +19,13 @@
from __future__ import print_function

import copy
import inspect
import types

import numpy as np

from tensorflow.contrib.keras.python.keras.models import Sequential
from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical
from tensorflow.python.util import tf_inspect


class BaseWrapper(object):
Expand Down Expand Up @@ -97,7 +97,7 @@ def check_params(self, params):

legal_params = []
for fn in legal_params_fns:
legal_params += inspect.getargspec(fn)[0]
legal_params += tf_inspect.getargspec(fn)[0]
legal_params = set(legal_params)

for params_name in params:
Expand Down Expand Up @@ -182,7 +182,7 @@ def filter_sk_params(self, fn, override=None):
"""
override = override or {}
res = {}
fn_args = inspect.getargspec(fn)[0]
fn_args = tf_inspect.getargspec(fn)[0]
for name, value in self.sk_params.items():
if name in fn_args:
res.update({name: value})
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py
Expand Up @@ -24,9 +24,9 @@

import collections
import functools
import inspect
import re

from tensorflow.python.util import tf_inspect

# used for register_type_abbreviation and _type_repr below.
_TYPE_ABBREVIATIONS = {}
Expand Down Expand Up @@ -230,7 +230,7 @@ def accepts(*types):

def check_accepts(f):
"""Check the types."""
spec = inspect.getargspec(f)
spec = tf_inspect.getargspec(f)

num_function_arguments = len(spec.args)
if len(types) != num_function_arguments:
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/contrib/learn/python/learn/dataframe/transform.py
Expand Up @@ -24,11 +24,12 @@
from abc import abstractproperty

import collections
import inspect

from .series import Series
from .series import TransformedSeries

from tensorflow.python.util import tf_inspect


def _make_list_of_series(x):
"""Converts `x` into a list of `Series` if possible.
Expand Down Expand Up @@ -120,7 +121,7 @@ def name(self):
def parameters(self):
"""A dict of names to values of properties marked with `@parameter`."""
property_param_names = [name
for name, func in inspect.getmembers(type(self))
for name, func in tf_inspect.getmembers(type(self))
if (hasattr(func, "fget") and hasattr(
getattr(func, "fget"), "is_parameter"))]
return {name: getattr(self, name) for name in property_param_names}
Expand Down
8 changes: 5 additions & 3 deletions tensorflow/contrib/learn/python/learn/estimators/estimator.py
Expand Up @@ -21,7 +21,6 @@

import abc
import copy
import inspect
import os
import tempfile

Expand Down Expand Up @@ -70,6 +69,8 @@
from tensorflow.python.training import saver
from tensorflow.python.training import summary_io
from tensorflow.python.util import compat
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect


AS_ITERABLE_DATE = '2016-09-15'
Expand Down Expand Up @@ -185,14 +186,15 @@ def _model_fn_args(fn):
Raises:
ValueError: if partial function has positionally bound arguments
"""
_, fn = tf_decorator.unwrap(fn)
if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
# Handle functools.partial and similar objects.
return tuple([
arg for arg in inspect.getargspec(fn.func).args[len(fn.args):]
arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):]
if arg not in set(fn.keywords.keys())
])
# Handle function.
return tuple(inspect.getargspec(fn).args)
return tuple(tf_inspect.getargspec(fn).args)


def _get_replica_device_setter(config):
Expand Down
Expand Up @@ -18,7 +18,7 @@
from __future__ import division
from __future__ import print_function

import inspect
from tensorflow.python.util import tf_inspect


def assert_estimator_contract(tester, estimator_class):
Expand All @@ -31,7 +31,7 @@ def assert_estimator_contract(tester, estimator_class):
tester: A tf.test.TestCase.
estimator_class: 'type' object of pre-canned estimator.
"""
attributes = inspect.getmembers(estimator_class)
attributes = tf_inspect.getmembers(estimator_class)
attribute_names = [a[0] for a in attributes]

tester.assertTrue('config' in attribute_names)
Expand Down

0 comments on commit 8e50419

Please sign in to comment.