-
Notifications
You must be signed in to change notification settings - Fork 376
Open
Description
- mxfp8 all2all -> stay in mxfp8 through the token shuffle -> mxfp8 grouped gemm
- initial mxfp8 all2all impl (drop in replacement for all_to_all_single_autograd, sync required)
- mxfp8 token shuffle (modified version of this Triton kernel which also permutes scales to be in the same order as their associated tokens)
- Extend mxfp8 grouped gemm autograd func to also accept pre-quantized inputs
- Improve 3d expert weight mxfp8 quanitzation CUDA kernel (currently at 65-70% peak memory bandwidth, should target 85%+ like the other mxfp8 quantization kernels)
- Investigate if we can write e8m0 scales directly to blocked format, instead of running separate conversion kernels.
- Improve mxfp8 grouped gemm performance for small K dim (dsv3/kimi shapes). Currently we see less speedup for small, skinny experts than larger experts like llama4 has. We need to improve this since dsv3/kimi base models are so popular now.
- unify dense + moe mxfp8 training code bases
TensorTemplar