- 
                Notifications
    
You must be signed in to change notification settings  - Fork 357
 
[mxfp8 moe training] integrate mxfp8 grouped gemm and triton kernels for scale conversion to blocked format #2977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
          
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2977
 Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes.  | 
    
7c8e965    to
    ecde9da      
    Compare
  
    | Returns: | ||
| - starting_row_after_padding: 1D integer tensor representing the starting row after padding each to blocked format. | ||
| - starting_col_after_padding: 1D integer tensor representing the starting row after padding each to blocked format. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit row -> col
| out_dtype=out_dtype, | ||
| ) | ||
| 
               | 
          ||
| # Store what we need for backward before returning. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: not needed comment
ecde9da    to
    ad14ff3      
    Compare
  
    ad14ff3    to
    7d857dd      
    Compare
  
    
Summary
This PR integrates all the grouped GEMMs and triton kernels for per group scale conversions landed recently (details below) into mxfp8 MoE training:
torch._scaled_grouped_mm(MXFP8 grouped GEMM support for torch._scaled_grouped_mm + submodule bump pytorch#162209)Test plan
pytest test/prototype/moe_training/test_scaled_grouped_mm.py -k test_mxfp8_grouped_gemm_with_dq_fwd_bwd -sNext steps