Skip to content

Next steps for mxfp8 MoE training #3379

@danielvegamyhre

Description

@danielvegamyhre
  • mxfp8 all2all -> stay in mxfp8 through the token shuffle -> mxfp8 grouped gemm
    • initial mxfp8 all2all impl (drop in replacement for all_to_all_single_autograd, sync required)
    • mxfp8 token shuffle (modified version of this Triton kernel which also permutes scales to be in the same order as their associated tokens)
    • Extend mxfp8 grouped gemm autograd func to also accept pre-quantized inputs
  • Improve 3d expert weight mxfp8 quanitzation CUDA kernel (currently at 65-70% peak memory bandwidth, should target 85%+ like the other mxfp8 quantization kernels)
  • Investigate if we can write e8m0 scales directly to blocked format, instead of running separate conversion kernels.
  • Improve mxfp8 grouped gemm performance for small K dim (dsv3/kimi shapes). Currently we see less speedup for small, skinny experts than larger experts like llama4 has. We need to improve this since dsv3/kimi base models are so popular now.
  • unify dense + moe mxfp8 training code bases

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions