Skip to content

Commit

Permalink
Torch cond operator, python dispatch, pyoperator (#83154)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER

Pull Request resolved: #83154
Approved by: https://github.com/ezyang
  • Loading branch information
voznesenskym authored and pytorchmergebot committed Aug 25, 2022
1 parent 3c2a078 commit ced2ca8
Show file tree
Hide file tree
Showing 9 changed files with 464 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ namespace detail {
* they have been registered as fallthrough. The set of excluded backends
* varies from operator, as some operators may have overridden the
* fallthrough with custom behavior.
*
* Note - this should maintain identical impl to the py dispatcher key extraction logic
* at pytorch/torch/dispatcher.py
*/
struct TORCH_API DispatchKeyExtractor final {
public:
Expand Down
3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@
"DeserializationStorageContext",
"DeviceObjType",
"DictType",
"DispatchKey",
"DispatchKeySet",
"EnumType",
"ExcludeDispatchKeyGuard",
"ExecutionPlan",
"FileCheck",
"FloatType",
Expand Down
137 changes: 137 additions & 0 deletions functorch/functorch/experimental/cond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
from functorch.experimental.ops import PyOperator, fallthrough_fn
from torch.utils._pytree import tree_flatten
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
from torch.fx.experimental.proxy_tensor import track_tensor_tree


"""
We're going to define a `cond` operation.
In order to do this, we need implementations for each of the dispatch keys.
"""
from contextlib import contextmanager

# TODO(voz): Move out somewhere else once other py dispatched ops need it
@contextmanager
def suspend_mode(mode):
assert(mode is not None), "Cannot suspend None mode"
assert(isinstance(mode, TorchDispatchMode)), f"Unexpected mode type {mode.__class__}"
torch._C._set_torch_dispatch_mode(None)
try:
yield
finally:
torch._C._set_torch_dispatch_mode(mode)

@contextmanager
def enable_mode(mode):
curr_mode = torch._C._get_torch_dispatch_mode()
torch._C._set_torch_dispatch_mode(mode)
try:
yield
finally:
torch._C._set_torch_dispatch_mode(curr_mode)


def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
def _unwrap_proxy(e):
return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)

assert isinstance(operands, list), "Cond operands must be a list of tensors"
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"

true_graph = get_isolated_graphmodule(true_fn, operands, {})
false_graph = get_isolated_graphmodule(false_fn, operands, {})

true_outs = []
false_outs = []
for node in true_graph.graph.nodes:
if node.op == 'output':
true_outs.extend(node.args)

for node in false_graph.graph.nodes:
if node.op == 'output':
false_outs.extend(node.args)

flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_outs)
assert(len(flat_true_outs) == len(flat_false_outs))

for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
assert true_out.meta == false_out.meta

# There are probably better ways - I know that create_arg has some self incrementing name
# magic to it, but since we explicitly have to get the name for register_module,
# I was not sure how to do that. This kinda simulates it.
next_name = None
i = 0
while not next_name:
candidate = f"true_graph_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
i += 1
else:
next_name = candidate

true_name = next_name
false_name = f"false_graph_{i}"
assert(not hasattr(proxy_mode.tracer.root, false_name))

proxy_mode.tracer.root.register_module(true_name, true_graph)
proxy_mode.tracer.root.register_module(false_name, false_graph)

args = (pred, true_graph, false_graph, [operands])

proxy_args = pytree.tree_map(_unwrap_proxy, args)

out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
name="conditional")

if pred:
out = true_fn(*operands)
else:
out = false_fn(*operands)

return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)


def cond_dense(pred, true_fn, false_fn, operands):
mode = torch._C._get_torch_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU key"
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)


def cond_autograd(pred, true_fn, false_fn, *operands):
# TODO: support autograd
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
assert all([not f.requires_grad for f in flat_operands
if isinstance(f, torch.Tensor)])

guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
return cond(pred, true_fn, false_fn, *operands)


def python_fallback(op):
def inner(pred, true_fn, false_fn, operands):
mode = torch._C._get_torch_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
with suspend_mode(mode):
res = trace_cond(mode, op, pred, true_fn, false_fn, operands)
return res

return inner


cond = PyOperator('cond')
cond.impl(DispatchKey.CPU, cond_dense)
cond.impl(DispatchKey.Python, python_fallback(cond))
cond.impl(DispatchKey.PythonTLSSnapshot, fallthrough_fn)
cond.impl(DispatchKey.AutogradCPU, cond_autograd)
cond.impl(DispatchKey.ADInplaceOrView, fallthrough_fn)
cond.impl(DispatchKey.BackendSelect, fallthrough_fn)
36 changes: 36 additions & 0 deletions functorch/functorch/experimental/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from torch._dispatch._dispatcher import PyDispatcher, to_flat_tuple, compute_keyset
from torch.nn.functional import handle_torch_function
from torch.overrides import has_torch_function
import torch._C as _C

class PyOperator:
def __init__(self, name):
self.name = name
self.table = {}

self.__name__ = name

def impl(self, dispatch_key, fn):
assert dispatch_key not in self.table
if fn is fallthrough_fn:
self.table[dispatch_key] = fn(self, dispatch_key)
else:
self.table[dispatch_key] = fn

def lookup(self, keyset):
dispatch_key = keyset.highestPriorityTypeId()
return self.table[dispatch_key]

def __call__(self, *args, **kwargs):
flat_args = to_flat_tuple(args, kwargs)
if has_torch_function(flat_args):
return handle_torch_function(self, flat_args, *args, **kwargs)

return PyDispatcher.call(self, *args, **kwargs)

def fallthrough_fn(operator, dispatch_key):
def inner(*args, **kwargs):
all_keys_after_current = _C._dispatch_keyset_full_after(dispatch_key)
all_keys_after_current_masked = all_keys_after_current & compute_keyset(args, kwargs)
return PyDispatcher.redispatch(operator, all_keys_after_current_masked, *args, **kwargs)
return inner
183 changes: 183 additions & 0 deletions functorch/test/test_control_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import torch

from torch.testing._internal.common_utils import TestCase
from functorch.experimental.cond import cond
from torch.fx.experimental.proxy_tensor import make_fx


class TestControlFlow(TestCase):
def test_cond_no_trace(self):
def true_fn(x):
return x.sin()

def false_fn(x):
return x.cos()

x = torch.randn(4)
result = cond(False, true_fn, false_fn, [x])
self.assertEqual(result, torch.cos(x))


class TestControlFlowTraced(TestCase):
def test_cond_traced(self):
def true_fn(x):
return x.sin()

def false_fn(x):
return x.cos()

def f(x, y):
return cond(y, true_fn, false_fn, [x])

x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False))
result_true = graph.forward(x, torch.tensor(True))
result_false = graph.forward(x, torch.tensor(False))
self.assertFalse(torch.allclose(result_true, result_false))
self.assertEqual(result_true, torch.sin(x))
self.assertEqual(result_false, torch.cos(x))

def test_cond_nested_traced(self):
def true_nested(y):
return y * y

def false_nested(y):
return y + y

def true_fn(x, pred2):
z = cond(pred2, true_nested, false_nested, [x])
return x + z

def false_fn(x, _):
return x.cos()

def f(x, pred, pred2):
return cond(pred, true_fn, false_fn, [x, pred2])

x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))

result_true_true = graph.forward(x, torch.tensor(True), torch.tensor(True)) # True + True -> x * x
result_true_false = graph.forward(x, torch.tensor(True), torch.tensor(False)) # True + True -> x + x
result_false_true = graph.forward(x, torch.tensor(False), torch.tensor(True)) # False + either -> cos
result_false_false = graph.forward(x, torch.tensor(False), torch.tensor(False)) # False + either -> cos

self.assertNotEqual(result_true_true, result_true_false)
self.assertFalse(torch.allclose(result_false_true, result_true_true))

self.assertEqual(result_false_true, result_false_false)

self.assertEqual(result_true_true, (x * x) + x)
self.assertEqual(result_true_false, x + x + x)

self.assertEqual(result_false_true, torch.cos(x))

def test_cond_nested_traced_other_inputs(self):
def true_nested(y):
return y * y

def false_nested(y):
return y + y

def true_fn(k, pred2):
z = cond(pred2, true_nested, false_nested, [k])
return torch.add(torch.tensor([.25, .25]), z)

def false_fn(k, _):
return k.cos()

def f(k, pred, pred2):
return cond(pred, true_fn, false_fn, [k, pred2])

x = torch.tensor([0.5, 0.5])
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))

a = torch.tensor([1.0, 1.0])
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))

b = torch.tensor([2.0, 2.0])
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))

def test_cond_nested_traced_multi(self):
def true_a(y):
return y * y

def false_a(y):
return y + y

def true_b(y, z):
return y + z

def false_b(y, z):
return y * z

def f(x, pred, pred2):
a_out = cond(pred, true_a, false_a, [x])
b_out = cond(pred2, true_b, false_b, [x, x])
return a_out + b_out

x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))

# Brittle, yet, delicious
out = """
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = functorch_experimental_ops_cond(pred_1,
true_graph_0, false_graph_0, [[x_1]]); pred_1 = true_graph_0 = false_graph_0 = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
conditional_1 = functorch_experimental_ops_cond(pred2_1,
true_graph_1, false_graph_1, [[x_1, x_1]]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
add_tensor = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
return add_tensor
"""
code = graph.code
# Normalization hack, cause .code makes some weird whitespace
code = "".join(code.split())
out = "".join(out.split())
self.assertEqual(code, out)

code = graph.true_graph_0.code
out = """
def forward(self, flat_args):
flat_args_1, = fx_pytree.tree_flatten_spec([flat_args], self._in_spec)
mul_tensor = torch.ops.aten.mul.Tensor(flat_args_1, flat_args_1); flat_args_1 = None
return pytree.tree_unflatten([mul_tensor], self._out_spec)
"""
# Normalization hack, cause .code makes some weird whitespace
code = "".join(code.split())
out = "".join(out.split())
self.assertEqual(code, out)

def test_assert_on_mismatch_type_size(self):
def true_fn(x):
return x.sin()

def false_fn(x):
return (x, x)

def f(x, y):
return cond(y, true_fn, false_fn, [x])

x = torch.randn(4)
with self.assertRaises(AssertionError):
make_fx(f)(x, torch.tensor(False))


def test_assert_on_mismatch_tensor_size(self):
def true_fn(x):
return x.sin()

def false_fn(x):
return torch.zeros([10, 10])

def f(x, y):
return cond(y, true_fn, false_fn, [x])

x = torch.randn(4)
with self.assertRaises(AssertionError):
make_fx(f)(x, torch.tensor(False))
3 changes: 3 additions & 0 deletions test/test_public_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,12 @@ def test_no_new_bindings(self):
"DeviceObjType",
"DictType",
"DisableTorchFunction",
"DispatchKey",
"DispatchKeySet",
"dtype",
"EnumType",
"ErrorReport",
"ExcludeDispatchKeyGuard",
"ExecutionPlan",
"FatalError",
"FileCheck",
Expand Down
Empty file added torch/_dispatch/__init__.py
Empty file.

0 comments on commit ced2ca8

Please sign in to comment.