From 29b8358f51e2da1df50bc6b896f360994bc5eacf Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Mon, 6 Aug 2018 18:50:38 -0700 Subject: [PATCH] Handle `self` argument in defun. PiperOrigin-RevId: 207646738 --- tensorflow/python/eager/function.py | 10 ++++++++-- tensorflow/python/eager/function_test.py | 13 +++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 54363ffcba34ec..f315fa296c4df4 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -960,11 +960,17 @@ def __init__(self, self._lock = threading.Lock() fullargspec = tf_inspect.getfullargspec(self._python_function) + if tf_inspect.ismethod(self._python_function): + # Remove `self`: default arguments shouldn't be matched to it. + args = fullargspec.args[1:] + else: + args = fullargspec.args + # A cache mapping from argument name to index, for canonicalizing # arguments that are called in a keyword-like fashion. - self._args_to_indices = {arg: i for i, arg in enumerate(fullargspec.args)} + self._args_to_indices = {arg: i for i, arg in enumerate(args)} # A cache mapping from arg index to default value, for canonicalization. - offset = len(fullargspec.args) - len(fullargspec.defaults or []) + offset = len(args) - len(fullargspec.defaults or []) self._arg_indices_to_default_values = { offset + index: default for index, default in enumerate(fullargspec.defaults or []) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index b9e29635f85303..b7c9334c334617 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1233,6 +1233,19 @@ def two(self, tensor, other=integer): self.assertEqual(one.numpy(), 1.0) self.assertEqual(two.numpy(), 2) + def testDefuningInstanceMethodWithDefaultArgument(self): + + integer = constant_op.constant(2, dtypes.int64) + + class Foo(object): + + @function.defun + def func(self, other=integer): + return other + + foo = Foo() + self.assertEqual(foo.func().numpy(), int(integer)) + def testPythonCallWithSideEffects(self): state = []