-
Notifications
You must be signed in to change notification settings - Fork 105
Add trace decomposition, invoke from C++ #740
Add trace decomposition, invoke from C++ #740
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good - might be nicer for the decomp stuff to be a boxed fallback instead.
functorch/csrc/BatchRulesViews.cpp
Outdated
| return std::make_tuple(at::_unsafe_view(self_, view_size), 0); | ||
| } | ||
|
|
||
| Tensor trace_decomp(const Tensor& self) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make this a macro? or perhaps a boxed fallback?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done... depends on pytorch/pytorch#76493
functorch/_src/vmap.py
Outdated
| def register_jit_decomposition(decomp): | ||
| assert decomp in decomposition_table, f"could not find {decomp}" | ||
| decomp_fn = decomposition_table[decomp] | ||
| scripted_decomp_fn = torch.jit.script(decomp_fn) | ||
| torch.jit._register_decomposition(decomp, scripted_decomp_fn.graph) | ||
|
|
||
|
|
||
| register_jit_decomposition(torch.ops.aten.trace.default) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decompositions aren't a vmap-specific thing, we also would like to apply them to our other transforms (jvp for now). Maybe a better place to put these lines would be eager_transforms.py ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is super cool! I can't wait for this to be landed; my plan on record previously was to transcribe everything in decompositions.py into C++ that I needed
|
@zou3519 couldn't you have also just called into the Python directly |
|
@ezyang, we have three short-term options really:
There's an open question around what the medium and long term state of this should be. That I am not sure about.
|
|
I think the appropriate long term attitude is that decomps get traced through in a JIT (e.g., AOTAutograd via torchdynamo) and therefore it doesn't matter that they directly get run in Python. |
…ch#740) * [WIP] add trace example * use better api * Move decomp registration * fix flake * oop * use boxed fallback * rename * move and fix string handling Co-authored-by: Elias Ellison <eellison@devfair044.h1.fair>
…ch#740) * [WIP] add trace example * use better api * Move decomp registration * fix flake * oop * use boxed fallback * rename * move and fix string handling Co-authored-by: Elias Ellison <eellison@devfair044.h1.fair>
TODO - integrate registrations of the decomposition w/jit into the
decompositions.pyfileRelies on pytorch/pytorch#76252
cc @zou3519 @Chillee