Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions torchao/prototype/moe_training/kernels/mxfp8/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
blockwise_barrier,
sync_threads,
)
from torchao.prototype.mx_formats.config import ScaleCalculationMode
from torchao.prototype.mx_formats.kernels import (
triton_mxfp8_dequant_dim0,
triton_to_mxfp8_dim0,
)
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx


Expand Down Expand Up @@ -473,11 +476,9 @@ def forward(
"""
# Quantize input
block_size = 32
input_scales, input_data = to_mx(
input_data, input_scales = triton_to_mxfp8_dim0(
input,
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
scaling_mode=ScaleCalculationMode.RCEIL,
inner_block_size=block_size,
)

# Dispatch data (async)
Expand All @@ -501,20 +502,17 @@ def forward(
output_data = torch.ops._c10d_functional.wait_tensor(output_data)

# Dequantize output
lowp_dtype = output_data.dtype
hp_dtype = input.dtype
hp_output = to_dtype(
triton_hp_output = triton_mxfp8_dequant_dim0(
output_data,
output_scales.view(torch.float8_e8m0fnu),
lowp_dtype,
block_size,
output_scales,
hp_dtype,
block_size,
)

ctx.input_splits = input_splits
ctx.output_splits = output_splits
ctx.group = group
return hp_output
return triton_hp_output

@staticmethod
def backward(ctx, grad_output_hp):
Expand All @@ -529,11 +527,9 @@ def backward(ctx, grad_output_hp):

# Quantize grad_output
block_size = 32
grad_out_scales, grad_out_data = to_mx(
grad_out_data, grad_out_scales = triton_to_mxfp8_dim0(
grad_output_hp,
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
scaling_mode=ScaleCalculationMode.RCEIL,
inner_block_size=block_size,
)

# Dispatch data (async)
Expand All @@ -557,13 +553,11 @@ def backward(ctx, grad_output_hp):
grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales)

hp_dtype = grad_output_hp.dtype
lowp_dtype = grad_input_data.dtype
grad_input_hp = to_dtype(
grad_input_hp = triton_mxfp8_dequant_dim0(
grad_input_data,
grad_input_scales.view(torch.float8_e8m0fnu),
lowp_dtype,
block_size,
grad_input_scales,
hp_dtype,
block_size,
)
return grad_input_hp, None, None, None

Expand Down
Loading