Skip to content

Conversation

Chillee
Copy link
Collaborator

@Chillee Chillee commented Mar 17, 2022

This is the __torch_dispatch__ subclass used for tracing by AOTAutograd (https://github.com/pytorch/functorch/blob/main/functorch/_src/python_key.py).

Given that a couple of folks are now interested in using this infra, it seems like a good idea to put it in core, and focus our efforts on a single implementation.

I put this up as a WIP, just for discussion, but some questions off the top of my head.

  1. What should be the intended way of extending this tracer? Should we define extension points, or should folks simply copy paste and modify? If we do define extension points, what are the extension points we should define?
  2. There are some open questions about the way we're overriding FX to resolve some lingering issues (i.e. dealing with nn.Parameter and call_module calls). @ezyang implemented an alternate version of this tensor in https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py, but it appears he ran into some issues with it that led to me submitting this implementation. That being said, I think some of the things over there should still be ported.
  3. Given that this is going to be shared infra, what other features should we put in here? One that comes to mind is to allow for meta-tensor tracing (perhaps by default?), with a more solid fallback.

Some of the other implementations (for reference on requirements).

  1. FX2TRT: D34868356 (internal only)
  2. Edge's? @gmagogsfm

cc: @ezyang , @jamesr66a , @zou3519 , @gmagogsfm, @842974287

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 17, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/e264eb93dd3c2241c244f040040d9ff09f5dabff/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-manywheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build ciflow/all, ciflow/cpu, ciflow/default, ciflow/libtorch, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
windows-binary-libtorch-debug ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-libtorch-release ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-wheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.3-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-xla-linux-bionic-py3.7-clang8 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla 🚫 skipped

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 17, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit c09f17c (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

try:
yield
finally:
del guard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to find a better spot for this XD

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol people have been copy-pasting this for literally a year

# ProxyTensor boundary.
# assert not elem.requires_grad or not torch.is_grad_enabled()

r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am reasonably sure it's safe to do super().__new__(elem) here; this will propagate gradients and is more correct than implicitly making leaves. But need to actually test this on functorch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only failure I get when trying this is when wrapping sparse tensors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah OK, this makes sense; need to make the default __new__ constructor accept sparse tensors too


@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
func = func_overload.overloadpacket
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another choice tracer users have to make: do you want to record the specific overloads, or the overload packets. Recording the overloads is strictly more information but many existing passes can't make use of this info in a useful way.

proxy_out = func(*proxy_args, **proxy_kwargs)

# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 schema APIs would come in handy here!

setattr(self.root, qualname, a)

return self.create_node('get_attr', qualname, (), {})
return super().create_arg(a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the part of Horace's implementation that I'd most like to get eyeballs from FX folks on, idk maybe @jamesr66a @zdevito ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is what's going on here that in a ProxyTensor context, you have nn.Parameter instances that are not installed as members of the module hierarchy and you'd like to intern those with a generated qualified name? Do you have an example of where this could happen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

So, if you're tracing something like

model = ...
def f(x):
    return model(x)

So, I think it's natural that in this case, the model parameters should just be constants from the perspective of the tracer. But since FX overrides parameters to look them up on the module hierarchy, this behavior needs to be modified.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I talked with @jamesr66a and we think some sort of approach along these lines is reasonable.

James has a preference for preserving as much of FX's current hooks as possible.

@ezyang
Copy link
Contributor

ezyang commented Mar 17, 2022

What should be the intended way of extending this tracer? Should we define extension points, or should folks simply copy paste and modify? If we do define extension points, what are the extension points we should define?

One thing is that you can subclass this class and tweak it (for example, decompositions could be implemented in this way). But I'm not sure I want to encourage people to be willy nilly subclassing here.

@ezyang
Copy link
Contributor

ezyang commented Apr 20, 2022

Any progress on this?

@Chillee
Copy link
Collaborator Author

Chillee commented Apr 20, 2022

I was sidetracked by the memory ownership issue, will get back to this this week.

@Chillee Chillee changed the title [WIP] Added proxy tensor Added proxy tensor Apr 23, 2022
@Chillee Chillee requested a review from ezyang April 23, 2022 01:11
try:
yield CURRENT_DECOMPOSITION_TABLE
finally:
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chillee do you think this should be the done and dusted API for decomps in AOTAutograd? I now agree that it should be possible to directly program extra decomps as part of the tracing process, but I still want control of the default set of decompositions (which we're gonna put in PyTorch core) to be based on what the backend declares it supports.

I suppose one could do this compositionally in the current API with something like this: decompose({**get_decomps(dont_decompose), **custom_decomps}). (this inverts the sense in which get_decomps is currently programmed; instead of asking for decomps, you say everything you DON'T want decomposed.) But this seems like a weird way to do the API if you're talking in terms of building an array of decompositions; it seems like maybe this should be abstracted away into a higher level API

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in practice... backends like NVFuser don't actually necessarily want to "try and decompose everything they don't support". For example, something like slice_backward decomposes into a new_zeros and a slice_scatter call. Now, NVFuser supports neither of these new ops, and since we've decomposed them they're now slower.

Essentially, one nice thing about this API is that the "default option" (i.e. decompose nothing) is very easy to reason about. Similarly, if you're only decomposing a few ops, it's very easy to reason about it.

I could be convinced that there could be a better API though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, agreed this is not good enough.

It seems like something better might be if we had metadata saying what a decomposition would decompose to. Then instead NVFuser can define "this is the stuff that I actually understand", and then we only do decompositions that (eventually) produce things we understand (but this is not 100% well specified, because what if a decomp produces some understandable stuff and some not understandable stuff, is it still profitable to decompose).

proxy_out = func(*proxy_args, **proxy_kwargs)

# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more robust test would be to replicate the parsing logic in model.py

    @staticmethod
    def parse(op: str) -> 'BaseOperatorName':
        assert op != ''
        assert not op.endswith('_out'), \
            "_out suffix is reserved and not permitted for operator names; " \
            "did you mean to specify an out overload name instead?"
        m = re.match(r'^__([^_]+)__$', op)
        if m is not None:
            dunder_method = True
            base = m.group(1)
            if any(base == f'i{n}' for n in AUGMENTED_ASSIGNMENT_NAMES):
                inplace = True
                base = base[1:]
            else:
                inplace = False
                # temporary, this is not intrinsically true but
                # has been historically true for dunder methods
                # we support  (but, if we ever got, say, __int__, this would
                # be wrong!)
                assert base[0] != 'i'
        else:
            dunder_method = False
            base = op
            if base[-1] == '_':
                inplace = True
                base = base[:-1]
            else:
                inplace = False
        r = BaseOperatorName(base=base, inplace=inplace, dunder_method=dunder_method)
        assert str(r) == op, f'{str(r)} != {op}'
        return r

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the best solution would be just to get this metadata on the freaking overloads cc @anjali411

proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])

with no_dispatch():
real_out = func_overload(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chillee wondering... would proxy tensor be better as a mode? We'd just chuck the proxies as attributes on vanilla tensors, no subclassing involved at all, and then just a mode to record IR and propagate proxies. Would be able to trace factories this way.

cc @samdow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've looked at having proxy tensors as a mode (hacky version of the change here). That still has subclassing, so @Chillee might know if there's something more fundamental but I can see what breaks locally if we just use attributes on vanilla tensors

We would need to wait for #75966 to land for any part of the mode solution to work at all in a non-hacky way because:

  1. The current tracer is saved as an attribute on the the proxies. The PR introduces the class PythonMode so that we can have saved the state on the mode. Without that PR, factory functions need the tracer to be saved as global state because they don't take any tensor inputs (this is how it works in the hacky version)
  2. (minor, not currently in the PR but should be) right now, torch_dispatch always warns if it's an instance method


# We need to do this so that parameters entering the `make_fx` context have
# a reference to them (and also have requires_grad set on them correctly
# I'm not actually sure if this is the right thing to do ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we figure this out lol

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to set requires_grad on it (at least, I don't remember why'd we need to and this code doesn't seem to do anything special about requires_grad...).

So I'll just remove the comment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main complexity here is that fundamentally, the .backward() API allows you to "leak" things like tracing tensors.

For example, what should happen here?

model = resnet18()
def f(x):
    model(x).sum().backward()
    return x.grad

print(make_fx(f)(torch.randn(1,3,224,224, requires_grad=True)).code)

print([i.grad for i in model.parameters()])

Under the current implementation, this will leak ProxyTensors to outside the function. I think this is unavoidable, and just part of the contract of make_fx.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, we talked about this on VC just now. It seems to me that when non-proxy requires_grad=True enter the scope of f, they should get detached and turned into leaf nodes, so the proxy doesn't escape. WDYT?

def make_fx(f, decomposition_table={}):
@functools.wraps(f)
def wrapped(*args):
phs = pytree.tree_map(lambda x: fx.PH, args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this implies f cannot have any non-tensor arguments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, make_fx does not support non-tensor arguments (or well, after pytree flattening).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we wanted to, we could simply bake those in if that's preferred.



@contextmanager
def no_dispatch():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the one from torch.testing._internal.logging_tensor.no_dispatch ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a terrible location for it.

The root of our evils is we don't have a "torch dispatch" module. I propose torch.dispatch. Any objections?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well this function should not exist?
Once the mode are fixed by Sam and users can call super(). There is no need for this function at all!

I do agree that we can have this namespace if needed though if we don't want to put it in torch.overrides (which has an ok name to house both torch_function and torch_dispatch).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we need this function to exist in some form for internal implementation purposes (e.g., actually implementing super())

I'm not very hot on torch.overrides because although the name is generic it really is very torch_function leaning right now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very hot on torch.overrides because although the name is generic it really is very torch_function leaning right now.

Sure, but we could change that? Also I plead guilding of adding torch.overrides.enable_reentrant_dispatch() there. But we can move it somewhere else if we prefer a new namespace.

@Chillee Chillee requested a review from ezyang April 27, 2022 08:41
@ezyang
Copy link
Contributor

ezyang commented May 3, 2022

any remaining blockers for this?

@Chillee
Copy link
Collaborator Author

Chillee commented May 3, 2022

No, will land today, responded to some other comments yesterday.

@Chillee
Copy link
Collaborator Author

Chillee commented May 3, 2022

@pytorchbot merge this please

@github-actions
Copy link
Contributor

github-actions bot commented May 3, 2022

Hey @Chillee.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@Chillee Chillee added the topic: not user facing topic category label May 3, 2022
facebook-github-bot pushed a commit that referenced this pull request May 5, 2022
Summary:
This is the `__torch_dispatch__` subclass used for tracing by AOTAutograd (https://github.com/pytorch/functorch/blob/main/functorch/_src/python_key.py).

Given that a couple of folks are now interested in using this infra, it seems like a good idea to put it in core, and focus our efforts on a single implementation.

I put this up as a WIP, just for discussion, but some questions off the top of my head.

1. What should be the intended way of extending this tracer? Should we define extension points, or should folks simply copy paste and modify? If we do define extension points, what are the extension points we should define?
2. There are some open questions about the way we're overriding FX to resolve some lingering issues (i.e. dealing with `nn.Parameter` and `call_module` calls). ezyang implemented an alternate version of this tensor in https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py, but it appears he ran into some issues with it that led to me submitting this implementation. That being said, I think some of the things over there should still be ported.
3. Given that this is going to be shared infra, what other features should we put in here? One that comes to mind is to allow for meta-tensor tracing (perhaps by default?), with a more solid fallback.

Some of the other implementations (for reference on requirements).

1. FX2TRT: D34868356 (internal only)
2. Edge's? gmagogsfm

GitHub CC:
ezyang , jamesr66a , zou3519 , gmagogsfm, 842974287

Pull Request resolved: #74360
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/fc95eda285aed184e48fe9657b8c9c3bdb60f283

Reviewed By: malfet

Differential Revision: D36134093

Pulled By: Chillee

fbshipit-source-id: 1d3af72fa79ab97ec685f20fb47f8c00404fb1c3
@github-actions github-actions bot deleted the proxy_tensor branch February 16, 2024 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants