-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Added proxy tensor #74360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added proxy tensor #74360
Conversation
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit c09f17c (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
try: | ||
yield | ||
finally: | ||
del guard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to find a better spot for this XD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lol people have been copy-pasting this for literally a year
torch/fx/_proxy_tensor.py
Outdated
# ProxyTensor boundary. | ||
# assert not elem.requires_grad or not torch.is_grad_enabled() | ||
|
||
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am reasonably sure it's safe to do super().__new__(elem)
here; this will propagate gradients and is more correct than implicitly making leaves. But need to actually test this on functorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only failure I get when trying this is when wrapping sparse tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah OK, this makes sense; need to make the default __new__
constructor accept sparse tensors too
|
||
@classmethod | ||
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): | ||
func = func_overload.overloadpacket |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is another choice tracer users have to make: do you want to record the specific overloads, or the overload packets. Recording the overloads is strictly more information but many existing passes can't make use of this info in a useful way.
proxy_out = func(*proxy_args, **proxy_kwargs) | ||
|
||
# Kind of a hacky way to test if an op is in-place or not | ||
if func.__name__[-1] == "_" and func.__name__[0] != "_": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 schema APIs would come in handy here!
setattr(self.root, qualname, a) | ||
|
||
return self.create_node('get_attr', qualname, (), {}) | ||
return super().create_arg(a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the part of Horace's implementation that I'd most like to get eyeballs from FX folks on, idk maybe @jamesr66a @zdevito ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So is what's going on here that in a ProxyTensor context, you have nn.Parameter
instances that are not installed as members of the module hierarchy and you'd like to intern those with a generated qualified name? Do you have an example of where this could happen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes.
So, if you're tracing something like
model = ...
def f(x):
return model(x)
So, I think it's natural that in this case, the model parameters should just be constants from the perspective of the tracer. But since FX overrides parameters to look them up on the module hierarchy, this behavior needs to be modified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I talked with @jamesr66a and we think some sort of approach along these lines is reasonable.
James has a preference for preserving as much of FX's current hooks as possible.
One thing is that you can subclass this class and tweak it (for example, decompositions could be implemented in this way). But I'm not sure I want to encourage people to be willy nilly subclassing here. |
Any progress on this? |
I was sidetracked by the memory ownership issue, will get back to this this week. |
try: | ||
yield CURRENT_DECOMPOSITION_TABLE | ||
finally: | ||
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Chillee do you think this should be the done and dusted API for decomps in AOTAutograd? I now agree that it should be possible to directly program extra decomps as part of the tracing process, but I still want control of the default set of decompositions (which we're gonna put in PyTorch core) to be based on what the backend declares it supports.
I suppose one could do this compositionally in the current API with something like this: decompose({**get_decomps(dont_decompose), **custom_decomps})
. (this inverts the sense in which get_decomps
is currently programmed; instead of asking for decomps, you say everything you DON'T want decomposed.) But this seems like a weird way to do the API if you're talking in terms of building an array of decompositions; it seems like maybe this should be abstracted away into a higher level API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in practice... backends like NVFuser don't actually necessarily want to "try and decompose everything they don't support". For example, something like slice_backward
decomposes into a new_zeros
and a slice_scatter
call. Now, NVFuser supports neither of these new ops, and since we've decomposed them they're now slower.
Essentially, one nice thing about this API is that the "default option" (i.e. decompose nothing) is very easy to reason about. Similarly, if you're only decomposing a few ops, it's very easy to reason about it.
I could be convinced that there could be a better API though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, agreed this is not good enough.
It seems like something better might be if we had metadata saying what a decomposition would decompose to. Then instead NVFuser can define "this is the stuff that I actually understand", and then we only do decompositions that (eventually) produce things we understand (but this is not 100% well specified, because what if a decomp produces some understandable stuff and some not understandable stuff, is it still profitable to decompose).
proxy_out = func(*proxy_args, **proxy_kwargs) | ||
|
||
# Kind of a hacky way to test if an op is in-place or not | ||
if func.__name__[-1] == "_" and func.__name__[0] != "_": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more robust test would be to replicate the parsing logic in model.py
@staticmethod
def parse(op: str) -> 'BaseOperatorName':
assert op != ''
assert not op.endswith('_out'), \
"_out suffix is reserved and not permitted for operator names; " \
"did you mean to specify an out overload name instead?"
m = re.match(r'^__([^_]+)__$', op)
if m is not None:
dunder_method = True
base = m.group(1)
if any(base == f'i{n}' for n in AUGMENTED_ASSIGNMENT_NAMES):
inplace = True
base = base[1:]
else:
inplace = False
# temporary, this is not intrinsically true but
# has been historically true for dunder methods
# we support (but, if we ever got, say, __int__, this would
# be wrong!)
assert base[0] != 'i'
else:
dunder_method = False
base = op
if base[-1] == '_':
inplace = True
base = base[:-1]
else:
inplace = False
r = BaseOperatorName(base=base, inplace=inplace, dunder_method=dunder_method)
assert str(r) == op, f'{str(r)} != {op}'
return r
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the best solution would be just to get this metadata on the freaking overloads cc @anjali411
proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) | ||
|
||
with no_dispatch(): | ||
real_out = func_overload(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've looked at having proxy tensors as a mode (hacky version of the change here). That still has subclassing, so @Chillee might know if there's something more fundamental but I can see what breaks locally if we just use attributes on vanilla tensors
We would need to wait for #75966 to land for any part of the mode solution to work at all in a non-hacky way because:
- The current tracer is saved as an attribute on the the proxies. The PR introduces the class PythonMode so that we can have saved the state on the mode. Without that PR, factory functions need the tracer to be saved as global state because they don't take any tensor inputs (this is how it works in the hacky version)
- (minor, not currently in the PR but should be) right now, torch_dispatch always warns if it's an instance method
|
||
# We need to do this so that parameters entering the `make_fx` context have | ||
# a reference to them (and also have requires_grad set on them correctly | ||
# I'm not actually sure if this is the right thing to do ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we figure this out lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to set requires_grad on it (at least, I don't remember why'd we need to and this code doesn't seem to do anything special about requires_grad...).
So I'll just remove the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main complexity here is that fundamentally, the .backward()
API allows you to "leak" things like tracing tensors.
For example, what should happen here?
model = resnet18()
def f(x):
model(x).sum().backward()
return x.grad
print(make_fx(f)(torch.randn(1,3,224,224, requires_grad=True)).code)
print([i.grad for i in model.parameters()])
Under the current implementation, this will leak ProxyTensors to outside the function. I think this is unavoidable, and just part of the contract of make_fx
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, we talked about this on VC just now. It seems to me that when non-proxy requires_grad=True
enter the scope of f, they should get detached and turned into leaf nodes, so the proxy doesn't escape. WDYT?
def make_fx(f, decomposition_table={}): | ||
@functools.wraps(f) | ||
def wrapped(*args): | ||
phs = pytree.tree_map(lambda x: fx.PH, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this implies f cannot have any non-tensor arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, make_fx
does not support non-tensor arguments (or well, after pytree flattening).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we wanted to, we could simply bake those in if that's preferred.
|
||
|
||
@contextmanager | ||
def no_dispatch(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the one from torch.testing._internal.logging_tensor.no_dispatch
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a terrible location for it.
The root of our evils is we don't have a "torch dispatch" module. I propose torch.dispatch. Any objections?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well this function should not exist?
Once the mode are fixed by Sam and users can call super(). There is no need for this function at all!
I do agree that we can have this namespace if needed though if we don't want to put it in torch.overrides
(which has an ok name to house both torch_function and torch_dispatch).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, we need this function to exist in some form for internal implementation purposes (e.g., actually implementing super())
I'm not very hot on torch.overrides because although the name is generic it really is very torch_function leaning right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not very hot on torch.overrides because although the name is generic it really is very torch_function leaning right now.
Sure, but we could change that? Also I plead guilding of adding torch.overrides.enable_reentrant_dispatch()
there. But we can move it somewhere else if we prefer a new namespace.
any remaining blockers for this? |
No, will land today, responded to some other comments yesterday. |
@pytorchbot merge this please |
Hey @Chillee. |
Summary: This is the `__torch_dispatch__` subclass used for tracing by AOTAutograd (https://github.com/pytorch/functorch/blob/main/functorch/_src/python_key.py). Given that a couple of folks are now interested in using this infra, it seems like a good idea to put it in core, and focus our efforts on a single implementation. I put this up as a WIP, just for discussion, but some questions off the top of my head. 1. What should be the intended way of extending this tracer? Should we define extension points, or should folks simply copy paste and modify? If we do define extension points, what are the extension points we should define? 2. There are some open questions about the way we're overriding FX to resolve some lingering issues (i.e. dealing with `nn.Parameter` and `call_module` calls). ezyang implemented an alternate version of this tensor in https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py, but it appears he ran into some issues with it that led to me submitting this implementation. That being said, I think some of the things over there should still be ported. 3. Given that this is going to be shared infra, what other features should we put in here? One that comes to mind is to allow for meta-tensor tracing (perhaps by default?), with a more solid fallback. Some of the other implementations (for reference on requirements). 1. FX2TRT: D34868356 (internal only) 2. Edge's? gmagogsfm GitHub CC: ezyang , jamesr66a , zou3519 , gmagogsfm, 842974287 Pull Request resolved: #74360 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/fc95eda285aed184e48fe9657b8c9c3bdb60f283 Reviewed By: malfet Differential Revision: D36134093 Pulled By: Chillee fbshipit-source-id: 1d3af72fa79ab97ec685f20fb47f8c00404fb1c3
This is the
__torch_dispatch__
subclass used for tracing by AOTAutograd (https://github.com/pytorch/functorch/blob/main/functorch/_src/python_key.py).Given that a couple of folks are now interested in using this infra, it seems like a good idea to put it in core, and focus our efforts on a single implementation.
I put this up as a WIP, just for discussion, but some questions off the top of my head.
nn.Parameter
andcall_module
calls). @ezyang implemented an alternate version of this tensor in https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py, but it appears he ran into some issues with it that led to me submitting this implementation. That being said, I think some of the things over there should still be ported.Some of the other implementations (for reference on requirements).
cc: @ezyang , @jamesr66a , @zou3519 , @gmagogsfm, @842974287