Skip to content

Conversation

SherlockNoMad
Copy link
Contributor

This PR integrates FX graph partitioner + Aten2Prims DecompositionInterpreter + Prims' TraceExecutor + naive caches for nvFuser.

davidberard98 and others added 3 commits June 29, 2022 10:28
If an aten -> prim decomposition is needed *after* the initial trace
with make_fx, this interpreter can be used to perform the decomposition.

ghstack-source-id: 3bf5c97
Pull Request resolved: #79989
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 30, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit b411b3e (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this first caching prototype. While it works fine for different inputs with the same shape, caching based only on the GraphModule is error-prone:
Case 1 (changing number of dimensions):

In [1]: import torch
In [2]: from torch._prims.executor import execute
In [3]: from torch._prims.context import TorchRefsMode
In [4]: from torch.fx.experimental.proxy_tensor import make_fx
In [5]: def func(a):
   ...:     return torch.sigmoid(a)
   ...:
In [6]: a = torch.randn(3, 3, device='cuda')
In [7]: with TorchRefsMode.push():
   ...:     gm = make_fx(func)(a)
   ...:
In [8]: %%time
   ...: execute(gm, a, executor="nvfuser")
   ...:
   ...:
CPU times: user 143 ms, sys: 21.1 ms, total: 164 ms
Wall time: 164 ms
Out[8]:
tensor([[0.1372, 0.4281, 0.6746],
        [0.5049, 0.6179, 0.5654],
        [0.7701, 0.3155, 0.6855]], device='cuda:0')
In [9]: %%time
    ...: execute(gm, a, executor="nvfuser");
    ...:
    ...:
CPU times: user 108 µs, sys: 93 µs, total: 201 µs
Wall time: 211 µs  # Caching works as expected
Out[9]:
tensor([[0.1372, 0.4281, 0.6746],
        [0.5049, 0.6179, 0.5654],
        [0.7701, 0.3155, 0.6855]], device='cuda:0')
In [10]: b = torch.randn(2, 3, 3, device="cuda")
In [11]: execute(gm, b, executor="nvfuser")
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [12], in <cell line: 1>()
----> 1 execute(gm, b, executor="nvfuser")

File ~/dev/pytorch/master/torch/_prims/executor.py:85, in execute(gm, executor, *args)
     81             # store fusion in the cache
     82             execute.fusion_cache[gm] = (fusion, unflatten_spec)
     84     return tree_unflatten(
---> 85         fusion.execute(tuple(arg for arg in args if isinstance(arg, torch.Tensor))),
     86         unflatten_spec,
     87     )
     89 msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
     90     executor
     91 )
     92 raise ValueError(msg)

RuntimeError: at_tensor.ndimension() == static_cast<int>(root_domain.size()) INTERNAL ASSERT FAILED at "/home/iyashchuk/dev/pytorch/master/torch/csrc/jit/codegen/cuda/evaluator_common.cpp":497, please report a bug to PyTorch. Something went wrong configuring launch. Inputs do not match.

Case 2 (changing strides):

In [1]: import torch
In [2]: from torch._prims.executor import execute
In [3]: from torch._prims.context import TorchRefsMode
In [4]: from torch.fx.experimental.proxy_tensor import make_fx
In [5]: a = torch.randn(3, 3, device='cuda')
In [6]: def func(a):
   ...:     return torch.sigmoid(a)
   ...:
In [7]: with TorchRefsMode.push():
   ...:     gm = make_fx(func)(a)
   ...:
In [8]: execute(gm, a, executor="nvfuser")
Out[8]:
tensor([[0.6656, 0.2625, 0.3951],
        [0.4097, 0.4118, 0.6406],
        [0.0929, 0.5299, 0.3691]], device='cuda:0')

# strides of a are (3, 1) while strides of a.mT are (1, 3)
In [9]: execute(gm, a.mT, executor="nvfuser") # same output but should be different
Out[9]:
tensor([[0.6656, 0.2625, 0.3951],
        [0.4097, 0.4118, 0.6406],
        [0.0929, 0.5299, 0.3691]], device='cuda:0')

In [10]: execute(gm, a.mT, executor="aten")
Out[10]:
tensor([[0.6656, 0.4097, 0.0929],
        [0.2625, 0.4118, 0.5299],
        [0.3951, 0.6406, 0.3691]], device='cuda:0')

In [11]: torch.allclose(execute(gm, a.mT, executor="nvfuser"), execute(gm, a.mT, executor="aten"))
Out[11]: False

I have a PR that takes care of nvFuser's Fusion creation caching, #80525, please take a look.

@SherlockNoMad
Copy link
Contributor Author

Hi @IvanYashchuk, you are absolutely right that the Fusion cache is a native implementation that would fail in many cases. It's actually a toy cache that I use to avoid re-compilation for getting some reasonable performance measurement.

Please ignore this cache for now... I will revert it before check in.

# TODO: use a better way to identify fused submodule
if "fused_" in node.name:
fused_module = getattr(fused_graph_module, node.name)
fused_module._wrapped_call = self.lower_to_prims_and_execute
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Effectively, the call function of fused_module is replaced with self.lower_to_prims_and_execute.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternate method to do this is to replace the entire submodule wholesale, example of it here:

class ApplyCudaGraphs(ProxyTensorInterpreter):
# All module calls are assumed to be fusion groups, since
# this is post AOTAutograd which would have squashed all the modules.
# Module assumed to be called only once.
def call_module(self, target, args, kwargs):
assert not kwargs
# Don't trace the module, but do run the module to get the correct
# out result
out = super().call_module(target, tree_map(unwrap_elem, args), kwargs)
submod = self.module.get_submodule(target)
mutated_inputs = FindInputMutations(submod)(*map(unwrap_elem, args))
# smh the module didn't get transferred wut
self.new_module.add_submodule(target, CudaGraphModule(submod, mutated_inputs))
return wrap_output(out, torch.fx.Proxy(self.new_graph.call_module(target, tree_map(unwrap_proxy_node, args), tree_map(unwrap_proxy_node, kwargs)), self.tracer))

I'm not so keen on monkeypatching the _wrapped_call method since it violates expectations (you're not expecting a method on the class to have gotten overwritten)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some aspects of this new version which I'm not entirely sure if I've gotten quite right so more golfing on this pattern would good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another consideration/feature request:

This fused module needs to run in two modes:

  • If the compilation is successful, it should go through the prims_decomp+nvfuser path.
  • In case compilation fails, it should fallback to run the original aten module. In case backend doesn't handle some specific input values, it should also fallback to aten module: [Prims+NVFuser] Supports 0-sized inputs #80231

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like we can handle such fallback with a try-catch pattern in the call_module function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that seems like the right thing to do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will update to use ProxyTensorInterpreter when it's put in place.

return decorate


@static_vars(fusion_cache={})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there isn't a need for a cache if you integrate directly with torchdynamo. It might be worth going there for the E2E

Copy link
Collaborator

@IvanYashchuk IvanYashchuk Jul 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a list of functionality that torchdynamo provides or aims to provide?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how I would give such a list; it's more about understanding the compilation model torchdynamo has. The compilation callback that dynamo calls once it has acquired an FX graph is expected to return a Python callable corresponding to the "compiled" version of the code; dynamo will take that return result and directly save it in its cache (see https://github.com/pytorch/torchdynamo/#guards ) so the next time the guards all match we will jump straight to this callable.

To get a better sense for what dynamo does, I'd recommend reading the README in the repo. You could also watch my livestream about it https://www.youtube.com/watch?v=egZB5Uxki0I if you like that sort of thing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the guards link! I watched the live stream and skimmed through the README I think the word "cache" is missing that's why I didn't find it when tried to look it up.

For the list, I was thinking of a list of problems that dynamo resolves as opposed to just using __torch_dispatch__ based tracer like torch.fx.experimental.proxy_tensor.make_fx to acquire the FX graph. I'd like to understand in what situations make_fx fails to create an FX graph while dynamo succeeds, one such situation is having calls to an external library (like NumPy https://github.com/pytorch/torchdynamo/blob/77a7808515d95a3b84b5ffc3943bcbd1960505b8/torchdynamo/allowed_functions.py#L210-L211)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not right to think about this as a coverage thing, because in fact dynamo captures less stuff than make_fx. It is a soundness thing: you can always capture a trace with make fx but dynamo will tell you under what conditions the trace is sound to reuse later, because the Python code might execute differently and give you a different trace later. Neither dynamo nor make_fx will create a graph with numpy calls in it, but dynamo will graph break whereas make_fx will bake in a tensor constant corresponding to the computed numpy result.

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

return torch._prims.convert_element_type(self, dtype)

# decomposition_table currently contains both aten2aten and aten2prim decomposition
# this is a hack to seperate them, as we only need aten2prim decomposition for nvfuser-supported aten graph lowering
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you need to chain an aten2aten decomp into an aten2prim decomp to get nvfuser going?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If an op needs aten2aten->aten2prims decomps to work together, by our current design, aten2aten will happen at the dynamo+aot layer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we have to be careful about CompositeImplicitAutograd aten2aten decompositions, and what we used to call torch._decomp aten2aten decompositions. If the original aten operator is a fused one, we also don't want to apply the aten2aten decomposition until we know that a compiler is going to be able to nail either


class NvFuserBackend:
def __init__(self):
self.supported_ops = NvFuserOperatorSupport()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't actually have any mutable state so it would be clearer if this was just a singleton object (actually, it's not really clear to me why operator supports need to be a class in the first place, it's literally just a callable without any state.)

Copy link
Contributor Author

@SherlockNoMad SherlockNoMad Jul 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OperatorSupport was implemented as an interface class in the first place. NvFuserOperatorSupport is a subclass for OperatorSupport.

Does NvFuserBackend needs to be a singleton?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, updated prescription (but it doesn't have to be done in this PR): I think OperatorSupport should be represented just as a Callable, and you're welcome to pass in a function or a proper class that implements __call__. To represent this in typing you can write an alias for Callable[...].

This is Python, not Java; we don't have to create an interface class just for a function, we can just pass in a function directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will update OperatorSupport in the next PR.

self.partitioner_cache: Dict[GraphModule, GraphModule] = {}

# TODO: this is a naive implementation of cache without proper guard, this will only work for identical inputs
self.prim_decomp_cache: Dict[GraphModule, GraphModule] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to have a sketch for what the envisioned proper design here would be

@SherlockNoMad
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Merge failed due to Refusing to merge as mandatory check(s) pull failed for rule superuser
Raised by https://github.com/pytorch/pytorch/actions/runs/2633954246

@SherlockNoMad
Copy link
Contributor Author

SherlockNoMad commented Jul 8, 2022

@pytorchbot rebase -s

1 similar comment
@SherlockNoMad
Copy link
Contributor Author

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/80591/head returned non-zero exit code 1

Rebasing (1/8)
Auto-merging test/test_proxy_tensor.py
CONFLICT (content): Merge conflict in test/test_proxy_tensor.py
Auto-merging torch/fx/experimental/proxy_tensor.py
error: could not apply df348dc110... Interpreter for decomposing aten -> prims
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
Could not apply df348dc110... Interpreter for decomposing aten -> prims

Raised by https://github.com/pytorch/pytorch/actions/runs/2634204553

@SherlockNoMad
Copy link
Contributor Author

@pytorchbot rebase -b master

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/master pull/80591/head returned non-zero exit code 1

Rebasing (1/8)
Auto-merging test/test_proxy_tensor.py
CONFLICT (content): Merge conflict in test/test_proxy_tensor.py
Auto-merging torch/fx/experimental/proxy_tensor.py
error: could not apply df348dc110... Interpreter for decomposing aten -> prims
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
Could not apply df348dc110... Interpreter for decomposing aten -> prims

Raised by https://github.com/pytorch/pytorch/actions/runs/2634240713

@SherlockNoMad
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

github-actions bot commented Jul 8, 2022

Hey @SherlockNoMad.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jul 12, 2022
Summary:
This PR integrates FX graph partitioner + Aten2Prims DecompositionInterpreter + Prims' TraceExecutor + naive caches for nvFuser.

Pull Request resolved: #80591
Approved by: https://github.com/jjsjann123, https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/fc10a6372736dd8bac1a894a92beebbafee74f01

Reviewed By: mehtanirav

Differential Revision: D37749351

Pulled By: SherlockNoMad

fbshipit-source-id: fb6eabbca954c58934bd10894f6420d87e413e96
@github-actions github-actions bot deleted the bahuang/nvfuser_backend branch February 17, 2024 02:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants