-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Prims+NvFuser Backend Prototype #80591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit b411b3e (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this first caching prototype. While it works fine for different inputs with the same shape, caching based only on the GraphModule
is error-prone:
Case 1 (changing number of dimensions):
In [1]: import torch
In [2]: from torch._prims.executor import execute
In [3]: from torch._prims.context import TorchRefsMode
In [4]: from torch.fx.experimental.proxy_tensor import make_fx
In [5]: def func(a):
...: return torch.sigmoid(a)
...:
In [6]: a = torch.randn(3, 3, device='cuda')
In [7]: with TorchRefsMode.push():
...: gm = make_fx(func)(a)
...:
In [8]: %%time
...: execute(gm, a, executor="nvfuser")
...:
...:
CPU times: user 143 ms, sys: 21.1 ms, total: 164 ms
Wall time: 164 ms
Out[8]:
tensor([[0.1372, 0.4281, 0.6746],
[0.5049, 0.6179, 0.5654],
[0.7701, 0.3155, 0.6855]], device='cuda:0')
In [9]: %%time
...: execute(gm, a, executor="nvfuser");
...:
...:
CPU times: user 108 µs, sys: 93 µs, total: 201 µs
Wall time: 211 µs # Caching works as expected
Out[9]:
tensor([[0.1372, 0.4281, 0.6746],
[0.5049, 0.6179, 0.5654],
[0.7701, 0.3155, 0.6855]], device='cuda:0')
In [10]: b = torch.randn(2, 3, 3, device="cuda")
In [11]: execute(gm, b, executor="nvfuser")
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [12], in <cell line: 1>()
----> 1 execute(gm, b, executor="nvfuser")
File ~/dev/pytorch/master/torch/_prims/executor.py:85, in execute(gm, executor, *args)
81 # store fusion in the cache
82 execute.fusion_cache[gm] = (fusion, unflatten_spec)
84 return tree_unflatten(
---> 85 fusion.execute(tuple(arg for arg in args if isinstance(arg, torch.Tensor))),
86 unflatten_spec,
87 )
89 msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
90 executor
91 )
92 raise ValueError(msg)
RuntimeError: at_tensor.ndimension() == static_cast<int>(root_domain.size()) INTERNAL ASSERT FAILED at "/home/iyashchuk/dev/pytorch/master/torch/csrc/jit/codegen/cuda/evaluator_common.cpp":497, please report a bug to PyTorch. Something went wrong configuring launch. Inputs do not match.
Case 2 (changing strides):
In [1]: import torch
In [2]: from torch._prims.executor import execute
In [3]: from torch._prims.context import TorchRefsMode
In [4]: from torch.fx.experimental.proxy_tensor import make_fx
In [5]: a = torch.randn(3, 3, device='cuda')
In [6]: def func(a):
...: return torch.sigmoid(a)
...:
In [7]: with TorchRefsMode.push():
...: gm = make_fx(func)(a)
...:
In [8]: execute(gm, a, executor="nvfuser")
Out[8]:
tensor([[0.6656, 0.2625, 0.3951],
[0.4097, 0.4118, 0.6406],
[0.0929, 0.5299, 0.3691]], device='cuda:0')
# strides of a are (3, 1) while strides of a.mT are (1, 3)
In [9]: execute(gm, a.mT, executor="nvfuser") # same output but should be different
Out[9]:
tensor([[0.6656, 0.2625, 0.3951],
[0.4097, 0.4118, 0.6406],
[0.0929, 0.5299, 0.3691]], device='cuda:0')
In [10]: execute(gm, a.mT, executor="aten")
Out[10]:
tensor([[0.6656, 0.4097, 0.0929],
[0.2625, 0.4118, 0.5299],
[0.3951, 0.6406, 0.3691]], device='cuda:0')
In [11]: torch.allclose(execute(gm, a.mT, executor="nvfuser"), execute(gm, a.mT, executor="aten"))
Out[11]: False
I have a PR that takes care of nvFuser's Fusion creation caching, #80525, please take a look.
Hi @IvanYashchuk, you are absolutely right that the Fusion cache is a native implementation that would fail in many cases. It's actually a toy cache that I use to avoid re-compilation for getting some reasonable performance measurement. Please ignore this cache for now... I will revert it before check in. |
# TODO: use a better way to identify fused submodule | ||
if "fused_" in node.name: | ||
fused_module = getattr(fused_graph_module, node.name) | ||
fused_module._wrapped_call = self.lower_to_prims_and_execute |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's going on here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Effectively, the call function of fused_module is replaced with self.lower_to_prims_and_execute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternate method to do this is to replace the entire submodule wholesale, example of it here:
pytorch/test/test_dynamo_cudagraphs.py
Lines 166 to 179 in db6df04
class ApplyCudaGraphs(ProxyTensorInterpreter): | |
# All module calls are assumed to be fusion groups, since | |
# this is post AOTAutograd which would have squashed all the modules. | |
# Module assumed to be called only once. | |
def call_module(self, target, args, kwargs): | |
assert not kwargs | |
# Don't trace the module, but do run the module to get the correct | |
# out result | |
out = super().call_module(target, tree_map(unwrap_elem, args), kwargs) | |
submod = self.module.get_submodule(target) | |
mutated_inputs = FindInputMutations(submod)(*map(unwrap_elem, args)) | |
# smh the module didn't get transferred wut | |
self.new_module.add_submodule(target, CudaGraphModule(submod, mutated_inputs)) | |
return wrap_output(out, torch.fx.Proxy(self.new_graph.call_module(target, tree_map(unwrap_proxy_node, args), tree_map(unwrap_proxy_node, kwargs)), self.tracer)) |
I'm not so keen on monkeypatching the _wrapped_call
method since it violates expectations (you're not expecting a method on the class to have gotten overwritten)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some aspects of this new version which I'm not entirely sure if I've gotten quite right so more golfing on this pattern would good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another consideration/feature request:
This fused module needs to run in two modes:
- If the compilation is successful, it should go through the prims_decomp+nvfuser path.
- In case compilation fails, it should fallback to run the original aten module. In case backend doesn't handle some specific input values, it should also fallback to aten module: [Prims+NVFuser] Supports 0-sized inputs #80231
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like we can handle such fallback with a try-catch pattern in the call_module function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that seems like the right thing to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will update to use ProxyTensorInterpreter when it's put in place.
torch/_prims/executor.py
Outdated
return decorate | ||
|
||
|
||
@static_vars(fusion_cache={}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there isn't a need for a cache if you integrate directly with torchdynamo. It might be worth going there for the E2E
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a list of functionality that torchdynamo provides or aims to provide?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how I would give such a list; it's more about understanding the compilation model torchdynamo has. The compilation callback that dynamo calls once it has acquired an FX graph is expected to return a Python callable corresponding to the "compiled" version of the code; dynamo will take that return result and directly save it in its cache (see https://github.com/pytorch/torchdynamo/#guards ) so the next time the guards all match we will jump straight to this callable.
To get a better sense for what dynamo does, I'd recommend reading the README in the repo. You could also watch my livestream about it https://www.youtube.com/watch?v=egZB5Uxki0I if you like that sort of thing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the guards link! I watched the live stream and skimmed through the README I think the word "cache" is missing that's why I didn't find it when tried to look it up.
For the list, I was thinking of a list of problems that dynamo resolves as opposed to just using __torch_dispatch__
based tracer like torch.fx.experimental.proxy_tensor.make_fx
to acquire the FX graph. I'd like to understand in what situations make_fx
fails to create an FX graph while dynamo succeeds, one such situation is having calls to an external library (like NumPy https://github.com/pytorch/torchdynamo/blob/77a7808515d95a3b84b5ffc3943bcbd1960505b8/torchdynamo/allowed_functions.py#L210-L211)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's not right to think about this as a coverage thing, because in fact dynamo captures less stuff than make_fx
. It is a soundness thing: you can always capture a trace with make fx but dynamo will tell you under what conditions the trace is sound to reuse later, because the Python code might execute differently and give you a different trace later. Neither dynamo nor make_fx will create a graph with numpy calls in it, but dynamo will graph break whereas make_fx will bake in a tensor constant corresponding to the computed numpy result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
return torch._prims.convert_element_type(self, dtype) | ||
|
||
# decomposition_table currently contains both aten2aten and aten2prim decomposition | ||
# this is a hack to seperate them, as we only need aten2prim decomposition for nvfuser-supported aten graph lowering |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if you need to chain an aten2aten decomp into an aten2prim decomp to get nvfuser going?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If an op needs aten2aten->aten2prims decomps to work together, by our current design, aten2aten will happen at the dynamo+aot layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, we have to be careful about CompositeImplicitAutograd aten2aten decompositions, and what we used to call torch._decomp aten2aten decompositions. If the original aten operator is a fused one, we also don't want to apply the aten2aten decomposition until we know that a compiler is going to be able to nail either
|
||
class NvFuserBackend: | ||
def __init__(self): | ||
self.supported_ops = NvFuserOperatorSupport() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't actually have any mutable state so it would be clearer if this was just a singleton object (actually, it's not really clear to me why operator supports need to be a class in the first place, it's literally just a callable without any state.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OperatorSupport was implemented as an interface class in the first place. NvFuserOperatorSupport is a subclass for OperatorSupport.
Does NvFuserBackend needs to be a singleton?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, updated prescription (but it doesn't have to be done in this PR): I think OperatorSupport should be represented just as a Callable, and you're welcome to pass in a function or a proper class that implements __call__
. To represent this in typing you can write an alias for Callable[...]
.
This is Python, not Java; we don't have to create an interface class just for a function, we can just pass in a function directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will update OperatorSupport in the next PR.
self.partitioner_cache: Dict[GraphModule, GraphModule] = {} | ||
|
||
# TODO: this is a naive implementation of cache without proper guard, this will only work for identical inputs | ||
self.prim_decomp_cache: Dict[GraphModule, GraphModule] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be helpful to have a sketch for what the envisioned proper design here would be
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here |
Merge failed due to Refusing to merge as mandatory check(s) pull failed for rule superuser |
@pytorchbot rebase -s |
1 similar comment
@pytorchbot rebase -s |
@pytorchbot successfully started a rebase job. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/2634204553 |
@pytorchbot rebase -b master |
@pytorchbot successfully started a rebase job. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/2634240713 |
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here |
Hey @SherlockNoMad. |
Summary: This PR integrates FX graph partitioner + Aten2Prims DecompositionInterpreter + Prims' TraceExecutor + naive caches for nvFuser. Pull Request resolved: #80591 Approved by: https://github.com/jjsjann123, https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/fc10a6372736dd8bac1a894a92beebbafee74f01 Reviewed By: mehtanirav Differential Revision: D37749351 Pulled By: SherlockNoMad fbshipit-source-id: fb6eabbca954c58934bd10894f6420d87e413e96
This PR integrates FX graph partitioner + Aten2Prims DecompositionInterpreter + Prims' TraceExecutor + naive caches for nvFuser.