diff --git a/test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py b/test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py index cd98fe69ac..fb1d7452ab 100644 --- a/test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py +++ b/test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py @@ -73,7 +73,7 @@ def test_a2a_fwd_bwd(self): tokens_per_ep_rank, dim, device=self.device, - dtype=torch.float32, + dtype=torch.bfloat16, requires_grad=True, ) ref_input_tensor = input_tensor.detach().clone().requires_grad_(True) @@ -107,7 +107,7 @@ def test_a2a_fwd_bwd(self): total_tokens_on_rank_after_a2a, dim, device=self.device, - dtype=torch.float32, + dtype=torch.bfloat16, ) # Do the actual all_to_all_single @@ -188,7 +188,7 @@ def test_a2a_fwd_bwd(self): tokens_per_ep_rank, dim, device=self.device, - dtype=torch.float32, + dtype=torch.bfloat16, requires_grad=True, ) ref_input_tensor = input_tensor.detach().clone().requires_grad_(True) diff --git a/torchao/prototype/moe_training/kernels/mxfp8/quant.py b/torchao/prototype/moe_training/kernels/mxfp8/quant.py index 353688f185..c83b0e1cdf 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8/quant.py +++ b/torchao/prototype/moe_training/kernels/mxfp8/quant.py @@ -175,7 +175,7 @@ def compute_blocked_scale_offsets_for_M_groups(offsets: torch.Tensor): - starting_row_after_padding: 1D integer tensor representing the starting row after padding each to blocked format. """ # Calculate group sizes - zero = torch.tensor([0], dtype=offsets.dtype, device=offsets.device) + zero = torch.zeros(1, dtype=offsets.dtype, device=offsets.device) group_sizes = torch.diff(offsets, prepend=zero) # Round each group size up to the nearest multiple of 128 @@ -203,8 +203,8 @@ def compute_blocked_scale_offsets_for_K_groups( - starting_col_after_padding: 1D integer tensor representing the starting row after padding each to blocked format. """ # Calculate group sizes - zero = torch.tensor( - [0], dtype=scale_group_offsets.dtype, device=scale_group_offsets.device + zero = torch.zeros( + 1, dtype=scale_group_offsets.dtype, device=scale_group_offsets.device ) group_sizes = torch.diff(scale_group_offsets, prepend=zero)