From 08faa8e6b008b7694d2cc06936f336d8bfdc15c9 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Sep 2025 12:56:36 -0700 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- .../mx_formats/test_inference_workflow.py | 22 +++++++++++-------- .../mx_formats/inference_workflow.py | 4 ---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 90dc2700ce..a3756a305b 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -50,12 +50,12 @@ def run_around_tests(): @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize("emulate", [True, False]) @torch.no_grad() @skip_if_rocm( "ROCm float4 gemm require gfx950" ) # TODO(future): deploy gfx950 in ROCM CI -@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required") -def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool): +def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: bool): """ Smoke test for inference compile """ @@ -64,17 +64,21 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool): if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + elif not is_sm_at_least_100() and not emulate: + pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm") elif elem_dtype == torch.float4_e2m1fn_x2: - if not is_sm_at_least_100(): - pytest.skip("CUDA capability >= 10.0 required for float4 gemm") + if not is_sm_at_least_100() and not emulate: + pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm") m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") m_mx = copy.deepcopy(m) - kernel_choice = ( - MXGemmKernelChoice.CUTLASS - if elem_dtype == torch.float4_e2m1fn_x2 - else MXGemmKernelChoice.CUBLAS - ) + + if emulate: + kernel_choice = MXGemmKernelChoice.EMULATED + elif elem_dtype == torch.float4_e2m1fn_x2: + kernel_choice = MXGemmKernelChoice.CUTLASS + else: + kernel_choice = MXGemmKernelChoice.CUBLAS config = MXFPInferenceConfig( activation_dtype=elem_dtype, weight_dtype=elem_dtype, diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 39f0725390..1b9c369be5 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -96,10 +96,6 @@ def _linear_extra_repr(self): def _mx_inference_linear_transform( module: torch.nn.Module, config: MXFPInferenceConfig ): - # TODO Sm120 has slightly more restrictive reqs - # TODO handle AMD - assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now" - weight = module.weight assert weight.dtype == torch.bfloat16, ( From b88afaf4156aeddde1971c000a0cf11a52ce03d5 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Sep 2025 14:29:48 -0700 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/inference_workflow.py | 3 ++- torchao/prototype/mx_formats/mx_tensor.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 1b9c369be5..cc6e34a708 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -24,6 +24,7 @@ QuantizeTensorToNVFP4Kwargs, per_tensor_amax_to_scale, ) +from torchao.quantization.quant_api import _quantization_type from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -89,7 +90,7 @@ def __post_init__(self): def _linear_extra_repr(self): - return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={repr(self.weight)}" + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" @register_quantize_module_handler(MXFPInferenceConfig) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index b717462b4d..d5ec30c10f 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -544,6 +544,9 @@ def __repr__(self): # TODO better elem dtype print for fp4 return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501 + def _quantization_type(self): + return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # avoid circular dependency From 00d263455f7ee7a882ebb644c038f26845510744 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Sep 2025 17:35:03 -0700 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 14 ++++++++++++++ torchao/prototype/mx_formats/mx_ops.py | 13 +++++++++++++ 2 files changed, 27 insertions(+) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 38eefbff07..2c89fae96d 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -455,6 +455,20 @@ def test_view(elem_dtype): x_mx_2 = x_mx.view(2, 4) # noqa: F841 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_clone(): + data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) + block_size = 4 + data_mx = MXTensor.to_mx(data, torch.float8_e4m3fn, block_size) + data_mx_c = data_mx.clone() + torch.testing.assert_close( + data_mx.to_dtype(torch.bfloat16), + data_mx_c.to_dtype(torch.bfloat16), + atol=0, + rtol=0, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]) @pytest.mark.parametrize("pack_fp6", [False, True]) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 07e47eed66..1779b0e278 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -352,3 +352,16 @@ def autocast_to_copy(func, types, args, kwargs): # If only device was changed, return the device-changed tensor return tensor + + +@implements([aten.clone.default]) +def mx_clone(func, types, args, kwargs): + self = args[0] + memory_format = kwargs.get("memory_format", None) + + if memory_format is not None: + clone_fn = lambda x: x.clone(memory_format=memory_format) + else: + clone_fn = lambda x: x.clone() + + return self._apply_fn_to_data(clone_fn) From ff57676b149a7c2d5af28dd446a0da85bc46994c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 05:21:38 -0700 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_inference_workflow.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index a3756a305b..db9731276b 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -69,6 +69,9 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b elif elem_dtype == torch.float4_e2m1fn_x2: if not is_sm_at_least_100() and not emulate: pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm") + elif not is_sm_at_least_100() and emulate and compile: + # TODO(future PR): investigate and fix this + pytest.skip("mxfp4 + emulate + compile currently does not work, low SQNR") m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") m_mx = copy.deepcopy(m)