Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch cond operator, python dispatch, pyoperator #83154

Closed
wants to merge 86 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
ab64805
Initial and reset
voznesenskym Aug 9, 2022
599d91a
Unset
voznesenskym Aug 9, 2022
2e3e976
Hax
voznesenskym Aug 10, 2022
d92b6f4
Impl wip
voznesenskym Aug 10, 2022
0d5e2bf
b etter outs
voznesenskym Aug 10, 2022
e3db51f
Work a little better
voznesenskym Aug 10, 2022
ccec9ef
less wrong
voznesenskym Aug 10, 2022
d472fa2
Hack up make_fx to natively support varargs
ezyang Aug 10, 2022
4071b1b
fixes
voznesenskym Aug 11, 2022
dab9822
Progress, works unnested
voznesenskym Aug 11, 2022
f2d885c
Feedback from Ed
voznesenskym Aug 11, 2022
56db44a
progress
voznesenskym Aug 12, 2022
28db0f8
Add data_dependent_output tag; generalize proxy tensor to test it
ezyang Aug 12, 2022
9dfeaee
Update base for Update on "Add data_dependent_output tag; generalize …
ezyang Aug 12, 2022
a3efe51
Update on "Add data_dependent_output tag; generalize proxy tensor to …
ezyang Aug 12, 2022
2ea2860
Add *_only and all/any pytree utilities
ezyang Aug 12, 2022
af5fe52
Update base for Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
255a288
Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
2f13c14
Working
voznesenskym Aug 12, 2022
97cdd25
lint
voznesenskym Aug 12, 2022
247a68c
Update base for Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
946d4fe
Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
1aa70d2
Update base for Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
fb4b669
Update on "Add *_only and all/any pytree utilities"
ezyang Aug 12, 2022
cdee501
RFC: Delete ProxyTensor wrapper subclass
ezyang Aug 12, 2022
ca203af
Update on "RFC: Delete ProxyTensor wrapper subclass"
ezyang Aug 12, 2022
f38bc6f
Update on "RFC: Delete ProxyTensor wrapper subclass"
ezyang Aug 13, 2022
670191a
Rearrange the chairs
voznesenskym Aug 13, 2022
7e903c7
Tests. real ones
voznesenskym Aug 13, 2022
63ebfb3
lints, fixes
voznesenskym Aug 13, 2022
0807c85
Linter
voznesenskym Aug 13, 2022
2bf2686
Refactors
voznesenskym Aug 13, 2022
343e9af
More cleanup
voznesenskym Aug 13, 2022
208c5b0
Update base for Update on "RFC: Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
a5cb0fc
Update on "RFC: Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
d67a117
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
a51fcb5
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
6f29257
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
dcba0b4
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
6fa4bd4
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
e7521fe
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 15, 2022
d634ea0
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 16, 2022
776cf3a
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 16, 2022
f8cf09c
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
2415bbe
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
3974ecd
get rid of the reentrant parts
voznesenskym Aug 16, 2022
f10df36
first small feedback
voznesenskym Aug 16, 2022
a5b874a
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
010700b
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
651f351
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
e4f5ff8
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
9cebd47
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
68018c0
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
aee98ad
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
20cf7ab
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
ea58ef6
Update base for Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
604ec98
Update on "Delete ProxyTensor wrapper subclass"
ezyang Aug 16, 2022
7ceaf24
Still not working, progress
voznesenskym Aug 17, 2022
d55baf4
cMerge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 17, 2022
aeae58c
Merge branch 'gh/ezyang/1326/head' of github.com:pytorch/pytorch into…
voznesenskym Aug 17, 2022
f009a58
Fixes
voznesenskym Aug 17, 2022
1d62b06
lint
voznesenskym Aug 17, 2022
d593d25
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 18, 2022
1e98c87
feedback pass
voznesenskym Aug 18, 2022
9de2972
Revert "feedback pass"
voznesenskym Aug 19, 2022
1bde5ee
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 19, 2022
52ffdc6
Make it less bad, safeguards, questionable choices
voznesenskym Aug 20, 2022
ee61ade
Add back accidently rm important line
voznesenskym Aug 20, 2022
2059523
Feedback pass
voznesenskym Aug 22, 2022
2b7b876
Feedback pass 2
voznesenskym Aug 22, 2022
366591c
Feedback pass 3
voznesenskym Aug 23, 2022
50181ee
Feedback pass, cleanup, improvements
voznesenskym Aug 23, 2022
663f8ca
types
voznesenskym Aug 23, 2022
10db393
Lints
voznesenskym Aug 23, 2022
8b56428
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 23, 2022
4c115b0
Trace better
voznesenskym Aug 24, 2022
f2445ce
Better tests, less execution
voznesenskym Aug 24, 2022
c4fe703
Cleanups, improvements
voznesenskym Aug 24, 2022
a346b3d
Feedback cleanup
voznesenskym Aug 24, 2022
fc5d7ff
Last few little nits
voznesenskym Aug 24, 2022
0840e56
Fix test
voznesenskym Aug 24, 2022
ddd8f0a
missed a spot
voznesenskym Aug 24, 2022
4000d77
Lint the test
voznesenskym Aug 25, 2022
25e0701
Shuffle stuff around
voznesenskym Aug 25, 2022
1c152bb
More lint stuff
voznesenskym Aug 25, 2022
670f3d2
Merge branch 'master' of github.com:pytorch/pytorch into voz/ctfl_proto
voznesenskym Aug 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
338 changes: 338 additions & 0 deletions control_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
import torch._C as _C
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
from torch.overrides import handle_torch_function, has_torch_function

"""
Structured control flow operators prototype.

This is a prototype of the cond operator. Its API is the following:
cond(pred, true_fn, false_fn, *operands)

Note that cond has some special features that makes it different from
an ATen operator:
- It accepts two functions as arguments (true_fn, false_fn)
- It accepts varargs (*operands)

We don't quite know how to shoehorn these types into the PyTorch dispatcher
(I'm sure it's possible though), so the proposal is: handle all of the
weirdness in Python.

The approach we take is:
- We set up a "Python version of the PyTorch Dispatcher" (call this PyDispatcher).
This is responsible for performing dispatch on operations in python.
- We have a notion of a "pyoperator" (not to be confused with Anjali's Python Op
Registration API). A "pyoperator" is an Operator that was defined in Python
and handled by the "Python version of the PyTorch Dispatcher"
(Anjali's Python Op Registration API creates operators in Python that are handled
by the PyTorch C++ Dispatcher).
- A "pyoperator":
- Does not require a schema
- Can accept functions as arguments
- Can accept varargs as arguments

Given a PyOperator, we can define "rules" for it for each dispatch key.
"""


SUPPORTED_KEYS = {
DispatchKey.CPU,
DispatchKey.BackendSelect,
DispatchKey.ADInplaceOrView,
DispatchKey.AutogradCPU,
DispatchKey.Python,
DispatchKey.PythonTLSSnapshot,
}

"""
This is a dispatcher (in Python)
- You can define new operations (in Python) without schemas
- It interfaces with the PyTorch dispatcher
"""

class PyDispatcher:
def __init__(self):
self.current_dispatching_op = None
self.already_dispatched_keys = None

def call(self, operator, args, kwargs):
try:
key = compute_dispatch_key(operator, args, kwargs)
self.record_dispatch(key, operator)
print(f"PyDispatcher.call {key}")
return dispatch(key, operator, args, kwargs)
finally:
self.reset_dispatch_record()

def redispatch(self, operator, args, kwargs):
# Redispatch doesn't go to the top
assert operator == self.currently_dispatching_op
key = compute_dispatch_key(operator, args, kwargs, self.already_dispatched_keys)
self.record_dispatch(key, operator)
print(f"PyDispatcher.redispatch {key}")
return dispatch(key, operator, args, kwargs)

def reset_dispatch_record(self):
self.current_dispatching_op = None
self.already_dispatched_keys = None

def record_dispatch(self, dispatch_key, operator):
self.currently_dispatching_op = operator
if self.already_dispatched_keys is None:
self.already_dispatched_keys = DispatchKeySet(dispatch_key)
else:
self.already_dispatched_keys = self.already_dispatched_keys | DispatchKeySet(dispatch_key)


dispatcher_singleton = PyDispatcher()

class PyOperator:
def __init__(self, name):
self.name = name
self.table = {}

# TODO: torch_dispatch expects PyOperator to be an instance of a torch.ops.aten op.
self.overloadpacket = self

# Hack for FX tracing
self.__name__ = f'torch.{name}'

def impl(self, dispatch_key, fn):
assert dispatch_key not in self.table
self.table[dispatch_key] = fn

def fallthrough(self, dispatch_key):
assert dispatch_key not in self.table
self.table[dispatch_key] = fallthrough_fn(self, dispatch_key)

def __call__(self, *args, **kwargs):
flat_args = to_flat_tuple(args, kwargs)
if has_torch_function(flat_args):
return handle_torch_function(self, flat_args, *args, **kwargs)

return dispatcher_singleton.call(self, args, kwargs)


def compute_dispatch_key(PyOperator, args, kwargs, additional_exclude=None):
tensors = get_tensors(args, kwargs)
dispatch_key = key_extractor(tensors, additional_exclude)
return dispatch_key


def dispatch(dispatch_key, operator, args, kwargs):
print("Dispatching:", dispatch_key, operator.__name__)
if dispatch_key not in SUPPORTED_KEYS:
raise RuntimeError(f'NYI: {dispatch_key}')
assert dispatch_key in operator.table
kernel = operator.table[dispatch_key]
return kernel(*args, **kwargs)


def key_extractor(tensors, additional_exclude=None):
key_set = _C._dispatch_tls_local_include_set()
for tensor in tensors:
key_set = key_set | _C._dispatch_keys(tensor)
key_set = key_set - _C._dispatch_tls_local_exclude_set()
if additional_exclude is not None:
key_set = key_set - additional_exclude
return key_set.highestPriorityTypeId()


def to_flat_tuple(args, kwargs):
flat_args, _ = tree_flatten(args)
flat_kwargs, _ = tree_flatten(kwargs)
flat_all = flat_args + flat_kwargs
return flat_all

def get_tensors(args, kwargs):
flat_all = to_flat_tuple(args, kwargs)
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
return tuple(tensor_args)

"""
We're going to define a `cond` operation.
In order to do this, we need implementations for each of the dispatch keys.
"""

def cond_dense(pred, true_fn, false_fn, *operands):
print("Running cond dense", pred)
# print("Pred?", pred.code)
# if pred(tuple()):
if pred.elem:
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
print("True!")
x = true_fn(*operands)
return x
else:
print("False!")
x = false_fn(*operands)
return x


def cond_autograd(pred, true_fn, false_fn, *operands):
# TODO: support autograd
flat_operands, _ = tree_flatten((true_fn, false_fn) + operands)
assert all([not f.requires_grad for f in flat_operands
if isinstance(f, torch.Tensor)])

guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
return cond(pred, true_fn, false_fn, *operands)


def cond_adinplaceorview(*args, **kwargs):
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.ADInplaceOrView))
return cond(*args, **kwargs)


def fallthrough_fn(operator, dispatch_key):
def inner(*args, **kwargs):
return dispatcher_singleton.redispatch(operator, args, kwargs)
return inner


def python_fallback(op):
def inner(*args, **kwargs):
print("Input:", args)
# Get all tensors. For each tensor, try their torch_dispatch
# until one returns something other than NotImplemented
def extract():
tensors = get_tensors(args, kwargs)
for tensor in tensors:
# print("T:", tensor)
ret = tensor.__torch_dispatch__(op, None, args, kwargs)
if ret is NotImplemented:
continue
return ret
return NotImplemented

mode = torch._C._get_torch_dispatch_mode()
if mode is not None:
with mode.restore():
return extract()
else:
return cond_dense(*args)

return inner


cond = PyOperator('cond')
cond.impl(DispatchKey.CPU, cond_dense)
cond.impl(DispatchKey.AutogradCPU, cond_autograd)
cond.fallthrough(DispatchKey.ADInplaceOrView)
cond.fallthrough(DispatchKey.BackendSelect)

cond.impl(DispatchKey.Python, python_fallback(cond))
cond.fallthrough(DispatchKey.PythonTLSSnapshot)


"""
Test case #1: basic
"""

def true_fn(x):
return x.sin()

def false_fn(x):
return x.cos()

x = torch.randn(4)
# result = cond(False, true_fn, false_fn, x)
# assert torch.allclose(result, torch.cos(x))

"""
Test case #2: tracing

NB: We need some additional way to add a new "lowering rule" for
lowering the cond call to an FX node. In particular,
cond accepts a true_fn/false_fn and these need to be traced out.

I've hardcoded the logic into ProxyTensor.
"""
print("EXAMPLE 2")
from torch.fx.experimental.proxy_tensor import make_fx

def f(x, y):
return cond(y, true_fn, false_fn, x)

graph = make_fx(f)(x, torch.tensor(True))
print("graph.code:")
print(graph.code)
graph.graph.print_tabular()
print("Invoking:")
result_false = graph.forward(x, torch.tensor(True))
print("False:", result_false)
result_true = graph(x, torch.tensor(False))
print("True:", result_true)
# result_true()
# print(graph.forward())

exit(0)

"""
def forward(self, x_1, pred_1):
_tensor_constant0 = self._tensor_constant0
conditional = __main___torch_cond(False, wrapped(), wrapped(), _tensor_constant0); _tensor_constant0 = None
return conditional

opcode name target args kwargs
------------- ----------------- ---------------------------------------------- ------------------------------------------------ --------
placeholder x_1 x_1 () {}
placeholder pred_1 pred_1 () {}
get_attr _tensor_constant0 _tensor_constant0 () {}
call_function conditional <__main__.PyOperator object at 0x7ff4b4480100> (False, wrapped(), wrapped(), _tensor_constant0) {}
output output output (conditional,) {}

"""

"""
Test case #3: tracing complex

NB: We need some additional way to add a new "lowering rule" for
lowering the cond call to an FX node. In particular,
cond accepts a true_fn/false_fn and these need to be traced out.

I've hardcoded the logic into ProxyTensor.
"""
from torch.fx.experimental.proxy_tensor import make_fx

def true_fn(x, pred2):
def true_nested(y):
return y * y

def false_nested(y):
return y + y

return cond(pred2, true_nested, false_nested, x.sin())

def false_fn(x, _):
return x.cos()

def f(x, pred, pred2):
return cond(pred, true_fn, false_fn, (x, pred2))

graph = make_fx(f)(x, False, True)
print("graph.code:")
print(graph.code)
graph.graph.print_tabular()

"""
def forward(self, x_1, pred_1, pred2_1):
_tensor_constant0 = self._tensor_constant0
conditional = __main___torch_cond(False, wrapped(), wrapped(), (_tensor_constant0, True)); _tensor_constant0 = None
return conditional

opcode name target args kwargs
------------- ----------------- ---------------------------------------------- -------------------------------------------------------- --------
placeholder x_1 x_1 () {}
placeholder pred_1 pred_1 () {}
placeholder pred2_1 pred2_1 () {}
get_attr _tensor_constant0 _tensor_constant0 () {}
call_function conditional <__main__.PyOperator object at 0x7ff4b4480100> (False, wrapped(), wrapped(), (_tensor_constant0, True)) {}
output output output (conditional,) {}
"""
"""
More test cases (coming soon)

3. Autograd
4. functorch transforms!
"""
50 changes: 50 additions & 0 deletions torch/csrc/utils/python_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,56 @@ void initDispatchBindings(PyObject* module) {
return at::isTensorSubclassLike(tensor);
});

m.def("_dispatch_total_keys", []() {
return (int64_t)c10::DispatchKey::EndOfFunctionalityKeys;
});
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
m.def("_dispatch_key_name", [](uint64_t dispatch_key) {
auto dt = (c10::DispatchKey)dispatch_key;
return c10::toString(dt);
});
m.def("_dispatch_num_backends", []() {
return c10::num_backends;
});

py::enum_<c10::DispatchKey>(m, "DispatchKey")
.value("Undefined", c10::DispatchKey::Undefined)
.value("Dense", c10::DispatchKey::Dense)
.value("BackendSelect", c10::DispatchKey::BackendSelect)
.value("CPU", c10::DispatchKey::CPU)
.value("CUDA", c10::DispatchKey::CUDA)
.value("AutocastCPU", c10::DispatchKey::AutocastCPU)
.value("AutocastCUDA", c10::DispatchKey::AutocastCUDA)
.value("AutogradCPU", c10::DispatchKey::AutogradCPU)
.value("ADInplaceOrView", c10::DispatchKey::ADInplaceOrView)
.value("AutogradCUDA", c10::DispatchKey::AutogradCUDA)
.value("PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot)
.value("Python", c10::DispatchKey::Python);
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved

py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
.def(py::init<c10::DispatchKey>())
.def("__or__", &c10::DispatchKeySet::operator|)
.def("__sub__", &c10::DispatchKeySet::operator-)
.def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
.def("has", &c10::DispatchKeySet::has);

m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) {
return c10::toString(keyset);
});

m.def("_dispatch_keys", [](const at::Tensor& tensor) {
auto* impl = tensor.unsafeGetTensorImpl();
return impl->key_set();
});
m.def("_dispatch_tls_local_include_set", []() {
return c10::impl::tls_local_dispatch_key_set().included_;
});
m.def("_dispatch_tls_local_exclude_set", []() {
return c10::impl::tls_local_dispatch_key_set().excluded_;
});
py::class_<c10::impl::ExcludeDispatchKeyGuard>(m, "ExcludeDispatchKeyGuard")
.def(py::init<c10::DispatchKeySet>());


py::class_<at::AutoDispatchBelowAutograd>(m, "_AutoDispatchBelowAutograd")
.def(py::init<>());

Expand Down