-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Added proxy tensor #74360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added proxy tensor #74360
Changes from all commits
b3c2256
00ec063
a8ef585
b672087
51affbf
7fa4527
c09f17c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import functools | ||
from typing import Any, Dict, Optional, Tuple, Callable, Union | ||
import torch | ||
from torch._C import _disabled_torch_function_impl | ||
import torch.utils._pytree as pytree | ||
from torch.fx import Tracer, GraphModule | ||
import torch.fx as fx | ||
from torch.fx.passes.shape_prop import _extract_tensor_metadata | ||
from contextlib import contextmanager | ||
|
||
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx"] | ||
aten = torch.ops.aten | ||
|
||
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {} | ||
|
||
|
||
@contextmanager | ||
def no_dispatch(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the one from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a terrible location for it. The root of our evils is we don't have a "torch dispatch" module. I propose torch.dispatch. Any objections? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well this function should not exist? I do agree that we can have this namespace if needed though if we don't want to put it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, we need this function to exist in some form for internal implementation purposes (e.g., actually implementing super()) I'm not very hot on torch.overrides because although the name is generic it really is very torch_function leaning right now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sure, but we could change that? Also I plead guilding of adding |
||
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] | ||
try: | ||
yield | ||
finally: | ||
del guard | ||
|
||
|
||
@contextmanager | ||
def decompose(decomposition_table): | ||
global CURRENT_DECOMPOSITION_TABLE | ||
old_decomposition_table = CURRENT_DECOMPOSITION_TABLE | ||
CURRENT_DECOMPOSITION_TABLE = decomposition_table | ||
try: | ||
yield CURRENT_DECOMPOSITION_TABLE | ||
finally: | ||
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Chillee do you think this should be the done and dusted API for decomps in AOTAutograd? I now agree that it should be possible to directly program extra decomps as part of the tracing process, but I still want control of the default set of decompositions (which we're gonna put in PyTorch core) to be based on what the backend declares it supports. I suppose one could do this compositionally in the current API with something like this: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in practice... backends like NVFuser don't actually necessarily want to "try and decompose everything they don't support". For example, something like Essentially, one nice thing about this API is that the "default option" (i.e. decompose nothing) is very easy to reason about. Similarly, if you're only decomposing a few ops, it's very easy to reason about it. I could be convinced that there could be a better API though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, agreed this is not good enough. It seems like something better might be if we had metadata saying what a decomposition would decompose to. Then instead NVFuser can define "this is the stuff that I actually understand", and then we only do decompositions that (eventually) produce things we understand (but this is not 100% well specified, because what if a decomp produces some understandable stuff and some not understandable stuff, is it still profitable to decompose). |
||
|
||
|
||
class ProxyTensor(torch.Tensor): | ||
proxy: fx.Proxy | ||
|
||
@staticmethod | ||
def __new__(cls, elem, proxy): | ||
# Hack to deal with super().__new__ not working for sparse tensors | ||
if elem.is_sparse: | ||
proxy.node.meta['tensor_meta'] = {} | ||
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) | ||
else: | ||
r = super().__new__(cls, elem) # type: ignore[call-arg] | ||
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) | ||
r.proxy = proxy # type: ignore[attr-defined] | ||
|
||
return r | ||
|
||
def __repr__(self): | ||
with no_dispatch(): | ||
return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type] | ||
|
||
__torch_function__ = _disabled_torch_function_impl | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): | ||
func = func_overload.overloadpacket | ||
if func_overload in CURRENT_DECOMPOSITION_TABLE: | ||
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) | ||
Chillee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if func_overload == aten._local_scalar_dense.default: | ||
raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " | ||
"It's likely that this is caused by data-dependent control flow or similar.") | ||
|
||
def unwrap_proxy(e): | ||
return e.proxy if isinstance(e, ProxyTensor) else e | ||
|
||
proxy_args = pytree.tree_map(unwrap_proxy, args) | ||
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) | ||
|
||
proxy_out = func(*proxy_args, **proxy_kwargs) | ||
|
||
# Kind of a hacky way to test if an op is in-place or not | ||
if func.__name__[-1] == "_" and func.__name__[0] != "_": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A more robust test would be to replicate the parsing logic in model.py
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But the best solution would be just to get this metadata on the freaking overloads cc @anjali411 |
||
args[0].proxy = proxy_out | ||
proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) | ||
|
||
with no_dispatch(): | ||
real_out = func_overload(*args, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We've looked at having proxy tensors as a mode (hacky version of the change here). That still has subclassing, so @Chillee might know if there's something more fundamental but I can see what breaks locally if we just use attributes on vanilla tensors We would need to wait for #75966 to land for any part of the mode solution to work at all in a non-hacky way because:
|
||
|
||
def wrap_with_proxy(e, proxy): | ||
if type(e) == torch.Tensor: | ||
return ProxyTensor(e, proxy) | ||
else: | ||
return e | ||
|
||
# Unfortunately, tree_map cannot directly be used here. As the resulting | ||
# object may be a proxy that represents a tuple, we may need to | ||
# explicitly unwrap the proxy by simulating the flattening operations. | ||
if isinstance(real_out, tuple): | ||
Chillee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)) | ||
elif isinstance(real_out, list): | ||
return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)]) | ||
elif isinstance(real_out, torch.Tensor): | ||
return wrap_with_proxy(real_out, proxy_out) | ||
else: | ||
return real_out | ||
|
||
|
||
class PythonKeyTracer(Tracer): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
# In general, we don't want to make modules leaves. In principle, users of | ||
# this tracer might want to override this in order to turn a couple specific | ||
# modules into leaves in the traced graph. | ||
def call_module( | ||
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] | ||
) -> Any: | ||
return forward(*args, **kwargs) | ||
Chillee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def create_arg(self, a: Any): | ||
if isinstance(a, torch.nn.Parameter): | ||
for n, p in self.root.named_parameters(): | ||
if a is p: | ||
return self.create_node('get_attr', n, (), {}) | ||
qualname: Optional[str] = None | ||
|
||
if not qualname: | ||
i = 0 | ||
while True: | ||
qualname = f'_param_constant{i}' | ||
if not hasattr(self.root, qualname): | ||
break | ||
i += 1 | ||
setattr(self.root, qualname, a) | ||
|
||
return self.create_node('get_attr', qualname, (), {}) | ||
return super().create_arg(a) | ||
|
||
|
||
def dispatch_trace( | ||
root: Union[torch.nn.Module, Callable], concrete_args: Optional[Tuple[Any, ...]] = None | ||
) -> GraphModule: | ||
tracer = PythonKeyTracer() | ||
graph = tracer.trace(root, concrete_args) | ||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ | ||
return GraphModule(tracer.root, graph, name) | ||
|
||
|
||
def wrap_key(f, inps): | ||
flat_inps, _ = pytree.tree_flatten(inps) | ||
|
||
@functools.wraps(f) | ||
def wrapped(*args): | ||
flat_args, args_spec = pytree.tree_flatten(args) | ||
assert(len(flat_args) == len(flat_inps)) | ||
for idx, arg in enumerate(flat_args): | ||
if isinstance(flat_inps[idx], torch.Tensor): | ||
flat_args[idx] = ProxyTensor(flat_inps[idx], arg) | ||
else: | ||
flat_args[idx] = flat_inps[idx] | ||
|
||
tree_args = pytree.tree_unflatten(flat_args, args_spec) | ||
out = f(*tree_args) | ||
flat_outs, out_spec = pytree.tree_flatten(out) | ||
for idx in range(len(flat_outs)): | ||
if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor): | ||
flat_outs[idx] = flat_outs[idx].proxy | ||
return pytree.tree_unflatten(flat_outs, out_spec) | ||
|
||
return wrapped | ||
|
||
|
||
def make_fx(f, decomposition_table=None): | ||
if decomposition_table is None: | ||
decomposition_table = {} | ||
|
||
@functools.wraps(f) | ||
def wrapped(*args): | ||
phs = pytree.tree_map(lambda x: fx.PH, args) # type: ignore[attr-defined] | ||
with decompose(decomposition_table): | ||
t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs)) | ||
return t | ||
|
||
return wrapped |
Uh oh!
There was an error while loading. Please reload this page.