Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix the deadlock issue of recursive tf.function.
Replace threading.Lock with threading.RLock to allow recursive tf.function.

PiperOrigin-RevId: 401282729
Change-Id: I3d10416f2eb2c15e2055bb4f4afee3d62bd6c428
  • Loading branch information
JXRiver authored and tensorflower-gardener committed Oct 6, 2021
1 parent 0fec20b commit afac815
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tensorflow/python/eager/def_function.py
Expand Up @@ -572,7 +572,7 @@ def __init__(self,
ValueError: if `input_signature` is not None and the `python_function`'s
argspec has keyword arguments.
"""
self._lock = threading.Lock()
self._lock = threading.RLock()
self._python_function = python_function
self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
python_function,
Expand Down Expand Up @@ -613,7 +613,7 @@ def __getstate__(self):
def __setstate__(self, state):
"""Restore from pickled state."""
self.__dict__ = state
self._lock = threading.Lock()
self._lock = threading.RLock()
self._descriptor_cache = weakref.WeakKeyDictionary()
self._key_for_call_stats = self._get_key_for_call_stats()

Expand Down
113 changes: 113 additions & 0 deletions tensorflow/python/eager/def_function_test.py
Expand Up @@ -25,6 +25,7 @@
from six.moves import range

from tensorflow.python.autograph.core import converter
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import constant_op
Expand All @@ -36,6 +37,7 @@
from tensorflow.python.framework import test_util
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
Expand Down Expand Up @@ -1261,6 +1263,117 @@ def testDouble(self, a):
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)

def test_recursive_tf_function(self):

@def_function.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
return 1

self.assertEqual(recursive_fn(5).numpy(), 1)

def test_recursive_tf_function_with_gradients(self):

@def_function.function
def recursive_fn(n, x):
if n > 0:
return n * recursive_fn(n - 1, x)
else:
return x

x = variables.Variable(1.0)
with backprop.GradientTape() as tape:
g = recursive_fn(5, x)

dg_dx = tape.gradient(g, x)
self.assertEqual(dg_dx.numpy(), 120)

def test_recursive_python_function(self):

def recursive_py_fn(n):
if n > 0:
return recursive_py_fn(n - 1)
return 1

@def_function.function
def recursive_fn(n):
return recursive_py_fn(n)

self.assertEqual(recursive_fn(5).numpy(), 1)

def test_recursive_python_function_with_gradients(self):

def recursive_py_fn(n, x):
if n > 0:
return n * recursive_py_fn(n - 1, x)
return x

@def_function.function
def recursive_fn(n, x):
return recursive_py_fn(n, x)

x = variables.Variable(1.0)
with backprop.GradientTape() as tape:
g = recursive_fn(5, x)

dg_dx = tape.gradient(g, x)
self.assertEqual(dg_dx.numpy(), 120)

def test_recursive_tf_function_call_each_other(self):

@def_function.function
def recursive_fn1(n):
if n <= 1:
return 1
return recursive_fn2(n - 1)

@def_function.function
def recursive_fn2(n):
if n <= 1:
return 2
return recursive_fn1(n - 1)

self.assertEqual(recursive_fn1(5).numpy(), 1)
self.assertEqual(recursive_fn1(6).numpy(), 2)
self.assertEqual(recursive_fn2(5).numpy(), 2)
self.assertEqual(recursive_fn2(6).numpy(), 1)

def test_recursive_tf_function_call_each_other_with_gradients(self):

@def_function.function
def recursive_fn1(n, x):
if n <= 1:
return x
return n * recursive_fn2(n - 1, x)

@def_function.function
def recursive_fn2(n, x):
if n <= 1:
return 2 * x
return n * recursive_fn1(n - 1, x)

x = variables.Variable(1.0)
with backprop.GradientTape() as tape:
g1 = recursive_fn1(5, x)

dg1_dx = tape.gradient(g1, x)
self.assertEqual(dg1_dx.numpy(), 120)

with backprop.GradientTape() as tape:
g2 = recursive_fn2(5, x)

dg2_dx = tape.gradient(g2, x)
self.assertEqual(dg2_dx.numpy(), 240)

def test_recursive_tf_function_with_cond(self):
@def_function.function(autograph=False)
def recursive_fn(n):
return cond_v2.cond_v2(n > 0, recursive_fn(n - 1), 1)

with self.assertRaises(RecursionError):
recursive_fn(constant_op.constant(5))


if __name__ == '__main__':
ops.enable_eager_execution()
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/eager/function.py
Expand Up @@ -3037,7 +3037,7 @@ def __init__(self,
if self.input_signature is not None:
self._hashable_input_signature = hash(self.flat_input_signature)

self._lock = threading.Lock()
self._lock = threading.RLock()
# _descriptor_cache is a of instance of a class to an instance-specific
# `Function`, used to make sure defun-decorated methods create different
# functions for each instance.
Expand Down

0 comments on commit afac815

Please sign in to comment.