Skip to content

Commit

Permalink
Support params/buffers inside cond and map (#102310)
Browse files Browse the repository at this point in the history
With #102022, params and buffers are always treated as special case of free variables. In this PR, I switch cond and map implementation to the this method and deprecate the old tracing mechanism.

Differential Revision: [D46746202](https://our.internmc.facebook.com/intern/diff/D46746202)
Pull Request resolved: #102310
Approved by: https://github.com/avikchaudhuri, https://github.com/zou3519
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Jun 20, 2023
1 parent 1be1f50 commit d4b85f3
Show file tree
Hide file tree
Showing 11 changed files with 563 additions and 277 deletions.
29 changes: 26 additions & 3 deletions functorch/experimental/_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,34 @@ def cond_dense(pred, true_fn, false_fn, operands):
def cond_autograd(pred, true_fn, false_fn, *operands):
# TODO: support autograd
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
assert all(not f.requires_grad for f in flat_operands
if isinstance(f, torch.Tensor))

requires_grad = any(
isinstance(arg, torch.Tensor) and arg.requires_grad
for arg in flat_operands
)

with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)):
return cond(pred, true_fn, false_fn, *operands)
result = cond(pred, true_fn, false_fn, *operands)

# If there is requires_grad, we delay the error until backward pass
if requires_grad:
# cond can only return one value
err_fn = torch._C._functions.DelayedError(
b"NYI: torch.cond doesn't support autograd",
1,
)
# Create aliases of the output that has requires_grad=True. We need
# at least one of the inputs to err_fn to require grad so that the
# output will have a grad_fn.

def fake_requires_grad(var):
if var is not None:
var = var.detach()
var.requires_grad = True
return var
return err_fn(fake_requires_grad(result))

return result


@cond.py_impl(ProxyTorchDispatchMode)
Expand Down
275 changes: 204 additions & 71 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,22 +1585,7 @@ def forward(self, pred, x):
def test_export_with_cond_closure(self):
from functorch.experimental.control_flow import cond

class ModuleAccidentallyPassingError(torch.nn.Module):
# error
def __init__(self):
super().__init__()

def forward(self, pred, x):
def true_fn(val):
return x * 2

def false_fn(val):
return x - 2

return cond(pred, true_fn, false_fn, [x])

class ModuleAccidentallyPassingFixed(torch.nn.Module):
# error
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -1613,22 +1598,7 @@ def false_fn(x):

return cond(pred, true_fn, false_fn, [x])

class ModuleNoAccidentError(torch.nn.Module):
# error
def __init__(self):
super().__init__()

def forward(self, pred, x):
def true_fn(val):
return x * 2

def false_fn(val):
return x - 2

return cond(pred, true_fn, false_fn, [x + 1])

class ModuleNoAccidentFixed(torch.nn.Module):
# error
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -1641,25 +1611,7 @@ def false_fn(x):

return cond(pred, true_fn, false_fn, [x + 1])

class ModuleClosureReproError(torch.nn.Module):
# error
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)

def forward(self, pred, x):
y = x + x

def true_fn(val):
return self.linear(val) * (x + y)

def false_fn(val):
return val * (y - x)

return cond(pred, true_fn, false_fn, [x])

class ModuleClosureReproFixed(torch.nn.Module):
# error
class FooBar(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
Expand All @@ -1675,25 +1627,7 @@ def false_fn(x, y):

return cond(pred, true_fn, false_fn, [x, y])

for Module in [
ModuleAccidentallyPassingError,
ModuleNoAccidentError,
ModuleClosureReproError,
]:
mod = Module()
x = torch.randn([3, 3])
pred = torch.tensor(x[0][0].item() < 0)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Cannot create subgraph for nested function.*because it closes over variables",
):
torch._dynamo.export(mod.forward, pred, x)

for Module in [
ModuleAccidentallyPassingFixed,
ModuleNoAccidentFixed,
ModuleClosureReproFixed,
]:
for Module in [Foo, Bar, FooBar]:
mod = Module()
x = torch.randn([3, 3])
pred = torch.tensor(x[0][0].item() < 0)
Expand Down Expand Up @@ -2945,7 +2879,7 @@ def f_branch_return_non_tensor(x):
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected branch out type to be a single tensor",
"HigherOrderOperator can't return non-tensor scalar output",
):
torch._dynamo.export(
f_branch_return_non_tensor,
Expand Down Expand Up @@ -3110,6 +3044,205 @@ def forward(self, x):
msg="test_capture_symbolic_tracing_aten_graph_" + str(aten_graph),
)

def test_cond_op_param_buffer_lifted(self):
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.zeros(6, 4))

def forward(self):
return self.buffer1.sum()

class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer2", torch.ones(6, 4))

def forward(self):
return self.buffer2.sum()

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = A()
self.b = B()

def forward(self, x):
def true_fn(x):
return x.cos() + self.a()

def false_fn(x):
return x.sin() + self.b()

return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),)

gm, _ = torch._dynamo.export(M(), torch.ones(6, 4), aten_graph=False)
self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4)))
self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4)))

def test_nested_cond_op_param_buffer_lifted(self):
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.zeros(6, 4))

def forward(self):
return self.buffer1.sum()

class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer2", torch.ones(6, 4))

def forward(self):
return self.buffer2.sum()

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = A()
self.b = B()

def forward(self, x):
def true_true_fn(x):
return x.cos() + self.a()

def true_false_fn(x):
return x.cos() + self.a() + 1

def true_fn(x):
return cond(x.shape[0] > 5, true_true_fn, true_false_fn, [x])

def false_fn(x):
return x.sin() + self.b()

return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),)

gm, _ = torch._dynamo.export(M(), torch.ones(6, 4), aten_graph=False)
self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4)))
self.assertEqual(gm(torch.ones(5, 4)), M()(torch.ones(5, 4)))
self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4)))

def test_map_cond_param_buffer_lifted(self):
from functorch.experimental.control_flow import cond, map

class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.zeros(6, 4))

def forward(self):
return self.buffer1.sum()

class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer2", torch.ones(6, 4))

def forward(self):
return self.buffer2.sum()

class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = A()
self.b = B()

def inner(self, x, pred):
def true_fn(x):
return x + x + self.a()

def false_fn(x):
return x * x + self.b()

return cond(pred, true_fn, false_fn, [x])

def forward(self, pred, xs):
def body(x, pred):
return self.inner(x, pred) + self.b()

return map(body, xs, pred)

mod = Module()
x = torch.randn(3, 2, 1)
pred_x = torch.tensor(True)

y = torch.randn(4, 3, 2)
pred_y = torch.tensor(False)
real_result = mod(pred_y, y)

out_graph, _ = torch._dynamo.export(mod, pred_x, x)
self.assertEqual(real_result, out_graph(pred_y, y))

def test_cond_free_variables_overlapping(self):
from functorch.experimental.control_flow import cond

class Module(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, pred, x):
a = torch.ones(6, 4)
b = torch.ones(6, 4)
c = torch.ones(6, 4)
d = torch.ones(6, 4)

def true_fn(x):
return x + x + a.cos() + b.cos() + d.cos()

def false_fn(x):
return x * x + a.sin() + b.sin() + c.sin()

return cond(pred, true_fn, false_fn, [x])

mod = Module()
x = torch.ones(6, 4)
pred_x = torch.tensor(True)

out_graph, _ = torch._dynamo.export(mod, pred_x, x)
self.assertExpectedInline(
out_graph.code.strip(),
"""\
def forward(self, pred, x):
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
ones = torch.ones(6, 4)
ones_1 = torch.ones(6, 4)
ones_2 = torch.ones(6, 4)
ones_3 = torch.ones(6, 4)
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.cond(arg0, cond_true_0, cond_false_0, [arg1, ones, ones_1, ones_3, ones, ones_1, ones_2]); arg0 = cond_true_0 = cond_false_0 = arg1 = ones = ones_1 = ones_3 = ones_2 = None
return pytree.tree_unflatten([cond], self._out_spec)""", # noqa: B950,E122
)

self.assertExpectedInline(
out_graph.cond_true_0.code.strip(),
"""\
def forward(self, l_x_, ones, ones_1, ones_3, ones_2_false_branch, ones_1_false_branch, ones_false_branch):
add = l_x_ + l_x_; l_x_ = None
cos = ones.cos(); ones = None
add_1 = add + cos; add = cos = None
cos_1 = ones_1.cos(); ones_1 = None
add_2 = add_1 + cos_1; add_1 = cos_1 = None
cos_2 = ones_3.cos(); ones_3 = None
add_3 = add_2 + cos_2; add_2 = cos_2 = None
return add_3""",
)

self.assertExpectedInline(
out_graph.cond_false_0.code.strip(),
"""\
def forward(self, l_x_, ones_3_true_branch, ones_1_true_branch, ones_true_branch, ones, ones_1, ones_2):
mul = l_x_ * l_x_; l_x_ = None
sin = ones.sin(); ones = None
add = mul + sin; mul = sin = None
sin_1 = ones_1.sin(); ones_1 = None
add_1 = add + sin_1; add = sin_1 = None
sin_2 = ones_2.sin(); ones_2 = None
add_2 = add_1 + sin_2; add_1 = sin_2 = None
return add_2""",
)


common_utils.instantiate_parametrized_tests(ExportTests)

Expand Down

0 comments on commit d4b85f3

Please sign in to comment.