-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Restore mixed dtypes GEMM auto-tuning for Ampere #129058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129058
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New FailuresAs of commit 44caacd with merge base f565d16 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Recent changes in the CUTLASS-based auto-tuning code for Inductor (the stack of changes roughly listed here) pretty much disabled any kind of CUTLASS-based auto-tuning for Ampere architecture, including auto-tuning for mixed dtypes GEMM, and sparse semi-structured GEMM. PRs in this stack are intended to restore the functionality for the two mentioned special GEMM types. (The rationale for original changes is provided in the comments section of PR 124577, starting from this comment.) The change mentioned is not detected by CI because tests for CUTLASS-based auto-tuning are checking only that there were no crashes during auto-tuning procedure, and not that no CUTLASS-based candidate kernel is generated. This kind of testing actually makes sense, because CUTLASS lists configurations that are know to work, for example, on A100 GPUs, but on less capable GPUs of Ampere architecture some of these configurations may crash corresponding candidate kernel during the execution (for lack of resources, etc.). However, apparently some kind of test is needed that would actually check the number of CUTLASS candidate kernels generated. Besides this extended testing needed, it remains to be decided how changes from this stack are to be incorporated. @kadeng suggested that these should be kept separate, and this is how this particular PR is implemented - it adds new Another thing to discuss: is there any need for CUTLASS-based auto-tuning on Ampere for ordinary MM/ADDMM operators? PRs in this stack revert auto-tuning functionality on Ampere for mentioned two special GEMM types, but for ordinary dense arguments of the same dtype, auto-tuning is still not possible on Ampere (while it was there before mentioned changes). (As seen from above comments: PRs in this stack are not yet to be considered for merging.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gemm_template_2x.py is almost Identical to the existing gemm_template.py.
What is the justification for having two isolated implementations for this? It looks like it would be very easy to cleanly refactor this into a common base + two small derived classes.
| CUTLASSGemmTemplate.add_cutlass_gemm_choices( | ||
| choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True | ||
| ) | ||
| CUTLASS2xGemmTemplate.add_cutlass_gemm_choices( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can both of these be used unconditionally? Or does there need to be a branch on the cutlass version?
If we can use both, do we need both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit, from @alexsamardzic comment I see this is adding back support for Ampere, so if the existing template does not work for ampere should there be a branch on SM verson to decide which to add?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Roughly speaking: add_cutlass_gemm_choices() will go through all configurations offered by CUTLASS, and will then filter_op() on each one. For CUTLASSGemmTemplate, filter_op() will eliminate operations that are not for Hopper, and for CUTLASS2xGemmTemplate, it will eliminate operations that are not for Ampere. So the only harm here, as is, is going through all the configurations twice - but that's just another reason to unify these two classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would think that it makes sense to factor out the gen_ops() and filter_op() methods from CutlassGEMMTemplate and CUTLASS2xGemmTemplate to a shared module and make them functions, not methods. Then you can pass in a flag that controls that. @alexsamardzic is that the kind of planned refactoring you had in mind already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(This comment applies to your suggestions below too.) The main point is not in commonalities, but exactly that differences between the two classes are actually minor. Thus, my questions is: are you still at the position that we need to have two separate classes, plus a base class? Frankly, I think single class with some if-else statements, like what was the case before the major refactoring, would be simpler. Plus, it would avoid going through all the ops offered by CUTLASS generator twice, in cases when we want to support both CUTLASS 2.x and CUTLASS 3.x architecture kernels for auto-tuning the same operator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can understand where you're coming from. It might appear simpler and more compact to write at the moment, but it definitely impacts readability and testability ( and thereby maintainability ) very badly and that's not to be underestimated. So, yes, I still think that it's better to have two classes. But that's a bit of Software Design philosophy. I personally prioritize readability, modularity, testability and orthogonality in software design over the DRY principle, but I know that opinions on that differ.
Going through all ops twice can be solved differently ( sort them into separate lists / sets just once and remember these ). It's inefficient anyway at the moment. Cutlass 3.x will certainly evolve a lot in the future, while Cutlass 2.x likely won't, so I would not just judge what's easy and what is not by how the code looks at the moment. It should be possible to evolve both variants independently.
You could factor out shared code into a shared parent class or an utility module if that helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I'm on leave now for several weeks, but as soon as I'm back, I'll refactor the code and will ping you for a review then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for me, I will be back in a bit more than 2 weeks. Sorry I had to delay this ..
|
Removed bits of |
| """ | ||
|
|
||
| # Additional includes which are neccessary if the standalone test / debug runner is generated as wel | ||
| GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES = r""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there's likely no harm in deduplicating the standalone runner related stuff, e.g. refactor it such that these template strings are included from a single module. Or are there any changes in there that I don't see?
| return None | ||
|
|
||
| @staticmethod | ||
| def flip_cutlass_layout( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these static utility methods would also be cases where it's likely ok to just call the original implementation or factor it out into a separate module that both use.
| new_op.D.layout = CUTLASS2xGemmTemplate.cutlass_layout(d_layout) | ||
| return new_op | ||
|
|
||
| def filter_op( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should likely be factored out into a separate module, and take a flag as input which tells whether to return cutlass 2.x or 3.x ops.
| op.C.layout = op.D.layout | ||
| return op | ||
|
|
||
| def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined] # noqa: F821 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, candidate for refactoring ( @alexsamardzic I assume you already had this in mind as you mentioned )
| res += "\n\n" + test_runner_code | ||
| return res | ||
|
|
||
| def test_call_statement( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can likely also be factored out of this and the original class.
|
@kadeng: This PR is now ready for review, as well as #123742. I made the changes according to your comments above, i.e. common parts for 2.x and 3.x generators are now extracted in an abstract base class Side question, regarding a previous discussion: I looked recently into Python |
Yes, the python part of the Cutlass library was not fully available yet when this Cutlass backend was initially written. To me, the Python backend appears to be well written and would likely provide a cleaner way to generate the code ( I would not use the compilation / linking implementation provided). That said, it's likely pretty much an all-or-nothing thing which would require a complete rewrite of the Cutlass backend without any performance gains. |
Agreed. Its main advantage is that it has quick and nice method of generating epilogues - for example, it would make the job much easier for me, and the solution much cleaner, if/when I proceed to auto-tuning for mixed data types and/or sparse semi-structured ADDMM operators. In any case: as mentioned above, the code is changed now according to your request that there are two separate classes for CUTLASS 2.x and 3.x based stuff, together with an abstract base class with commonalities, so - would it be possible for you to give a full review to this PR? |
|
|
||
| # TODO: Enable dynamic test cases when dynamic support is added. | ||
| @unittest.skipIf(not SM80OrLater, "need sm_80") | ||
| @unittest.skipIf(not SM80, "need sm_80 exactly") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need SM80 exactly, not SM80 or up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mixed data types MM is enabled only for SM 8.x in eager mode - it's implemented quite differently in CUTLASS for SM 9.x, and I simply never had an opportunity to test it. So, for now, the corresponding ATen operator could be auto-tuned only on SM 8.x too, thus no need to enable test for SM 9.x (also, if I remember it correctly, it was failing CI if enabled).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm .. I think we should ensure that this isn't leading to exceptions or silently to incorrect results on SM90 systems. Do you know how it failed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me. I didn't understand why that one test is SM80 exactly, but that's not blocking IMO.
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
Frankly I don't remember it any more, as it was on the first CI run, when I created PR... Eager mode has |
|
@pytorchbot merge |
|
@alexsamardzic ciflow/trunk that's automatically added when merging adds a few more tests. Some of them failed. Could you try to verify whether any of these are real failures ? ( hud.pytorch.org helps with that ) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 2, 3, macos-m1-stable), trunk / macos-py3-arm64 / test (default, 3, 3, macos-m1-stable) Details for Dev Infra teamRaised by workflow job |
I've seen these, and looked into logs - they seem unrelated to my changes. |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 3 checks: trunk / macos-py3-arm64 / test (default, 1, 3, macos-m1-stable), trunk / macos-py3-arm64 / test (default, 2, 3, macos-m1-stable), trunk / macos-py3-arm64 / test (default, 3, 3, macos-m1-stable) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
|
||
| if epilogue_template is None: | ||
| arguments = self._template_from_string(argument_template).render( | ||
| split_k=1, **options |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexsamardzic wondering why we fixed the split_k value here? Thanks!
cc @kadeng
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember it correctly, it won't compile with any other value.
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang