Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Conversation

eellison
Copy link
Contributor

@eellison eellison commented Apr 22, 2022

TODO - integrate registrations of the decomposition w/jit into the decompositions.py file

Relies on pytorch/pytorch#76252

cc @zou3519 @Chillee

@eellison eellison changed the title [WIP] add trace decomposition invocation from C++ Add trace decomposition, invoke from C++ Apr 27, 2022
Copy link
Contributor

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good - might be nicer for the decomp stuff to be a boxed fallback instead.

return std::make_tuple(at::_unsafe_view(self_, view_size), 0);
}

Tensor trace_decomp(const Tensor& self) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this a macro? or perhaps a boxed fallback?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done... depends on pytorch/pytorch#76493

Comment on lines 40 to 47
def register_jit_decomposition(decomp):
assert decomp in decomposition_table, f"could not find {decomp}"
decomp_fn = decomposition_table[decomp]
scripted_decomp_fn = torch.jit.script(decomp_fn)
torch.jit._register_decomposition(decomp, scripted_decomp_fn.graph)


register_jit_decomposition(torch.ops.aten.trace.default)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decompositions aren't a vmap-specific thing, we also would like to apply them to our other transforms (jvp for now). Maybe a better place to put these lines would be eager_transforms.py ?

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super cool! I can't wait for this to be landed; my plan on record previously was to transcribe everything in decompositions.py into C++ that I needed

@ezyang
Copy link
Contributor

ezyang commented Apr 28, 2022

@zou3519 couldn't you have also just called into the Python directly

@zou3519
Copy link
Contributor

zou3519 commented Apr 28, 2022

@ezyang, we have three short-term options really:

  1. Transcribe the Python into C++. This is what I'm doing right now (e.g., in Use nll_loss_backward decomposition for jvp transform #764) because I didn't want to think too hard
  2. Call directly from C++ into Python. Requires us to either use Anjali's Python Op registration logic or implement something similar (a Python function that can be called from C++)
  3. Use this PR (which is getting the decompositions into TorchScript and then being able to call the TorchScript from C++).

There's an open question around what the medium and long term state of this should be. That I am not sure about.

  • for medium-term: It seems more lightweight to call into Python directly, but TorchScript is potentially faster due to not needing to get back into Python interpreter. Need to benchmark to see if this matters at all.
  • for long-term: ???

@Chillee Chillee merged commit a1f569a into pytorch:main May 2, 2022
@ezyang
Copy link
Contributor

ezyang commented May 3, 2022

I think the appropriate long term attitude is that decomps get traced through in a JIT (e.g., AOTAutograd via torchdynamo) and therefore it doesn't matter that they directly get run in Python.

zou3519 pushed a commit to zou3519/pytorch that referenced this pull request Jul 20, 2022
…ch#740)

* [WIP] add trace example

* use better api

* Move decomp registration

* fix flake

* oop

* use boxed fallback

* rename

* move and fix string handling

Co-authored-by: Elias Ellison <eellison@devfair044.h1.fair>
bigfootjon pushed a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
…ch#740)

* [WIP] add trace example

* use better api

* Move decomp registration

* fix flake

* oop

* use boxed fallback

* rename

* move and fix string handling

Co-authored-by: Elias Ellison <eellison@devfair044.h1.fair>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants