Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-symbolic FX tracer #56862

Open
volcacius opened this issue Apr 24, 2021 · 11 comments
Open

Non-symbolic FX tracer #56862

volcacius opened this issue Apr 24, 2021 · 11 comments
Labels
module: fx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@volcacius
Copy link

volcacius commented Apr 24, 2021

馃殌 Feature

An alternative non-symbolic tracer for fx that wrap around values rather than replace them, so that it's possible to resolve calls to __iter__ , __next__ and __bool__ .
This would require the user to pass in representative inputs, similarly to what happens already for torch.jit.trace .

Motivation

In use cases like quantization where users pass in concrete representative inputs anyway for calibration purposes, it would make sense to take advantage of that to specialize an fx graph against them. It would provide a way to overcome the current limitations around unpacking and conditionals.

Alternatives

I can specify concrete_args but then it gets too specialized. Example:

class TestModule(nn.Module):

    def __init__(self):
        super().__init__()

    def call_cat(self, *args):
        out = torch.cat(args, dim=1)
        return out

    def forward(self, x):
        x = torch.split(x, 2, dim=1)
        y = self.call_cat(*x)
        return y


model = TestModule()
inp = torch.randn(1, 4, 10, 10)
out = symbolic_trace(model, concrete_args={'x': inp})
print(out)

You get as output (on 1.8.1):

TestModule()

def forward(self, ):
    _tensor_constant0 = self._tensor_constant0
    return _tensor_constant0

The kind of tracer I'm suggesting would record the fact that torch.split returns an iterable of two values. The generated graph then would then work with any input tensor that can be split in half along dim=1.

cc @ezyang

@volcacius
Copy link
Author

volcacius commented May 18, 2021

My take on the issue: https://github.com/Xilinx/brevitas/blob/fx/src/brevitas/fx/value_tracer.py

import torch
from torch import nn
from brevitas.fx import value_trace

class TestModule(nn.Module):

    def __init__(self):
        super().__init__()

    def call_cat(self, *args):
        out = torch.cat(args, dim=1)
        return out

    def forward(self, x):
        x = torch.split(x, 2, dim=1)
        y = self.call_cat(*x)
        return y


model = TestModule()
inp = torch.randn(1, 4, 10, 10)
out = value_trace(model, concrete_args={'x': inp})
print(out)

I get as output (on 1.8.1):

TestModule()
import torch
def forward(self, x):
    split_1 = torch.functional.split(x, 2, dim = 1);  x = None
    iter_1 = iter(split_1);  split_1 = None
    next_1 = iter_1.__next__()
    next_2 = iter_1.__next__();  iter_1 = None
    cat_1 = torch.cat((next_1, next_2), dim = 1);  next_1 = next_2 = None
    return cat_1

I'm reinterpreting the way concrete_args are used but it could also be a separate set of args. To preserve compatibility with the symbolic interface, they not required as long as the tracing doesn't go through a bool or iter or next.

@jamesr66a
Copy link
Collaborator

Hi @volcacius,

I get a 404 on that link. Is that a private repo?

@volcacius
Copy link
Author

Apologies, it's a branch that got merged. Here is a working link https://github.com/Xilinx/brevitas/blob/master/src/brevitas/fx/value_tracer.py . It's still based on 1.8.1 fx, I haven't had time to upgrade it to 1.9 yet.

@jamesr66a
Copy link
Collaborator

So I think something where various properties of the program are taken to be specialized during tracing can make sense in some situations but can be confusing in other situations. With jit.trace, folks have had issues with e.g. shapes or device values getting baked in and causing hard-to-debug correctness/soundness issues in the traced/generated code. I know of one such use case in production at FB where they explicitly switched to FX to avoid these issues. My take is that making such things configurable, e.g. via annotations or flags can make this a better experience. For example, recently someone contributed specialization over the shapes of the model parameters:

param_shapes_constant: bool = False) -> None:

I think we were also kicking around various designs for annotation APIs for shape specialization, let me see if I can dig up those docs somewhere.

It's worth noting that JAX, for example, does full shape/value metadata propagation/specialization and they are able to ensure correctness because all of their transforms appear within the context of a runtime that re-checks those properties and redoes the transformations if something has changed. We've elected not to do that in FX because we believe that makes things more confusing for the end user

@RalphMao
Copy link

@volcacius I was thinking exactly the same functionality as you presented. May I know how it goes with the newer versions of fx?

@volcacius
Copy link
Author

I haven't had a chance to look at that yet sorry.

@ezyang ezyang added module: fx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed oncall: fx labels Jul 20, 2022
@ezyang
Copy link
Contributor

ezyang commented Jul 20, 2022

The operators you will get in the graph will be different, but try ProxyTensor tracing: torch.fx.experimental.proxy_tensor import make_fx

@ezyang
Copy link
Contributor

ezyang commented Jul 20, 2022

related #63076

@laserkelvin
Copy link

The operators you will get in the graph will be different, but try ProxyTensor tracing: torch.fx.experimental.proxy_tensor import make_fx

Could you please elaborate on this? As you pointed out in the conversation in #63076, make_fx should get you down to the aten calls. I can't infer what to do with the wrapped make_fx module from the source, and functionally the attributes/methods seem identical to fx.GraphModule.

I'm interested in trying to map the symbolic graph to which operators will be called, without necessarily having to run concrete data through a model.

@ezyang
Copy link
Contributor

ezyang commented Jan 11, 2023

make_fx gives you a GraphModule, like symbolic_trace. So you can do whatever it is you wanted to do with this module that you would have done with symbolic trace. The operators in the graph will just be different; torch vs aten.

make_fx works without concrete data. Pass it a bunch of meta tensors (or fake tensors) and it will trace without running.

@ezyang
Copy link
Contributor

ezyang commented Jan 11, 2023

You DO have to know the sizes of all your inputs though (in the form of the meta tensor), that's the price of entry.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants