Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch cond operator, python dispatch, pyoperator #83154

Closed
wants to merge 86 commits into from
Closed
Show file tree
Hide file tree
Changes from 71 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
ab64805
Initial and reset
voznesenskym Aug 9, 2022
599d91a
Unset
voznesenskym Aug 9, 2022
2e3e976
Hax
voznesenskym Aug 10, 2022
d92b6f4
Impl wip
voznesenskym Aug 10, 2022
0d5e2bf
b etter outs
voznesenskym Aug 10, 2022
e3db51f
Work a little better
voznesenskym Aug 10, 2022
ccec9ef
less wrong
voznesenskym Aug 10, 2022
d472fa2
Hack up make_fx to natively support varargs
ezyang Aug 10, 2022
4071b1b
fixes
voznesenskym Aug 11, 2022
dab9822
Progress, works unnested
voznesenskym Aug 11, 2022
f2d885c
Feedback from Ed
voznesenskym Aug 11, 2022
56db44a
progress
voznesenskym Aug 12, 2022
28db0f8
Add data_dependent_output tag; generalize proxy tensor to test it
ezyang Aug 12, 2022
9dfeaee
Update base for Update on "Add data_dependent_output tag; generalize …
ezyang Aug 12, 2022
a3efe51
Update on "Add data_dependent_output tag; generalize proxy tensor to …
ezyang Aug 12, 2022
2ea2860
Add *_only and all/any pytree utilities
ezyang Aug 12, 2022
af5fe52
Update base for Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
255a288
Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
2f13c14
Working
voznesenskym Aug 12, 2022
97cdd25
lint
voznesenskym Aug 12, 2022
247a68c
Update base for Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
946d4fe
Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
1aa70d2
Update base for Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
fb4b669
Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
cdee501
RFC: Delete ProxyTensor wrapper subclass
ezyang Aug 12, 2022
ca203af
Update on "RFC: Delete ProxyTensor wrapper subclass"
ezyang Aug 12, 2022
f38bc6f
Update on "RFC: Delete ProxyTensor wrapper subclass"
ezyang Aug 13, 2022
670191a
Rearrange the chairs
voznesenskym Aug 13, 2022
7e903c7
Tests. real ones
voznesenskym Aug 13, 2022
63ebfb3
lints, fixes
voznesenskym Aug 13, 2022
0807c85
Linter
voznesenskym Aug 13, 2022
2bf2686
Refactors
voznesenskym Aug 13, 2022
343e9af
More cleanup
voznesenskym Aug 13, 2022
208c5b0
Update base for Update on "RFC: Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
a5cb0fc
Update on "RFC: Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
d67a117
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
a51fcb5
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
6f29257
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
dcba0b4
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
6fa4bd4
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
e7521fe
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
d634ea0
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 16, 2022
776cf3a
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 16, 2022
f8cf09c
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
2415bbe
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
3974ecd
get rid of the reentrant parts
voznesenskym Aug 16, 2022
f10df36
first small feedback
voznesenskym Aug 16, 2022
a5b874a
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
010700b
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
651f351
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
e4f5ff8
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
9cebd47
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
68018c0
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
aee98ad
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
20cf7ab
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
ea58ef6
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
604ec98
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
7ceaf24
Still not working, progress
voznesenskym Aug 17, 2022
d55baf4
cMerge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 17, 2022
aeae58c
Merge branch 'gh/ezyang/1326/head' of github.com:pytorch/pytorch into…
voznesenskym Aug 17, 2022
f009a58
Fixes
voznesenskym Aug 17, 2022
1d62b06
lint
voznesenskym Aug 17, 2022
d593d25
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 18, 2022
1e98c87
feedback pass
voznesenskym Aug 18, 2022
9de2972
Revert "feedback pass"
voznesenskym Aug 19, 2022
1bde5ee
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 19, 2022
52ffdc6
Make it less bad, safeguards, questionable choices
voznesenskym Aug 20, 2022
ee61ade
Add back accidently rm important line
voznesenskym Aug 20, 2022
2059523
Feedback pass
voznesenskym Aug 22, 2022
2b7b876
Feedback pass 2
voznesenskym Aug 22, 2022
366591c
Feedback pass 3
voznesenskym Aug 23, 2022
50181ee
Feedback pass, cleanup, improvements
voznesenskym Aug 23, 2022
663f8ca
types
voznesenskym Aug 23, 2022
10db393
Lints
voznesenskym Aug 23, 2022
8b56428
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 23, 2022
4c115b0
Trace better
voznesenskym Aug 24, 2022
f2445ce
Better tests, less execution
voznesenskym Aug 24, 2022
c4fe703
Cleanups, improvements
voznesenskym Aug 24, 2022
a346b3d
Feedback cleanup
voznesenskym Aug 24, 2022
fc5d7ff
Last few little nits
voznesenskym Aug 24, 2022
0840e56
Fix test
voznesenskym Aug 24, 2022
ddd8f0a
missed a spot
voznesenskym Aug 24, 2022
4000d77
Lint the test
voznesenskym Aug 25, 2022
25e0701
Shuffle stuff around
voznesenskym Aug 25, 2022
1c152bb
More lint stuff
voznesenskym Aug 25, 2022
670f3d2
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ namespace detail {
* they have been registered as fallthrough. The set of excluded backends
* varies from operator, as some operators may have overridden the
* fallthrough with custom behavior.
*
* Note - this should maintain identical impl to the py dispatcher key extraction logic
* at pytorch/torch/dispatcher.py
*/
struct TORCH_API DispatchKeyExtractor final {
public:
Expand Down
143 changes: 143 additions & 0 deletions functorch/functorch/experimental/cond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import torch
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
from functorch.experimental.ops import PyOperator, fallthrough_fn
from torch.utils._pytree import tree_flatten
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot
import torch.utils._pytree as pytree
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
import random
import string

"""
We're going to define a `cond` operation.
In order to do this, we need implementations for each of the dispatch keys.
"""
from contextlib import contextmanager

# TODO(voz): Move out somewhere else once other py dispatched ops need it
@contextmanager
def suspend_mode(mode):
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
assert(mode is not None), "Cannot suspend None mode"
assert(isinstance(mode, TorchDispatchMode)), f"Unexpected mode type {mode.__class__}"
torch._C._set_torch_dispatch_mode(None)
try:
yield
finally:
torch._C._set_torch_dispatch_mode(mode)

@contextmanager
def enable_mode(mode):
curr_mode = torch._C._get_torch_dispatch_mode()
torch._C._set_torch_dispatch_mode(mode)
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
try:
yield
finally:
torch._C._set_torch_dispatch_mode(curr_mode)


def trace_cond(proxy_mode, func_overload, args, kwargs=None):
assert kwargs is None or not kwargs
pred, true_fn, false_fn, operands = args

def _unwrap_proxy(e):
proxy_or_e = get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy )
return proxy_or_e
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved

assert isinstance(operands, list), "Cond operands must be a list of tensors"
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"


true_graph = get_isolated_graphmodule(true_fn, operands, {})
false_graph = get_isolated_graphmodule(false_fn, operands, {})

# There are probably better ways - I know that create_arg has some self incrementing name
# magic to it, but since we explicitly have to get the name for register_module,
# I was not sure how to do that. This kinda simulates it.
next_name = None
i = 0
while not next_name:
candidate = f"true_graph_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
i += 1
else:
next_name = candidate

random_slug = ''.join(random.choices(string.digits, k=5))
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
true_name = next_name
false_name = f"false_graph_{i}"

proxy_mode.tracer.root.register_module(true_name, true_graph)
proxy_mode.tracer.root.register_module(false_name, false_graph)

with no_dispatch():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry @Chillee

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The no_dispatch here shouldn't be necessary, as you have already dispelled the ambient proxy mode upon entry to this function

# This is not amazing.
# However, if we have nested operators that have a call_function
# in their graph that is not a torch op (ex: see conditional below, nested cond)
# we cannot get metadata for it from just looking at out vars.
# The reason is that an operation on the output of such an op is not
# evalauted as a torch.Tensor.
# So we execute the real true and false fn here and compare metadata
# inp_ops = [o for o in operands]
true_result = true_graph(operands)
false_result = false_graph(operands)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this shouldn't be necessary. You're probably getting hobbled by get_isolated_graphmodule API. That API should have the real tensor output from when it traced through, you just need to get it to disgorge that information so you can use it directly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, we can try that! Let me play with it a bit, I don't love this as it is.

def recursive_compare_same(a, b):
assert(type(a) == type(b))
if isinstance(a, torch.Tensor):
assert(a.dtype == b.dtype)
assert(a.size() == b.size())
assert(a.stride() == b.stride())
assert(a.device == b.device)
elif isinstance(a, (list, tuple)):
assert(len(a) == len(b))
for i in range(0, len(a)):
recursive_compare_same(a[i], b[i])

recursive_compare_same(true_result, false_result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't tell the difference between a tuple of three Tensors and a list of three Tensors, right?

Right way to do this might be to compare the specs after tree_flattening:

list_a, a_spec = tree_flatten(a)
list_b, b_spec = tree_flatten(b)
assert a_spec == b_spec
for ai, bi in zip(list_a, list_b):
  ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, agreed.


args = (pred, true_graph, false_graph, operands)

proxy_args = pytree.tree_map(_unwrap_proxy, args)

proxy_res = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, kwargs,
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
name="conditional")
return proxy_res
zou3519 marked this conversation as resolved.
Show resolved Hide resolved


def cond_dense(pred, true_fn, false_fn, *operands):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a minor improvement, perhaps you would like to support kwargs too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of didn't want kwargs in this API, but if pressed, have no qualms around adding them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want people to manually flatten before passing into cond, you should accept operands as a list of tensors and not a vararg function. Vararg functions in torch.ops is not a thing.

mode = torch._C._get_torch_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU key"
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)


def cond_autograd(pred, true_fn, false_fn, *operands):
# TODO: support autograd
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
assert all([not f.requires_grad for f in flat_operands
if isinstance(f, torch.Tensor)])
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved

guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
return cond(pred, true_fn, false_fn, *operands)


def python_fallback(op):
def inner(*args, **kwargs):
mode = torch._C._get_torch_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
with suspend_mode(mode):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mode suspension mechanic here is not quite right, ordinarily, we would call into TorchDispatchMode which would take care of reapplying the inner mode before the inside of the mode function gets run.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate? I don't think I have the full mental model for modes right.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a mode stack. When you call into the handler for a mode, you pop that mode off the stack before you do it, so that internal calls in the handler go to the next mode in the stack.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

res = trace_cond(mode, op, args, kwargs)
return res
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This registration is not quite right. To register to the Python key implies that the implementation works for any Python mode and for any tensor subclass with __torch_dispatch__ on it. To actually make good on this promise, you must actually call the __torch_dispatch__ method that was provided in this way.

What this implementation does, instead, is call into a hard-coded implementation of the mode that is suitable for ProxyTensor. This is fine, but the python fallback mustn't imply that it can be used in other situations, which is what it will blindly crash through here. So you should at least assert that the mode stack has exactly one mode in it and it is proxy tensor. Which, speaking of which, means that fake tensor mode is not going to work properly in that case...

The easiest fix may just to be to write the generalized version of this logic. That just means recreating the python fallback key in C++ in PythonFallbackKernel.cpp; but you get to pass functions in the args instead. Then, ProxyTensorMode would be responsible for seeing that cond is being called and handling it properly. You could put your trace_cond implementation directly in ProxyMode, or you could add a little registration system for custom trace handling and then register your trace cond that way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of a registration system. A much earlier version of this PR had this going through __torch_dispatch__ on the mode, which in turn would call back to this in one version, and had the logic inside Proxy in another.

That just means recreating the python fallback key in C++ in PythonFallbackKernel.cpp; but you get to pass functions in the args instead.

What does moving it to C++ give us? Just sharing more of the stack with how we dispatch today?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does moving it to C++ give us? Just sharing more of the stack with how we dispatch today?

Sorry, I wasn't clear. I meant porting that current C++ logic into Python, the same way you did with the dispatcher

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yeah, we absolutely should.


return inner


cond = PyOperator('cond')
cond.impl(DispatchKey.CPU, cond_dense)
cond.impl(DispatchKey.Python, python_fallback(cond))
cond.impl(DispatchKey.PythonTLSSnapshot, fallthrough_fn)
cond.impl(DispatchKey.AutogradCPU, cond_autograd)
cond.impl(DispatchKey.ADInplaceOrView, fallthrough_fn)
cond.impl(DispatchKey.BackendSelect, fallthrough_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minor improvement, that I quite like from the Python operator registration API (we also need to be careful about names haha, maybe we'll call them Operator and PyOperator) is to do the registrations via decorator. This makes it immediately clear that the registration is for a particular dispatch key.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that's reasonable. I kind of prefer this kind of API over decorators, but am open to moving to decorators.

37 changes: 37 additions & 0 deletions functorch/functorch/experimental/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from torch.dispatch.dispatcher import dispatcher_singleton, to_flat_tuple, has_torch_function, compute_keyset
from torch._C import DispatchKey, DispatchKeySet
from torch.nn.functional import handle_torch_function

import torch._C as _C

class PyOperator:
def __init__(self, name):
self.name = name
self.table = {}

self.__name__ = name

def impl(self, dispatch_key, fn):
assert dispatch_key not in self.table
if fn is fallthrough_fn:
self.table[dispatch_key] = fn(self, dispatch_key)
else:
self.table[dispatch_key] = fn

def lookup(self, keyset):
dispatch_key = keyset.highestPriorityTypeId()
return self.table[dispatch_key]

def __call__(self, *args, **kwargs):
flat_args = to_flat_tuple(args, kwargs)
if has_torch_function(flat_args):
return handle_torch_function(self, flat_args, *args, **kwargs)

return dispatcher_singleton.call(self, args, kwargs)

def fallthrough_fn(operator, dispatch_key):
def inner(*args, **kwargs):
all_keys_after_current = _C._dispatch_keyset_full_after(dispatch_key)
all_keys_after_current_masked = all_keys_after_current & compute_keyset(args, kwargs)
return dispatcher_singleton.redispatch(operator, all_keys_after_current_masked, args, kwargs)
return inner
154 changes: 154 additions & 0 deletions functorch/test/test_control_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import torch

from torch.testing._internal.common_utils import TestCase
from functorch.experimental.cond import cond
from torch.fx.experimental.proxy_tensor import make_fx


voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
class TestControlFlow(TestCase):
def test_cond(self):
def true_fn(x):
return x.sin()

def false_fn(x):
return x.cos()

x = torch.randn(4)
result = cond(False, true_fn, false_fn, [x])
self.assertEqual(result, torch.cos(x))


class TestControlFlowTraced(TestCase):
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
def test_cond_traced(self):
def true_fn(x):
return x.sin()

def false_fn(x):
return x.cos()

def f(x, y):
return cond(y, true_fn, false_fn, [x])

x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False))
result_true = graph.forward(x, torch.tensor(True))
result_false = graph.forward(x, torch.tensor(False))
self.assertFalse(torch.allclose(result_true, result_false))
print("result_true", result_true)
self.assertEqual(result_true, torch.sin(x))
self.assertEqual(result_false, torch.cos(x))

def test_cond_nested_traced(self):
def true_nested(y):
return y * y

def false_nested(y):
return y + y

def true_fn(x, pred2):
z = cond(pred2, true_nested, false_nested, [x])
return x + z

def false_fn(x, _):
return x.cos()

def f(x, pred, pred2):
return cond(pred, true_fn, false_fn, [x, pred2])

x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))

result_true_true = graph.forward(x, torch.tensor(True), torch.tensor(True)) # True + True -> x * x
result_true_false = graph.forward(x, torch.tensor(True), torch.tensor(False)) # True + True -> x + x
result_false_true = graph.forward(x, torch.tensor(False), torch.tensor(True)) # False + either -> cos
result_false_false = graph.forward(x, torch.tensor(False), torch.tensor(False)) # False + either -> cos

self.assertNotEqual(result_true_true, result_true_false)
self.assertFalse(torch.allclose(result_false_true, result_true_true))

self.assertEqual(result_false_true, result_false_false)

self.assertEqual(result_true_true, (x * x) + x)
self.assertEqual(result_true_false, x + x + x)

self.assertEqual(result_false_true, torch.cos(x))

def test_cond_nested_traced_other_inputs(self):
def true_nested(y):
return y * y

def false_nested(y):
return y + y

def true_fn(k, pred2):
z = cond(pred2, true_nested, false_nested, [k])
return torch.add(torch.tensor([.25, .25]), z)

def false_fn(k, _):
return k.cos()

def f(k, pred, pred2):
return cond(pred, true_fn, false_fn, [k, pred2])

x = torch.tensor([0.5, 0.5])
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))

a = torch.tensor([1.0, 1.0])
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))

b = torch.tensor([2.0, 2.0])
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))

def test_cond_nested_traced_multi(self):
def true_a(y):
return y * y

def false_a(y):
return y + y

def true_b(y, z):
return y + z

def false_b(y, z):
return y * z

def f(x, pred, pred2):
a_out = cond(pred, true_a, false_a, [x])
b_out = cond(pred2, true_b, false_b, [x, x])
return a_out + b_out

x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe more of a question for @Chillee @eellison -- do we want to test that this works under the other flavors of make_fx (I'm mostly just thinking fake Tensor tracing; this is probably not going to work with symbolic yet) or do we punt that to the future?


# Brittle, yet, delicious
out = """
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = functorch_experimental_ops_cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
conditional_1 = functorch_experimental_ops_cond(pred2_1, true_graph_1, false_graph_1, [x_1, x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
add = conditional + conditional_1; conditional = conditional_1 = None
return add
"""
code = graph.code

# Normalization hack, cause .code makes some weird whitespace
code = "".join(code.split())
out = "".join(out.split())
self.assertEqual(code, out)

code = graph.true_graph_0.code
out = """
def forward(self, flat_args):
flat_args_1, = fx_pytree.tree_flatten_spec([flat_args], self._in_spec)
mul_tensor = torch.ops.aten.mul.Tensor(flat_args_1, flat_args_1); flat_args_1 = None
return pytree.tree_unflatten([mul_tensor], self._out_spec)
"""
# Normalization hack, cause .code makes some weird whitespace
code = "".join(code.split())
out = "".join(out.split())
self.assertEqual(code, out)