diff --git a/torchao/prototype/moe_training/kernels/mxfp8/comms.py b/torchao/prototype/moe_training/kernels/mxfp8/comms.py index 9593b02513..b81abacf1c 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8/comms.py +++ b/torchao/prototype/moe_training/kernels/mxfp8/comms.py @@ -275,7 +275,7 @@ def _mxfp8_on_device_all_to_all_v( world_size=input_hdl.world_size, BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK, BLOCK_SIZE=BLOCK_SIZE, - num_warps=1, + num_warps=16, ) return output