Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/test_fx_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
from torch.fx.experimental.meta_tracer import MetaTracer
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from torch.fx.operator_schemas import (
Expand Down Expand Up @@ -689,6 +690,14 @@ def forward(self, x):
gm = torch.fx.GraphModule(mttm, graph)
torch.testing.assert_close(gm(x), mttm(x))

def test_proxy_tensor(self):
def f(x):
val = x.cos().cos().sum()
return torch.autograd.grad(val, x)

traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
inp = torch.randn(3, requires_grad=True)
torch.testing.assert_close(traced_graph(inp), f(inp))

def test_call_to_assert_with_msg(self):
class M(torch.nn.Module):
Expand Down
184 changes: 184 additions & 0 deletions torch/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Any, Dict, Optional, Tuple, Callable, Union
import torch
from torch._C import _disabled_torch_function_impl
import torch.utils._pytree as pytree
from torch.fx import Tracer, GraphModule
import torch.fx as fx
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from contextlib import contextmanager

__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx"]
aten = torch.ops.aten

CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}


@contextmanager
def no_dispatch():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the one from torch.testing._internal.logging_tensor.no_dispatch ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a terrible location for it.

The root of our evils is we don't have a "torch dispatch" module. I propose torch.dispatch. Any objections?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well this function should not exist?
Once the mode are fixed by Sam and users can call super(). There is no need for this function at all!

I do agree that we can have this namespace if needed though if we don't want to put it in torch.overrides (which has an ok name to house both torch_function and torch_dispatch).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we need this function to exist in some form for internal implementation purposes (e.g., actually implementing super())

I'm not very hot on torch.overrides because although the name is generic it really is very torch_function leaning right now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very hot on torch.overrides because although the name is generic it really is very torch_function leaning right now.

Sure, but we could change that? Also I plead guilding of adding torch.overrides.enable_reentrant_dispatch() there. But we can move it somewhere else if we prefer a new namespace.

guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
yield
finally:
del guard


@contextmanager
def decompose(decomposition_table):
global CURRENT_DECOMPOSITION_TABLE
old_decomposition_table = CURRENT_DECOMPOSITION_TABLE
CURRENT_DECOMPOSITION_TABLE = decomposition_table
try:
yield CURRENT_DECOMPOSITION_TABLE
finally:
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chillee do you think this should be the done and dusted API for decomps in AOTAutograd? I now agree that it should be possible to directly program extra decomps as part of the tracing process, but I still want control of the default set of decompositions (which we're gonna put in PyTorch core) to be based on what the backend declares it supports.

I suppose one could do this compositionally in the current API with something like this: decompose({**get_decomps(dont_decompose), **custom_decomps}). (this inverts the sense in which get_decomps is currently programmed; instead of asking for decomps, you say everything you DON'T want decomposed.) But this seems like a weird way to do the API if you're talking in terms of building an array of decompositions; it seems like maybe this should be abstracted away into a higher level API

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in practice... backends like NVFuser don't actually necessarily want to "try and decompose everything they don't support". For example, something like slice_backward decomposes into a new_zeros and a slice_scatter call. Now, NVFuser supports neither of these new ops, and since we've decomposed them they're now slower.

Essentially, one nice thing about this API is that the "default option" (i.e. decompose nothing) is very easy to reason about. Similarly, if you're only decomposing a few ops, it's very easy to reason about it.

I could be convinced that there could be a better API though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, agreed this is not good enough.

It seems like something better might be if we had metadata saying what a decomposition would decompose to. Then instead NVFuser can define "this is the stuff that I actually understand", and then we only do decompositions that (eventually) produce things we understand (but this is not 100% well specified, because what if a decomp produces some understandable stuff and some not understandable stuff, is it still profitable to decompose).



class ProxyTensor(torch.Tensor):
proxy: fx.Proxy

@staticmethod
def __new__(cls, elem, proxy):
# Hack to deal with super().__new__ not working for sparse tensors
if elem.is_sparse:
proxy.node.meta['tensor_meta'] = {}
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
else:
r = super().__new__(cls, elem) # type: ignore[call-arg]
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
r.proxy = proxy # type: ignore[attr-defined]

return r

def __repr__(self):
with no_dispatch():
return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type]

__torch_function__ = _disabled_torch_function_impl

@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
func = func_overload.overloadpacket
if func_overload in CURRENT_DECOMPOSITION_TABLE:
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
if func_overload == aten._local_scalar_dense.default:
raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
"It's likely that this is caused by data-dependent control flow or similar.")

def unwrap_proxy(e):
return e.proxy if isinstance(e, ProxyTensor) else e

proxy_args = pytree.tree_map(unwrap_proxy, args)
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)

proxy_out = func(*proxy_args, **proxy_kwargs)

# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more robust test would be to replicate the parsing logic in model.py

    @staticmethod
    def parse(op: str) -> 'BaseOperatorName':
        assert op != ''
        assert not op.endswith('_out'), \
            "_out suffix is reserved and not permitted for operator names; " \
            "did you mean to specify an out overload name instead?"
        m = re.match(r'^__([^_]+)__$', op)
        if m is not None:
            dunder_method = True
            base = m.group(1)
            if any(base == f'i{n}' for n in AUGMENTED_ASSIGNMENT_NAMES):
                inplace = True
                base = base[1:]
            else:
                inplace = False
                # temporary, this is not intrinsically true but
                # has been historically true for dunder methods
                # we support  (but, if we ever got, say, __int__, this would
                # be wrong!)
                assert base[0] != 'i'
        else:
            dunder_method = False
            base = op
            if base[-1] == '_':
                inplace = True
                base = base[:-1]
            else:
                inplace = False
        r = BaseOperatorName(base=base, inplace=inplace, dunder_method=dunder_method)
        assert str(r) == op, f'{str(r)} != {op}'
        return r

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the best solution would be just to get this metadata on the freaking overloads cc @anjali411

args[0].proxy = proxy_out
proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])

with no_dispatch():
real_out = func_overload(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chillee wondering... would proxy tensor be better as a mode? We'd just chuck the proxies as attributes on vanilla tensors, no subclassing involved at all, and then just a mode to record IR and propagate proxies. Would be able to trace factories this way.

cc @samdow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've looked at having proxy tensors as a mode (hacky version of the change here). That still has subclassing, so @Chillee might know if there's something more fundamental but I can see what breaks locally if we just use attributes on vanilla tensors

We would need to wait for #75966 to land for any part of the mode solution to work at all in a non-hacky way because:

  1. The current tracer is saved as an attribute on the the proxies. The PR introduces the class PythonMode so that we can have saved the state on the mode. Without that PR, factory functions need the tracer to be saved as global state because they don't take any tensor inputs (this is how it works in the hacky version)
  2. (minor, not currently in the PR but should be) right now, torch_dispatch always warns if it's an instance method


def wrap_with_proxy(e, proxy):
if type(e) == torch.Tensor:
return ProxyTensor(e, proxy)
else:
return e

# Unfortunately, tree_map cannot directly be used here. As the resulting
# object may be a proxy that represents a tuple, we may need to
# explicitly unwrap the proxy by simulating the flattening operations.
if isinstance(real_out, tuple):
return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out))
elif isinstance(real_out, list):
return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
elif isinstance(real_out, torch.Tensor):
return wrap_with_proxy(real_out, proxy_out)
else:
return real_out


class PythonKeyTracer(Tracer):
def __init__(self):
super().__init__()

# In general, we don't want to make modules leaves. In principle, users of
# this tracer might want to override this in order to turn a couple specific
# modules into leaves in the traced graph.
def call_module(
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Any:
return forward(*args, **kwargs)

def create_arg(self, a: Any):
if isinstance(a, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if a is p:
return self.create_node('get_attr', n, (), {})
qualname: Optional[str] = None

if not qualname:
i = 0
while True:
qualname = f'_param_constant{i}'
if not hasattr(self.root, qualname):
break
i += 1
setattr(self.root, qualname, a)

return self.create_node('get_attr', qualname, (), {})
return super().create_arg(a)


def dispatch_trace(
root: Union[torch.nn.Module, Callable], concrete_args: Optional[Tuple[Any, ...]] = None
) -> GraphModule:
tracer = PythonKeyTracer()
graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)


def wrap_key(f, inps):
flat_inps, _ = pytree.tree_flatten(inps)

@functools.wraps(f)
def wrapped(*args):
flat_args, args_spec = pytree.tree_flatten(args)
assert(len(flat_args) == len(flat_inps))
for idx, arg in enumerate(flat_args):
if isinstance(flat_inps[idx], torch.Tensor):
flat_args[idx] = ProxyTensor(flat_inps[idx], arg)
else:
flat_args[idx] = flat_inps[idx]

tree_args = pytree.tree_unflatten(flat_args, args_spec)
out = f(*tree_args)
flat_outs, out_spec = pytree.tree_flatten(out)
for idx in range(len(flat_outs)):
if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor):
flat_outs[idx] = flat_outs[idx].proxy
return pytree.tree_unflatten(flat_outs, out_spec)

return wrapped


def make_fx(f, decomposition_table=None):
if decomposition_table is None:
decomposition_table = {}

@functools.wraps(f)
def wrapped(*args):
phs = pytree.tree_map(lambda x: fx.PH, args) # type: ignore[attr-defined]
with decompose(decomposition_table):
t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs))
return t

return wrapped