-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Non-symbolic FX tracer #56862
Comments
My take on the issue: https://github.com/Xilinx/brevitas/blob/fx/src/brevitas/fx/value_tracer.py import torch
from torch import nn
from brevitas.fx import value_trace
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def call_cat(self, *args):
out = torch.cat(args, dim=1)
return out
def forward(self, x):
x = torch.split(x, 2, dim=1)
y = self.call_cat(*x)
return y
model = TestModule()
inp = torch.randn(1, 4, 10, 10)
out = value_trace(model, concrete_args={'x': inp})
print(out) I get as output (on 1.8.1): TestModule()
import torch
def forward(self, x):
split_1 = torch.functional.split(x, 2, dim = 1); x = None
iter_1 = iter(split_1); split_1 = None
next_1 = iter_1.__next__()
next_2 = iter_1.__next__(); iter_1 = None
cat_1 = torch.cat((next_1, next_2), dim = 1); next_1 = next_2 = None
return cat_1 I'm reinterpreting the way concrete_args are used but it could also be a separate set of args. To preserve compatibility with the symbolic interface, they not required as long as the tracing doesn't go through a bool or iter or next. |
Hi @volcacius, I get a 404 on that link. Is that a private repo? |
Apologies, it's a branch that got merged. Here is a working link https://github.com/Xilinx/brevitas/blob/master/src/brevitas/fx/value_tracer.py . It's still based on 1.8.1 fx, I haven't had time to upgrade it to 1.9 yet. |
So I think something where various properties of the program are taken to be specialized during tracing can make sense in some situations but can be confusing in other situations. With pytorch/torch/fx/_symbolic_trace.py Line 187 in 7f1b672
I think we were also kicking around various designs for annotation APIs for shape specialization, let me see if I can dig up those docs somewhere. It's worth noting that JAX, for example, does full shape/value metadata propagation/specialization and they are able to ensure correctness because all of their transforms appear within the context of a runtime that re-checks those properties and redoes the transformations if something has changed. We've elected not to do that in FX because we believe that makes things more confusing for the end user |
@volcacius I was thinking exactly the same functionality as you presented. May I know how it goes with the newer versions of fx? |
I haven't had a chance to look at that yet sorry. |
The operators you will get in the graph will be different, but try ProxyTensor tracing: torch.fx.experimental.proxy_tensor import make_fx |
related #63076 |
Could you please elaborate on this? As you pointed out in the conversation in #63076, I'm interested in trying to map the symbolic graph to which operators will be called, without necessarily having to run concrete data through a model. |
make_fx works without concrete data. Pass it a bunch of meta tensors (or fake tensors) and it will trace without running. |
You DO have to know the sizes of all your inputs though (in the form of the meta tensor), that's the price of entry. |
馃殌 Feature
An alternative non-symbolic tracer for fx that wrap around values rather than replace them, so that it's possible to resolve calls to
__iter__
,__next__
and__bool__
.This would require the user to pass in representative inputs, similarly to what happens already for torch.jit.trace .
Motivation
In use cases like quantization where users pass in concrete representative inputs anyway for calibration purposes, it would make sense to take advantage of that to specialize an fx graph against them. It would provide a way to overcome the current limitations around unpacking and conditionals.
Alternatives
I can specify
concrete_args
but then it gets too specialized. Example:You get as output (on 1.8.1):
The kind of tracer I'm suggesting would record the fact that
torch.split
returns an iterable of two values. The generated graph then would then work with any input tensor that can be split in half alongdim=1
.cc @ezyang
The text was updated successfully, but these errors were encountered: