-
Notifications
You must be signed in to change notification settings - Fork 21.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Torch cond operator, python dispatch, pyoperator (#83154)
Fixes #ISSUE_NUMBER Pull Request resolved: #83154 Approved by: https://github.com/ezyang
- Loading branch information
1 parent
3c2a078
commit ced2ca8
Showing
9 changed files
with
464 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import torch | ||
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard | ||
from functorch.experimental.ops import PyOperator, fallthrough_fn | ||
from torch.utils._pytree import tree_flatten | ||
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot | ||
import torch.utils._pytree as pytree | ||
from torch.utils._python_dispatch import TorchDispatchMode | ||
from torch.fx.experimental.proxy_tensor import track_tensor_tree | ||
|
||
|
||
""" | ||
We're going to define a `cond` operation. | ||
In order to do this, we need implementations for each of the dispatch keys. | ||
""" | ||
from contextlib import contextmanager | ||
|
||
# TODO(voz): Move out somewhere else once other py dispatched ops need it | ||
@contextmanager | ||
def suspend_mode(mode): | ||
assert(mode is not None), "Cannot suspend None mode" | ||
assert(isinstance(mode, TorchDispatchMode)), f"Unexpected mode type {mode.__class__}" | ||
torch._C._set_torch_dispatch_mode(None) | ||
try: | ||
yield | ||
finally: | ||
torch._C._set_torch_dispatch_mode(mode) | ||
|
||
@contextmanager | ||
def enable_mode(mode): | ||
curr_mode = torch._C._get_torch_dispatch_mode() | ||
torch._C._set_torch_dispatch_mode(mode) | ||
try: | ||
yield | ||
finally: | ||
torch._C._set_torch_dispatch_mode(curr_mode) | ||
|
||
|
||
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): | ||
def _unwrap_proxy(e): | ||
return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy) | ||
|
||
assert isinstance(operands, list), "Cond operands must be a list of tensors" | ||
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors" | ||
|
||
true_graph = get_isolated_graphmodule(true_fn, operands, {}) | ||
false_graph = get_isolated_graphmodule(false_fn, operands, {}) | ||
|
||
true_outs = [] | ||
false_outs = [] | ||
for node in true_graph.graph.nodes: | ||
if node.op == 'output': | ||
true_outs.extend(node.args) | ||
|
||
for node in false_graph.graph.nodes: | ||
if node.op == 'output': | ||
false_outs.extend(node.args) | ||
|
||
flat_true_outs, _ = pytree.tree_flatten(true_outs) | ||
flat_false_outs, _ = pytree.tree_flatten(false_outs) | ||
assert(len(flat_true_outs) == len(flat_false_outs)) | ||
|
||
for i in range(0, len(flat_true_outs)): | ||
true_out = flat_true_outs[i] | ||
false_out = flat_false_outs[i] | ||
assert true_out.meta == false_out.meta | ||
|
||
# There are probably better ways - I know that create_arg has some self incrementing name | ||
# magic to it, but since we explicitly have to get the name for register_module, | ||
# I was not sure how to do that. This kinda simulates it. | ||
next_name = None | ||
i = 0 | ||
while not next_name: | ||
candidate = f"true_graph_{i}" | ||
if hasattr(proxy_mode.tracer.root, candidate): | ||
i += 1 | ||
else: | ||
next_name = candidate | ||
|
||
true_name = next_name | ||
false_name = f"false_graph_{i}" | ||
assert(not hasattr(proxy_mode.tracer.root, false_name)) | ||
|
||
proxy_mode.tracer.root.register_module(true_name, true_graph) | ||
proxy_mode.tracer.root.register_module(false_name, false_graph) | ||
|
||
args = (pred, true_graph, false_graph, [operands]) | ||
|
||
proxy_args = pytree.tree_map(_unwrap_proxy, args) | ||
|
||
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {}, | ||
name="conditional") | ||
|
||
if pred: | ||
out = true_fn(*operands) | ||
else: | ||
out = false_fn(*operands) | ||
|
||
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) | ||
|
||
|
||
def cond_dense(pred, true_fn, false_fn, operands): | ||
mode = torch._C._get_torch_dispatch_mode() | ||
assert (mode is None), "Mode should never be enabled for CPU key" | ||
if pred: | ||
return true_fn(*operands) | ||
else: | ||
return false_fn(*operands) | ||
|
||
|
||
def cond_autograd(pred, true_fn, false_fn, *operands): | ||
# TODO: support autograd | ||
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands]) | ||
assert all([not f.requires_grad for f in flat_operands | ||
if isinstance(f, torch.Tensor)]) | ||
|
||
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)) | ||
return cond(pred, true_fn, false_fn, *operands) | ||
|
||
|
||
def python_fallback(op): | ||
def inner(pred, true_fn, false_fn, operands): | ||
mode = torch._C._get_torch_dispatch_mode() | ||
assert (mode is not None), "Mode should always be enabled for python fallback key" | ||
with suspend_mode(mode): | ||
res = trace_cond(mode, op, pred, true_fn, false_fn, operands) | ||
return res | ||
|
||
return inner | ||
|
||
|
||
cond = PyOperator('cond') | ||
cond.impl(DispatchKey.CPU, cond_dense) | ||
cond.impl(DispatchKey.Python, python_fallback(cond)) | ||
cond.impl(DispatchKey.PythonTLSSnapshot, fallthrough_fn) | ||
cond.impl(DispatchKey.AutogradCPU, cond_autograd) | ||
cond.impl(DispatchKey.ADInplaceOrView, fallthrough_fn) | ||
cond.impl(DispatchKey.BackendSelect, fallthrough_fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from torch._dispatch._dispatcher import PyDispatcher, to_flat_tuple, compute_keyset | ||
from torch.nn.functional import handle_torch_function | ||
from torch.overrides import has_torch_function | ||
import torch._C as _C | ||
|
||
class PyOperator: | ||
def __init__(self, name): | ||
self.name = name | ||
self.table = {} | ||
|
||
self.__name__ = name | ||
|
||
def impl(self, dispatch_key, fn): | ||
assert dispatch_key not in self.table | ||
if fn is fallthrough_fn: | ||
self.table[dispatch_key] = fn(self, dispatch_key) | ||
else: | ||
self.table[dispatch_key] = fn | ||
|
||
def lookup(self, keyset): | ||
dispatch_key = keyset.highestPriorityTypeId() | ||
return self.table[dispatch_key] | ||
|
||
def __call__(self, *args, **kwargs): | ||
flat_args = to_flat_tuple(args, kwargs) | ||
if has_torch_function(flat_args): | ||
return handle_torch_function(self, flat_args, *args, **kwargs) | ||
|
||
return PyDispatcher.call(self, *args, **kwargs) | ||
|
||
def fallthrough_fn(operator, dispatch_key): | ||
def inner(*args, **kwargs): | ||
all_keys_after_current = _C._dispatch_keyset_full_after(dispatch_key) | ||
all_keys_after_current_masked = all_keys_after_current & compute_keyset(args, kwargs) | ||
return PyDispatcher.redispatch(operator, all_keys_after_current_masked, *args, **kwargs) | ||
return inner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import torch | ||
|
||
from torch.testing._internal.common_utils import TestCase | ||
from functorch.experimental.cond import cond | ||
from torch.fx.experimental.proxy_tensor import make_fx | ||
|
||
|
||
class TestControlFlow(TestCase): | ||
def test_cond_no_trace(self): | ||
def true_fn(x): | ||
return x.sin() | ||
|
||
def false_fn(x): | ||
return x.cos() | ||
|
||
x = torch.randn(4) | ||
result = cond(False, true_fn, false_fn, [x]) | ||
self.assertEqual(result, torch.cos(x)) | ||
|
||
|
||
class TestControlFlowTraced(TestCase): | ||
def test_cond_traced(self): | ||
def true_fn(x): | ||
return x.sin() | ||
|
||
def false_fn(x): | ||
return x.cos() | ||
|
||
def f(x, y): | ||
return cond(y, true_fn, false_fn, [x]) | ||
|
||
x = torch.randn(4) | ||
graph = make_fx(f)(x, torch.tensor(False)) | ||
result_true = graph.forward(x, torch.tensor(True)) | ||
result_false = graph.forward(x, torch.tensor(False)) | ||
self.assertFalse(torch.allclose(result_true, result_false)) | ||
self.assertEqual(result_true, torch.sin(x)) | ||
self.assertEqual(result_false, torch.cos(x)) | ||
|
||
def test_cond_nested_traced(self): | ||
def true_nested(y): | ||
return y * y | ||
|
||
def false_nested(y): | ||
return y + y | ||
|
||
def true_fn(x, pred2): | ||
z = cond(pred2, true_nested, false_nested, [x]) | ||
return x + z | ||
|
||
def false_fn(x, _): | ||
return x.cos() | ||
|
||
def f(x, pred, pred2): | ||
return cond(pred, true_fn, false_fn, [x, pred2]) | ||
|
||
x = torch.randn(4) | ||
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False)) | ||
|
||
result_true_true = graph.forward(x, torch.tensor(True), torch.tensor(True)) # True + True -> x * x | ||
result_true_false = graph.forward(x, torch.tensor(True), torch.tensor(False)) # True + True -> x + x | ||
result_false_true = graph.forward(x, torch.tensor(False), torch.tensor(True)) # False + either -> cos | ||
result_false_false = graph.forward(x, torch.tensor(False), torch.tensor(False)) # False + either -> cos | ||
|
||
self.assertNotEqual(result_true_true, result_true_false) | ||
self.assertFalse(torch.allclose(result_false_true, result_true_true)) | ||
|
||
self.assertEqual(result_false_true, result_false_false) | ||
|
||
self.assertEqual(result_true_true, (x * x) + x) | ||
self.assertEqual(result_true_false, x + x + x) | ||
|
||
self.assertEqual(result_false_true, torch.cos(x)) | ||
|
||
def test_cond_nested_traced_other_inputs(self): | ||
def true_nested(y): | ||
return y * y | ||
|
||
def false_nested(y): | ||
return y + y | ||
|
||
def true_fn(k, pred2): | ||
z = cond(pred2, true_nested, false_nested, [k]) | ||
return torch.add(torch.tensor([.25, .25]), z) | ||
|
||
def false_fn(k, _): | ||
return k.cos() | ||
|
||
def f(k, pred, pred2): | ||
return cond(pred, true_fn, false_fn, [k, pred2]) | ||
|
||
x = torch.tensor([0.5, 0.5]) | ||
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False)) | ||
|
||
a = torch.tensor([1.0, 1.0]) | ||
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True)) | ||
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25])) | ||
|
||
b = torch.tensor([2.0, 2.0]) | ||
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True)) | ||
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25])) | ||
|
||
def test_cond_nested_traced_multi(self): | ||
def true_a(y): | ||
return y * y | ||
|
||
def false_a(y): | ||
return y + y | ||
|
||
def true_b(y, z): | ||
return y + z | ||
|
||
def false_b(y, z): | ||
return y * z | ||
|
||
def f(x, pred, pred2): | ||
a_out = cond(pred, true_a, false_a, [x]) | ||
b_out = cond(pred2, true_b, false_b, [x, x]) | ||
return a_out + b_out | ||
|
||
x = torch.randn(4) | ||
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False)) | ||
|
||
# Brittle, yet, delicious | ||
out = """ | ||
def forward(self, x_1, pred_1, pred2_1): | ||
true_graph_0 = self.true_graph_0 | ||
false_graph_0 = self.false_graph_0 | ||
conditional = functorch_experimental_ops_cond(pred_1, | ||
true_graph_0, false_graph_0, [[x_1]]); pred_1 = true_graph_0 = false_graph_0 = None | ||
true_graph_1 = self.true_graph_1 | ||
false_graph_1 = self.false_graph_1 | ||
conditional_1 = functorch_experimental_ops_cond(pred2_1, | ||
true_graph_1, false_graph_1, [[x_1, x_1]]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None | ||
add_tensor = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None | ||
return add_tensor | ||
""" | ||
code = graph.code | ||
# Normalization hack, cause .code makes some weird whitespace | ||
code = "".join(code.split()) | ||
out = "".join(out.split()) | ||
self.assertEqual(code, out) | ||
|
||
code = graph.true_graph_0.code | ||
out = """ | ||
def forward(self, flat_args): | ||
flat_args_1, = fx_pytree.tree_flatten_spec([flat_args], self._in_spec) | ||
mul_tensor = torch.ops.aten.mul.Tensor(flat_args_1, flat_args_1); flat_args_1 = None | ||
return pytree.tree_unflatten([mul_tensor], self._out_spec) | ||
""" | ||
# Normalization hack, cause .code makes some weird whitespace | ||
code = "".join(code.split()) | ||
out = "".join(out.split()) | ||
self.assertEqual(code, out) | ||
|
||
def test_assert_on_mismatch_type_size(self): | ||
def true_fn(x): | ||
return x.sin() | ||
|
||
def false_fn(x): | ||
return (x, x) | ||
|
||
def f(x, y): | ||
return cond(y, true_fn, false_fn, [x]) | ||
|
||
x = torch.randn(4) | ||
with self.assertRaises(AssertionError): | ||
make_fx(f)(x, torch.tensor(False)) | ||
|
||
|
||
def test_assert_on_mismatch_tensor_size(self): | ||
def true_fn(x): | ||
return x.sin() | ||
|
||
def false_fn(x): | ||
return torch.zeros([10, 10]) | ||
|
||
def f(x, y): | ||
return cond(y, true_fn, false_fn, [x]) | ||
|
||
x = torch.randn(4) | ||
with self.assertRaises(AssertionError): | ||
make_fx(f)(x, torch.tensor(False)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.