Skip to content

Commit

Permalink
Handle self argument in defun.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 207646738
  • Loading branch information
akshayka authored and tensorflower-gardener committed Aug 7, 2018
1 parent 0fc1de7 commit 29b8358
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tensorflow/python/eager/function.py
Expand Up @@ -960,11 +960,17 @@ def __init__(self,
self._lock = threading.Lock()

fullargspec = tf_inspect.getfullargspec(self._python_function)
if tf_inspect.ismethod(self._python_function):
# Remove `self`: default arguments shouldn't be matched to it.
args = fullargspec.args[1:]
else:
args = fullargspec.args

# A cache mapping from argument name to index, for canonicalizing
# arguments that are called in a keyword-like fashion.
self._args_to_indices = {arg: i for i, arg in enumerate(fullargspec.args)}
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
# A cache mapping from arg index to default value, for canonicalization.
offset = len(fullargspec.args) - len(fullargspec.defaults or [])
offset = len(args) - len(fullargspec.defaults or [])
self._arg_indices_to_default_values = {
offset + index: default
for index, default in enumerate(fullargspec.defaults or [])
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/python/eager/function_test.py
Expand Up @@ -1233,6 +1233,19 @@ def two(self, tensor, other=integer):
self.assertEqual(one.numpy(), 1.0)
self.assertEqual(two.numpy(), 2)

def testDefuningInstanceMethodWithDefaultArgument(self):

integer = constant_op.constant(2, dtypes.int64)

class Foo(object):

@function.defun
def func(self, other=integer):
return other

foo = Foo()
self.assertEqual(foo.func().numpy(), int(integer))

def testPythonCallWithSideEffects(self):
state = []

Expand Down

0 comments on commit 29b8358

Please sign in to comment.