Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def run_around_tests():
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("compile", [True, False])
@pytest.mark.parametrize("emulate", [True, False])
@torch.no_grad()
@skip_if_rocm(
"ROCm float4 gemm require gfx950"
) # TODO(future): deploy gfx950 in ROCM CI
@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required")
def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool):
def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: bool):
"""
Smoke test for inference compile
"""
Expand All @@ -64,17 +64,24 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool):
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
elif not is_sm_at_least_100() and not emulate:
pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm")
elif elem_dtype == torch.float4_e2m1fn_x2:
if not is_sm_at_least_100():
pytest.skip("CUDA capability >= 10.0 required for float4 gemm")
if not is_sm_at_least_100() and not emulate:
pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm")
elif not is_sm_at_least_100() and emulate and compile:
# TODO(future PR): investigate and fix this
pytest.skip("mxfp4 + emulate + compile currently does not work, low SQNR")

m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
m_mx = copy.deepcopy(m)
kernel_choice = (
MXGemmKernelChoice.CUTLASS
if elem_dtype == torch.float4_e2m1fn_x2
else MXGemmKernelChoice.CUBLAS
)

if emulate:
kernel_choice = MXGemmKernelChoice.EMULATED
elif elem_dtype == torch.float4_e2m1fn_x2:
kernel_choice = MXGemmKernelChoice.CUTLASS
else:
kernel_choice = MXGemmKernelChoice.CUBLAS
config = MXFPInferenceConfig(
activation_dtype=elem_dtype,
weight_dtype=elem_dtype,
Expand Down
4 changes: 0 additions & 4 deletions torchao/prototype/mx_formats/inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,6 @@ def _linear_extra_repr(self):
def _mx_inference_linear_transform(
module: torch.nn.Module, config: MXFPInferenceConfig
):
# TODO Sm120 has slightly more restrictive reqs
# TODO handle AMD
assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now"

weight = module.weight

assert weight.dtype == torch.bfloat16, (
Expand Down
Loading