From 08faa8e6b008b7694d2cc06936f336d8bfdc15c9 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Sep 2025 12:56:36 -0700 Subject: [PATCH 1/9] Update [ghstack-poisoned] --- .../mx_formats/test_inference_workflow.py | 22 +++++++++++-------- .../mx_formats/inference_workflow.py | 4 ---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 90dc2700ce..a3756a305b 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -50,12 +50,12 @@ def run_around_tests(): @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize("emulate", [True, False]) @torch.no_grad() @skip_if_rocm( "ROCm float4 gemm require gfx950" ) # TODO(future): deploy gfx950 in ROCM CI -@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required") -def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool): +def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: bool): """ Smoke test for inference compile """ @@ -64,17 +64,21 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool): if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + elif not is_sm_at_least_100() and not emulate: + pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm") elif elem_dtype == torch.float4_e2m1fn_x2: - if not is_sm_at_least_100(): - pytest.skip("CUDA capability >= 10.0 required for float4 gemm") + if not is_sm_at_least_100() and not emulate: + pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm") m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") m_mx = copy.deepcopy(m) - kernel_choice = ( - MXGemmKernelChoice.CUTLASS - if elem_dtype == torch.float4_e2m1fn_x2 - else MXGemmKernelChoice.CUBLAS - ) + + if emulate: + kernel_choice = MXGemmKernelChoice.EMULATED + elif elem_dtype == torch.float4_e2m1fn_x2: + kernel_choice = MXGemmKernelChoice.CUTLASS + else: + kernel_choice = MXGemmKernelChoice.CUBLAS config = MXFPInferenceConfig( activation_dtype=elem_dtype, weight_dtype=elem_dtype, diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 39f0725390..1b9c369be5 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -96,10 +96,6 @@ def _linear_extra_repr(self): def _mx_inference_linear_transform( module: torch.nn.Module, config: MXFPInferenceConfig ): - # TODO Sm120 has slightly more restrictive reqs - # TODO handle AMD - assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now" - weight = module.weight assert weight.dtype == torch.bfloat16, ( From b88afaf4156aeddde1971c000a0cf11a52ce03d5 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Sep 2025 14:29:48 -0700 Subject: [PATCH 2/9] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/inference_workflow.py | 3 ++- torchao/prototype/mx_formats/mx_tensor.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 1b9c369be5..cc6e34a708 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -24,6 +24,7 @@ QuantizeTensorToNVFP4Kwargs, per_tensor_amax_to_scale, ) +from torchao.quantization.quant_api import _quantization_type from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -89,7 +90,7 @@ def __post_init__(self): def _linear_extra_repr(self): - return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={repr(self.weight)}" + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" @register_quantize_module_handler(MXFPInferenceConfig) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index b717462b4d..d5ec30c10f 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -544,6 +544,9 @@ def __repr__(self): # TODO better elem dtype print for fp4 return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501 + def _quantization_type(self): + return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # avoid circular dependency From 00d263455f7ee7a882ebb644c038f26845510744 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Sep 2025 17:35:03 -0700 Subject: [PATCH 3/9] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 14 ++++++++++++++ torchao/prototype/mx_formats/mx_ops.py | 13 +++++++++++++ 2 files changed, 27 insertions(+) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 38eefbff07..2c89fae96d 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -455,6 +455,20 @@ def test_view(elem_dtype): x_mx_2 = x_mx.view(2, 4) # noqa: F841 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_clone(): + data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) + block_size = 4 + data_mx = MXTensor.to_mx(data, torch.float8_e4m3fn, block_size) + data_mx_c = data_mx.clone() + torch.testing.assert_close( + data_mx.to_dtype(torch.bfloat16), + data_mx_c.to_dtype(torch.bfloat16), + atol=0, + rtol=0, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]) @pytest.mark.parametrize("pack_fp6", [False, True]) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 07e47eed66..1779b0e278 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -352,3 +352,16 @@ def autocast_to_copy(func, types, args, kwargs): # If only device was changed, return the device-changed tensor return tensor + + +@implements([aten.clone.default]) +def mx_clone(func, types, args, kwargs): + self = args[0] + memory_format = kwargs.get("memory_format", None) + + if memory_format is not None: + clone_fn = lambda x: x.clone(memory_format=memory_format) + else: + clone_fn = lambda x: x.clone() + + return self._apply_fn_to_data(clone_fn) From 5a840c1aedad8bf40e383cd3de71f39e2fe8d881 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Sep 2025 17:44:15 -0700 Subject: [PATCH 4/9] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/mx_ops.py | 43 -------------------------- 1 file changed, 43 deletions(-) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 1779b0e278..d870698601 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -311,49 +311,6 @@ def mx_copy_(func, types, args, kwargs): ) -@implements([aten._to_copy.default]) -def autocast_to_copy(func, types, args, kwargs): - """Autocast + device movement""" - assert isinstance(args[0], MXTensor) - - # Handle dtype parameter - dtype = kwargs.pop("dtype", None) - if dtype is not None: - assert dtype in { - torch.float16, - torch.bfloat16, - }, "Only support floating point conversion for autocast w/ MXTensor" - - # Handle device parameter - device = kwargs.pop("device", None) - if device is not None: - # Apply device change using _apply_fn_to_data - tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device)) - tensor = return_and_correct_aliasing(func, args, {}, tensor) - else: - tensor = args[0] - - # Verify no other kwargs remain - assert len(kwargs) == 0, "Only support dtype and device kwargs for autocast" - - # If dtype is specified, create a new MXTensor with the requested dtype - if dtype is not None: - res = MXTensor( - tensor.qdata, - tensor._scale_e8m0, - tensor._elem_dtype, - tensor._block_size, - dtype, - tensor._gemm_kernel_choice, - tensor._pack_fp6, - tensor.act_quant_kwargs, - ) - return res - - # If only device was changed, return the device-changed tensor - return tensor - - @implements([aten.clone.default]) def mx_clone(func, types, args, kwargs): self = args[0] From ff57676b149a7c2d5af28dd446a0da85bc46994c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 05:21:38 -0700 Subject: [PATCH 5/9] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_inference_workflow.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index a3756a305b..db9731276b 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -69,6 +69,9 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b elif elem_dtype == torch.float4_e2m1fn_x2: if not is_sm_at_least_100() and not emulate: pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm") + elif not is_sm_at_least_100() and emulate and compile: + # TODO(future PR): investigate and fix this + pytest.skip("mxfp4 + emulate + compile currently does not work, low SQNR") m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") m_mx = copy.deepcopy(m) From f58607e22778d0d2eaa18fa5468bc026dc7e7250 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 05:21:40 -0700 Subject: [PATCH 6/9] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_inference_workflow.py | 11 +++++++++++ torchao/prototype/mx_formats/inference_workflow.py | 8 +++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index db9731276b..df5a5a28ff 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy +import tempfile import pytest import torch @@ -100,6 +101,16 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}" ) + # serialization + with tempfile.NamedTemporaryFile() as f: + torch.save(m_mx.state_dict(), f) + f.seek(0) + + # temporary workaround for https://github.com/pytorch/ao/issues/3077 + torch.serialization.add_safe_globals([getattr]) + + _ = torch.load(f, weights_only=True) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index cc6e34a708..d5bf290589 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -17,7 +17,11 @@ _validate_elem_dtype, _validate_gemm_kernel_choice, ) -from torchao.prototype.mx_formats.mx_tensor import MXTensor, QuantizeTensorToMXKwargs +from torchao.prototype.mx_formats.mx_tensor import ( + MXTensor, + QuantizeTensorToMXKwargs, + ScaleCalculationMode, +) from torchao.prototype.mx_formats.nvfp4_tensor import ( NVFP4MMConfig, NVFP4Tensor, @@ -206,6 +210,8 @@ def _nvfp4_inference_linear_transform( NVFP4Tensor, NVFP4MMConfig, MXGemmKernelChoice, + QuantizeTensorToMXKwargs, + ScaleCalculationMode, ] ) From 263ad983efb45c54ce097b5c4cdedfbd3eaa194e Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 06:00:12 -0700 Subject: [PATCH 7/9] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 22 +++++++++++++++++++++ torchao/prototype/mx_formats/mx_ops.py | 20 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 2c89fae96d..90a0047bea 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -26,6 +26,7 @@ from torchao.quantization.utils import compute_error from torchao.utils import ( is_sm_at_least_89, + is_sm_at_least_90, is_sm_at_least_100, torch_version_at_least, ) @@ -556,6 +557,27 @@ def test_to_mx_inductor_single_kernel(): FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0]) +@pytest.mark.skipIf(not is_sm_at_least_90(), "Need sm90+") +def test_index_select(): + """ + test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is + useful when stitching checkpoints of `num_experts` 2D parameters into + a single 3D parameter when converting between model definitions that + use 2D and 3D parameters for their expert weights. + """ + + E, K, N = 128, 256, 512 + x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16) + x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn, 32) + + # import pdb; pdb.set_trace() + + x_mx_1 = x_mx[1] + torch.testing.assert_close( + x_mx.to_dtype(x.dtype)[1], x_mx_1.to_dtype(x.dtype), atol=0, rtol=0 + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not is_sm_at_least_89(), diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index d870698601..8fcda5e69c 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -322,3 +322,23 @@ def mx_clone(func, types, args, kwargs): clone_fn = lambda x: x.clone() return self._apply_fn_to_data(clone_fn) + + +@implements([aten.select.int]) +def mx_select(func, types, args, kwargs): + old_mx_tensor, dim, index = args + assert dim == 0, f"MXTensor aten.select.int with {dim=} is not yet supported" + assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor._scale_e8m0.shape), ( + "unsupported" + ) + new_mx_tensor = old_mx_tensor.__class__( + old_mx_tensor.qdata[index], + old_mx_tensor._scale_e8m0[index], + old_mx_tensor._elem_dtype, + old_mx_tensor._block_size, + old_mx_tensor._orig_dtype, + old_mx_tensor._gemm_kernel_choice, + old_mx_tensor._pack_fp6, + old_mx_tensor.act_quant_kwargs, + ) + return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor) From ebd322691ddc58d68551f2cf983f2669bad3d488 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 06:32:02 -0700 Subject: [PATCH 8/9] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 90a0047bea..67e372a009 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -570,8 +570,6 @@ def test_index_select(): x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16) x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn, 32) - # import pdb; pdb.set_trace() - x_mx_1 = x_mx[1] torch.testing.assert_close( x_mx.to_dtype(x.dtype)[1], x_mx_1.to_dtype(x.dtype), atol=0, rtol=0 From 15248b14c0d126b84ab5986e13fa80456ae856cf Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 08:00:10 -0700 Subject: [PATCH 9/9] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 67e372a009..305ec64500 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -557,6 +557,7 @@ def test_to_mx_inductor_single_kernel(): FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipIf(not is_sm_at_least_90(), "Need sm90+") def test_index_select(): """