diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 1729901933..4b6586b385 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -442,8 +442,8 @@ def triton_to_mxfp8_dim0_reference( not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater", ) -@pytest.mark.parametrize("M", (256, 2048)) -@pytest.mark.parametrize("K", (256, 2048)) +@pytest.mark.parametrize("M", (128, 256)) +@pytest.mark.parametrize("K", (128, 256)) def test_triton_mxfp8_dim1_randn(M, K): x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32) @@ -457,8 +457,8 @@ def test_triton_mxfp8_dim1_randn(M, K): not is_sm_at_least_100(), reason="mxfp8 requires CUDA capability 10.0 or greater", ) -@pytest.mark.parametrize("M", (256, 2048, 131072)) -@pytest.mark.parametrize("K", (256, 5120, 7168)) +@pytest.mark.parametrize("M", (128, 256)) +@pytest.mark.parametrize("K", (128, 256)) def test_triton_mxfp8_dim0_randn(M, K): x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) @@ -473,7 +473,7 @@ def test_triton_mxfp8_dim0_randn(M, K): reason="mxfp8 requires CUDA capability 10.0 or greater", ) def test_triton_mxfp8_dim0_zeros(): - x = torch.zeros(8192, 5120, dtype=torch.bfloat16, device="cuda") + x = torch.zeros(128, 256, dtype=torch.bfloat16, device="cuda") x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32) assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs" @@ -486,8 +486,8 @@ def test_triton_mxfp8_dim0_zeros(): not is_sm_at_least_100(), reason="mxfp8 requires CUDA capability 10.0 or greater", ) -@pytest.mark.parametrize("M", (256, 2048, 131072)) -@pytest.mark.parametrize("K", (256, 5120, 7168)) +@pytest.mark.parametrize("M", (128, 256)) +@pytest.mark.parametrize("K", (128, 256)) @pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16)) def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype): x = torch.zeros(M, K, dtype=orig_dtype, device="cuda") @@ -529,8 +529,8 @@ def test_rearrange(shape): not is_sm_at_least_100(), reason="MXFP8 requires CUDA capability 10.0 or greater", ) -@pytest.mark.parametrize("M", (32, 64, 2048)) -@pytest.mark.parametrize("K", (32, 64, 2048)) +@pytest.mark.parametrize("M", (32, 256)) +@pytest.mark.parametrize("K", (32, 256)) @pytest.mark.parametrize("input_dtype", (torch.float32, torch.bfloat16)) @pytest.mark.parametrize( "scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index c858657af6..49343c6608 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -238,9 +238,11 @@ def test_activation_checkpointing(): "recipe_name", [ "mxfp8_emulated", - "mxfp4_emulated", "mxfp8_cublas", - "mxfp4_cutlass", + # TODO(future PR): add mxfp4 back here, but ensure CI speed is not too + # slow + # "mxfp4_emulated", + # "mxfp4_cutlass", ], ) @pytest.mark.parametrize("bias", [False, True]) @@ -258,7 +260,6 @@ def test_activation_checkpointing(): "scale_calculation_mode", [ ScaleCalculationMode.FLOOR, - ScaleCalculationMode.CEIL, # even + compile does not work yet: # https://gist.github.com/vkuzo/1a04845cd503b1c75291aa1ea3bf79c4 # ScaleCalculationMode.EVEN,