diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 90dc2700ce..db9731276b 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -50,12 +50,12 @@ def run_around_tests(): @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize("emulate", [True, False]) @torch.no_grad() @skip_if_rocm( "ROCm float4 gemm require gfx950" ) # TODO(future): deploy gfx950 in ROCM CI -@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required") -def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool): +def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: bool): """ Smoke test for inference compile """ @@ -64,17 +64,24 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool): if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + elif not is_sm_at_least_100() and not emulate: + pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm") elif elem_dtype == torch.float4_e2m1fn_x2: - if not is_sm_at_least_100(): - pytest.skip("CUDA capability >= 10.0 required for float4 gemm") + if not is_sm_at_least_100() and not emulate: + pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm") + elif not is_sm_at_least_100() and emulate and compile: + # TODO(future PR): investigate and fix this + pytest.skip("mxfp4 + emulate + compile currently does not work, low SQNR") m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") m_mx = copy.deepcopy(m) - kernel_choice = ( - MXGemmKernelChoice.CUTLASS - if elem_dtype == torch.float4_e2m1fn_x2 - else MXGemmKernelChoice.CUBLAS - ) + + if emulate: + kernel_choice = MXGemmKernelChoice.EMULATED + elif elem_dtype == torch.float4_e2m1fn_x2: + kernel_choice = MXGemmKernelChoice.CUTLASS + else: + kernel_choice = MXGemmKernelChoice.CUBLAS config = MXFPInferenceConfig( activation_dtype=elem_dtype, weight_dtype=elem_dtype, diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 39f0725390..1b9c369be5 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -96,10 +96,6 @@ def _linear_extra_repr(self): def _mx_inference_linear_transform( module: torch.nn.Module, config: MXFPInferenceConfig ): - # TODO Sm120 has slightly more restrictive reqs - # TODO handle AMD - assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now" - weight = module.weight assert weight.dtype == torch.bfloat16, (