diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 51c066ad3f..1c8c1bc207 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -172,6 +172,8 @@ def test_inference_workflow_mx( ], ids=lambda s: f"{s[0]}x{s[1]}x{s[2]}", ) +@pytest.mark.parametrize("use_inference_mode", [False, True]) +@pytest.mark.parametrize("x_rank", [2, 3]) @torch.no_grad() @skip_if_rocm("ROCm float4 gemm require gfx950") def test_inference_workflow_nvfp4( @@ -182,6 +184,8 @@ def test_inference_workflow_nvfp4( use_triton_kernel: bool, use_dynamic_per_tensor_scale: bool, shapes: tuple, + use_inference_mode: bool, + x_rank: int, ): """ Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16 @@ -196,6 +200,16 @@ def test_inference_workflow_nvfp4( if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile: pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile") + + if use_inference_mode and ( + shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel + ): + pytest.skip("skipping unnecessary tests for inference mode") + if x_rank == 3 and ( + shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel + ): + pytest.skip("skipping unnecessary tests for x_rank 3") + batch_size, in_features, out_features = shapes m = nn.Linear(in_features, out_features, bias=bias, dtype=inpt_dtype, device="cuda") @@ -212,6 +226,9 @@ def test_inference_workflow_nvfp4( m_mx = torch.compile(m_mx, fullgraph=True, backend="aot_eager") x = torch.randn(batch_size, in_features, device="cuda", dtype=inpt_dtype) + if x_rank == 3: + x = x.unsqueeze(0) + y_ref = m(x) if use_triton_kernel and mm_config != NVFP4MMConfig.WEIGHT_ONLY: @@ -219,7 +236,11 @@ def test_inference_workflow_nvfp4( y_mx = m_mx(x) assert result["found"], "Expected quantize_nvfp4 kernel to be found" else: - y_mx = m_mx(x) + if use_inference_mode: + with torch.inference_mode(): + y_mx = m_mx(x) + else: + y_mx = m_mx(x) sqnr = compute_error(y_ref, y_mx) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 2397270d5e..18f05290e5 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -502,6 +502,7 @@ def _addmm_nvfp4_dispatch( assert b.scale.t().is_contiguous() assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}" assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}" + assert len(a.shape) == 2 and len(b.shape) == 2 M, K = a.shape[0], a.shape[1] N = b.shape[1] @@ -576,7 +577,9 @@ def nvfp4_linear(func, types, args, kwargs): tensor_amax = torch.max(torch.abs(input_tensor)) per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) else: - per_tensor_scale = weight_tensor._act_per_tensor_scale + per_tensor_scale = weight_tensor.act_per_tensor_scale + orig_shape = input_tensor.shape + input_tensor = input_tensor.view(-1, orig_shape[-1]) input_tensor = NVFP4Tensor.to_nvfp4( input_tensor, block_size=k.block_size, @@ -584,7 +587,9 @@ def nvfp4_linear(func, types, args, kwargs): is_swizzled_scales=k.is_swizzled_scales, use_triton_kernel=k.use_triton_kernel, ) - return _addmm_nvfp4_dispatch(input_tensor, weight_tensor.t(), func, bias=bias) + res = _addmm_nvfp4_dispatch(input_tensor, weight_tensor.t(), func, bias=bias) + res = res.reshape(*orig_shape[:-1], res.shape[-1]) + return res @implements([aten.mm.default, aten.matmul.default])