Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Feature] CUTLASS kernels for w4a8 quantization #64

Open
supriyar opened this issue Mar 18, 2024 · 4 comments
Open

[New Feature] CUTLASS kernels for w4a8 quantization #64

supriyar opened this issue Mar 18, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@supriyar
Copy link
Contributor

supriyar commented Mar 18, 2024

We plan to add QAT for LLMs to torchao (as mentioned in the original RFC here #47)

For this to run efficiently on the GPU we'd need kernel support for W4A8 quantization (int4 weights, int8 activations).

Other places where this has been raised before
NVIDIA/cutlass#1316,
NVIDIA/cutlass#1370

cc @andrewor14

@supriyar
Copy link
Contributor Author

cc @alexsamardzic @cpuhrsch

@alexsamardzic
Copy link
Contributor

Working on this: NVIDIA/cutlass#1413.

@alexsamardzic alexsamardzic removed their assignment Mar 21, 2024
@jeromeku
Copy link
Collaborator

@alexsamardzic

Great work so far on integrating w4a8 GEMM in Cutlass!

Do you have plans on re-implementing this functionality in pre-Hopper architectures using Cutlass 3.x / CuTe rather the Cutlass 2.x apis that seem to be deprecated?

The 3.x interface has some convenient sub-byte primitives for slicing 4b tensors but warp-level shuffling would still be needed for efficient tensor core loading and mma.

Would be happy to help adapt 4b mixed type gemm using CuTe for Ampere.

@alexsamardzic
Copy link
Contributor

Do you have plans on re-implementing this functionality in pre-Hopper architectures using Cutlass 3.x / CuTe rather the Cutlass 2.x apis that seem to be deprecated?

(Please send further comments to the PR mentioned above - I think it makes most sense to discuss CUTLASS features on CUTLASS GitHub pages.)

As it could be seen from my PR, this feature is implemented the same way as F16/S8, and alike. For my purpose, and that is adding support for this operation into PyTorch, for Ampere architecture and for both eager and compiled mode, this is good enough. I'm not sure in which way my changes could be made more 3.x-y, as the functionality is implemented on the warp level, but if you have any suggestions, please post them either into this, or in separate PR.

@msaroufim msaroufim added the enhancement New feature or request label May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants