From 08faa8e6b008b7694d2cc06936f336d8bfdc15c9 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Sep 2025 12:56:36 -0700 Subject: [PATCH 01/13] Update [ghstack-poisoned] --- .../mx_formats/test_inference_workflow.py | 22 +++++++++++-------- .../mx_formats/inference_workflow.py | 4 ---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 90dc2700ce..a3756a305b 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -50,12 +50,12 @@ def run_around_tests(): @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize("emulate", [True, False]) @torch.no_grad() @skip_if_rocm( "ROCm float4 gemm require gfx950" ) # TODO(future): deploy gfx950 in ROCM CI -@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required") -def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool): +def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: bool): """ Smoke test for inference compile """ @@ -64,17 +64,21 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool): if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + elif not is_sm_at_least_100() and not emulate: + pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm") elif elem_dtype == torch.float4_e2m1fn_x2: - if not is_sm_at_least_100(): - pytest.skip("CUDA capability >= 10.0 required for float4 gemm") + if not is_sm_at_least_100() and not emulate: + pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm") m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") m_mx = copy.deepcopy(m) - kernel_choice = ( - MXGemmKernelChoice.CUTLASS - if elem_dtype == torch.float4_e2m1fn_x2 - else MXGemmKernelChoice.CUBLAS - ) + + if emulate: + kernel_choice = MXGemmKernelChoice.EMULATED + elif elem_dtype == torch.float4_e2m1fn_x2: + kernel_choice = MXGemmKernelChoice.CUTLASS + else: + kernel_choice = MXGemmKernelChoice.CUBLAS config = MXFPInferenceConfig( activation_dtype=elem_dtype, weight_dtype=elem_dtype, diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 39f0725390..1b9c369be5 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -96,10 +96,6 @@ def _linear_extra_repr(self): def _mx_inference_linear_transform( module: torch.nn.Module, config: MXFPInferenceConfig ): - # TODO Sm120 has slightly more restrictive reqs - # TODO handle AMD - assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now" - weight = module.weight assert weight.dtype == torch.bfloat16, ( From b88afaf4156aeddde1971c000a0cf11a52ce03d5 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Sep 2025 14:29:48 -0700 Subject: [PATCH 02/13] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/inference_workflow.py | 3 ++- torchao/prototype/mx_formats/mx_tensor.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 1b9c369be5..cc6e34a708 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -24,6 +24,7 @@ QuantizeTensorToNVFP4Kwargs, per_tensor_amax_to_scale, ) +from torchao.quantization.quant_api import _quantization_type from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -89,7 +90,7 @@ def __post_init__(self): def _linear_extra_repr(self): - return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={repr(self.weight)}" + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" @register_quantize_module_handler(MXFPInferenceConfig) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index b717462b4d..d5ec30c10f 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -544,6 +544,9 @@ def __repr__(self): # TODO better elem dtype print for fp4 return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501 + def _quantization_type(self): + return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # avoid circular dependency From 00d263455f7ee7a882ebb644c038f26845510744 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Sep 2025 17:35:03 -0700 Subject: [PATCH 03/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 14 ++++++++++++++ torchao/prototype/mx_formats/mx_ops.py | 13 +++++++++++++ 2 files changed, 27 insertions(+) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 38eefbff07..2c89fae96d 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -455,6 +455,20 @@ def test_view(elem_dtype): x_mx_2 = x_mx.view(2, 4) # noqa: F841 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_clone(): + data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) + block_size = 4 + data_mx = MXTensor.to_mx(data, torch.float8_e4m3fn, block_size) + data_mx_c = data_mx.clone() + torch.testing.assert_close( + data_mx.to_dtype(torch.bfloat16), + data_mx_c.to_dtype(torch.bfloat16), + atol=0, + rtol=0, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]) @pytest.mark.parametrize("pack_fp6", [False, True]) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 07e47eed66..1779b0e278 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -352,3 +352,16 @@ def autocast_to_copy(func, types, args, kwargs): # If only device was changed, return the device-changed tensor return tensor + + +@implements([aten.clone.default]) +def mx_clone(func, types, args, kwargs): + self = args[0] + memory_format = kwargs.get("memory_format", None) + + if memory_format is not None: + clone_fn = lambda x: x.clone(memory_format=memory_format) + else: + clone_fn = lambda x: x.clone() + + return self._apply_fn_to_data(clone_fn) From 5a840c1aedad8bf40e383cd3de71f39e2fe8d881 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Sep 2025 17:44:15 -0700 Subject: [PATCH 04/13] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/mx_ops.py | 43 -------------------------- 1 file changed, 43 deletions(-) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 1779b0e278..d870698601 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -311,49 +311,6 @@ def mx_copy_(func, types, args, kwargs): ) -@implements([aten._to_copy.default]) -def autocast_to_copy(func, types, args, kwargs): - """Autocast + device movement""" - assert isinstance(args[0], MXTensor) - - # Handle dtype parameter - dtype = kwargs.pop("dtype", None) - if dtype is not None: - assert dtype in { - torch.float16, - torch.bfloat16, - }, "Only support floating point conversion for autocast w/ MXTensor" - - # Handle device parameter - device = kwargs.pop("device", None) - if device is not None: - # Apply device change using _apply_fn_to_data - tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device)) - tensor = return_and_correct_aliasing(func, args, {}, tensor) - else: - tensor = args[0] - - # Verify no other kwargs remain - assert len(kwargs) == 0, "Only support dtype and device kwargs for autocast" - - # If dtype is specified, create a new MXTensor with the requested dtype - if dtype is not None: - res = MXTensor( - tensor.qdata, - tensor._scale_e8m0, - tensor._elem_dtype, - tensor._block_size, - dtype, - tensor._gemm_kernel_choice, - tensor._pack_fp6, - tensor.act_quant_kwargs, - ) - return res - - # If only device was changed, return the device-changed tensor - return tensor - - @implements([aten.clone.default]) def mx_clone(func, types, args, kwargs): self = args[0] From ff57676b149a7c2d5af28dd446a0da85bc46994c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 05:21:38 -0700 Subject: [PATCH 05/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_inference_workflow.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index a3756a305b..db9731276b 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -69,6 +69,9 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b elif elem_dtype == torch.float4_e2m1fn_x2: if not is_sm_at_least_100() and not emulate: pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm") + elif not is_sm_at_least_100() and emulate and compile: + # TODO(future PR): investigate and fix this + pytest.skip("mxfp4 + emulate + compile currently does not work, low SQNR") m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") m_mx = copy.deepcopy(m) From f58607e22778d0d2eaa18fa5468bc026dc7e7250 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 05:21:40 -0700 Subject: [PATCH 06/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_inference_workflow.py | 11 +++++++++++ torchao/prototype/mx_formats/inference_workflow.py | 8 +++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index db9731276b..df5a5a28ff 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy +import tempfile import pytest import torch @@ -100,6 +101,16 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}" ) + # serialization + with tempfile.NamedTemporaryFile() as f: + torch.save(m_mx.state_dict(), f) + f.seek(0) + + # temporary workaround for https://github.com/pytorch/ao/issues/3077 + torch.serialization.add_safe_globals([getattr]) + + _ = torch.load(f, weights_only=True) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index cc6e34a708..d5bf290589 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -17,7 +17,11 @@ _validate_elem_dtype, _validate_gemm_kernel_choice, ) -from torchao.prototype.mx_formats.mx_tensor import MXTensor, QuantizeTensorToMXKwargs +from torchao.prototype.mx_formats.mx_tensor import ( + MXTensor, + QuantizeTensorToMXKwargs, + ScaleCalculationMode, +) from torchao.prototype.mx_formats.nvfp4_tensor import ( NVFP4MMConfig, NVFP4Tensor, @@ -206,6 +210,8 @@ def _nvfp4_inference_linear_transform( NVFP4Tensor, NVFP4MMConfig, MXGemmKernelChoice, + QuantizeTensorToMXKwargs, + ScaleCalculationMode, ] ) From 263ad983efb45c54ce097b5c4cdedfbd3eaa194e Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 06:00:12 -0700 Subject: [PATCH 07/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 22 +++++++++++++++++++++ torchao/prototype/mx_formats/mx_ops.py | 20 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 2c89fae96d..90a0047bea 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -26,6 +26,7 @@ from torchao.quantization.utils import compute_error from torchao.utils import ( is_sm_at_least_89, + is_sm_at_least_90, is_sm_at_least_100, torch_version_at_least, ) @@ -556,6 +557,27 @@ def test_to_mx_inductor_single_kernel(): FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0]) +@pytest.mark.skipIf(not is_sm_at_least_90(), "Need sm90+") +def test_index_select(): + """ + test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is + useful when stitching checkpoints of `num_experts` 2D parameters into + a single 3D parameter when converting between model definitions that + use 2D and 3D parameters for their expert weights. + """ + + E, K, N = 128, 256, 512 + x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16) + x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn, 32) + + # import pdb; pdb.set_trace() + + x_mx_1 = x_mx[1] + torch.testing.assert_close( + x_mx.to_dtype(x.dtype)[1], x_mx_1.to_dtype(x.dtype), atol=0, rtol=0 + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not is_sm_at_least_89(), diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index d870698601..8fcda5e69c 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -322,3 +322,23 @@ def mx_clone(func, types, args, kwargs): clone_fn = lambda x: x.clone() return self._apply_fn_to_data(clone_fn) + + +@implements([aten.select.int]) +def mx_select(func, types, args, kwargs): + old_mx_tensor, dim, index = args + assert dim == 0, f"MXTensor aten.select.int with {dim=} is not yet supported" + assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor._scale_e8m0.shape), ( + "unsupported" + ) + new_mx_tensor = old_mx_tensor.__class__( + old_mx_tensor.qdata[index], + old_mx_tensor._scale_e8m0[index], + old_mx_tensor._elem_dtype, + old_mx_tensor._block_size, + old_mx_tensor._orig_dtype, + old_mx_tensor._gemm_kernel_choice, + old_mx_tensor._pack_fp6, + old_mx_tensor.act_quant_kwargs, + ) + return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor) From 235494e8bff43cf6b094310e197d9bd74fb68c4a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 06:32:02 -0700 Subject: [PATCH 08/13] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/mx_ops.py | 10 +- torchao/prototype/mx_formats/mx_tensor.py | 318 ++++++++++++++++++++-- torchao/prototype/mx_formats/utils.py | 5 +- 3 files changed, 305 insertions(+), 28 deletions(-) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 8fcda5e69c..f9cb2a9bfe 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -45,15 +45,7 @@ MX_FUNCTION_TABLE: Dict[Any, Any] = {} -def implements(aten_ops): - """Register aten ops to the mx op table""" - - def decorator(func): - for op in aten_ops: - MX_OPS_TABLE[op] = func - return func - - return decorator +implements = MXTensor.implements @implements([aten.detach.default, aten.alias.default]) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index d5ec30c10f..51ec74c12b 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -22,7 +22,12 @@ import torch from torch.distributed._tensor import DTensor +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) +from torch.utils._pytree import tree_map +import torchao.ops from torchao.prototype.mx_formats.config import MXGemmKernelChoice, ScaleCalculationMode from torchao.prototype.mx_formats.constants import ( BLOCK_SIZE_DEFAULT, @@ -57,10 +62,13 @@ triton_f6_e3m2_to_scaled_bf16, unpack_uint4, ) +from torchao.prototype.mx_formats.utils import to_blocked from torchao.quantization.quantize_.common import ( QuantizeTensorKwargs, ) -from torchao.utils import TorchAOBaseTensor +from torchao.utils import TorchAOBaseTensor, fill_defaults + +aten = torch.ops.aten # TODO(later): read from somewhere else? SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 @@ -547,23 +555,6 @@ def __repr__(self): def _quantization_type(self): return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - # avoid circular dependency - from torchao.prototype.mx_formats.mx_funcs import MX_FUNC_TABLE - from torchao.prototype.mx_formats.mx_ops import MX_OPS_TABLE - - if func in MX_OPS_TABLE: - return MX_OPS_TABLE[func](func, types, args, kwargs) - - # TODO AO BASE_TENSOR doesn't respect dispatch and function modes - # We are calling nn.functional.linear from within LinearAct Tensor even though - # We are in a __torch__dispatch. This disables the decomposition and we get this op - if func == torch.ops.aten.linear.default: - return MX_FUNC_TABLE[func](func, types, args, kwargs) - - raise NotImplementedError(f"{func} not implemented") - def to_dtype(self, target_dtype): return to_dtype( self.qdata, @@ -624,3 +615,294 @@ def to_mx( # Do not force the MXTensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl + + +implements = MXTensor.implements + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(func) + ) + + +def _get_gemm_choice( + choice_a: Optional[MXGemmKernelChoice], choice_b: Optional[MXGemmKernelChoice] +) -> MXGemmKernelChoice: + if choice_a is not None and choice_b is not None: + assert choice_a == choice_b, ( + "Both MXTensor inputs must have the same gemm config if specified" + ) + return choice_a + + # Assert that at least one is set and return that one + assert choice_a is not None or choice_b is not None, ( + "At least one gemm choice must be specified" + ) + return choice_a if choice_a is not None else choice_b + + +def _addmm_mx_dispatch( + a: torch.Tensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: + """ + Core implementation shared between mx_mm and mx_addmm. + The only difference is whether bias is None or not. + """ + + if not isinstance(a, MXTensor): + assert b.act_quant_kwargs is not None, "weight-only quant not yet supported" + k = b.act_quant_kwargs + a = MXTensor.to_mx( + a, + k.elem_dtype, + k.block_size, + k.scaling_mode, + k.gemm_kernel_choice, + k.pack_fp6, + ) + + gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice) + + if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): + # real MX gemm backed by torchao's CUTLASS kernels + M, K, N = a.shape[0], a.shape[1], b.shape[1] + assert a.qdata.is_contiguous() + assert b.qdata.t().is_contiguous() + assert a._block_size == 32, f"Invalid block size {a._block_size}" + assert b._block_size == 32, f"Invalid block size {b._block_size}" + + a_scale = a._scale_e8m0.view(M, K // a._block_size) + b_scale = b._scale_e8m0.view(N, K // b._block_size) + a_scale_block = to_blocked(a_scale) + b_scale_block = to_blocked(b_scale) + + if a._elem_dtype == torch.float8_e4m3fn: + assert b._elem_dtype == torch.float8_e4m3fn + assert gemm_choice is MXGemmKernelChoice.CUBLAS, ( + "CUBLAS is the only supported kernel choice for MX FP8 operations" + ) + + res = torch._scaled_mm( + a.qdata, + b.qdata, + a_scale_block.view(torch.float8_e8m0fnu), + b_scale_block.view(torch.float8_e8m0fnu), + bias=bias, + out_dtype=torch.bfloat16, + ) + else: + assert a._elem_dtype == torch.float4_e2m1fn_x2 + assert b._elem_dtype == torch.float4_e2m1fn_x2 + assert gemm_choice is MXGemmKernelChoice.CUTLASS, "unsupported" + # FP4 operations + res = torchao.ops.mx_fp4_bf16( + a.qdata, b.qdata, a_scale_block, b_scale_block + ) + # TODO add optional bias to kernel + if bias is not None: + res = res + bias + + else: + # emulated MX gemm + a_hp = a.to_dtype(a._orig_dtype) + b_hp = b.to_dtype(b._orig_dtype) + # assert memory layout we expect to be required in hardware + assert a_hp.is_contiguous() + assert b_hp.t().is_contiguous() + + # Call appropriate aten_op based on whether bias is provided + if bias is not None: + res = aten_op(bias, a_hp, b_hp) # addmm + else: + res = aten_op(a_hp, b_hp) # mm + + return res + + +@implements([aten.mm.default, aten.matmul.default]) +def mx_mm(func, types, args, kwargs): + a = args[0] + b = args[1] + assert isinstance(b, MXTensor) + + return _addmm_mx_dispatch(a, b, func) + + +@implements([aten.addmm.default]) +def mx_addmm(func, types, args, kwargs): + assert isinstance(args[0], torch.Tensor) and isinstance(args[2], MXTensor) + bias = args[0] + a = args[1] + b = args[2] + return _addmm_mx_dispatch(a, b, func, bias=bias) + + +@implements([aten.t.default]) +def mx_t(func, types, args, kwargs): + # For now, only transpose(input, 0, 1) is supported. + old = args[0] + new = MXTensor( + old.qdata.t(), + old._scale_e8m0, + old._elem_dtype, + old._block_size, + old._orig_dtype, + old._gemm_kernel_choice, + old._pack_fp6, + old.act_quant_kwargs, + ) + return new + + +@implements([aten.sum.dim_IntList]) +def mx_cast_up_op(func, types, args, kwargs): + """Be careful with this function, this is a "fallback" op that + casts the output of the op to the original precision. And performs the op. + + We currently need this to support the backward for admmm bias. + "addmm" -> out + "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut" + """ + + def unwrap(x): + if isinstance(x, MXTensor): + return x.to_dtype(x._orig_dtype) + return x + + new_args = tree_map(unwrap, args) + new_kwargs = tree_map(unwrap, kwargs) + return func(*new_args, **new_kwargs) + + +@implements([aten.view.default]) +def mx_view_op(func, types, args, kwargs): + data = args[0].qdata + new_size = args[1] + if args[0]._elem_dtype == torch.float4_e2m1fn_x2: + # special case fp4 as we pack two elements per byte + new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) + elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6: + # special case fp6 as we pack 4 elements in 3 bytes + new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous()) + new_data = func(data, new_size, *args[2:], **kwargs) + return MXTensor( + new_data, + args[0]._scale_e8m0, + args[0]._elem_dtype, + args[0]._block_size, + args[0]._orig_dtype, + args[0]._gemm_kernel_choice, + args[0]._pack_fp6, + args[0].act_quant_kwargs, + ) + + +@implements([aten.slice.Tensor]) +def mx_slice(func, types, args, kwargs): + x, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + + if step != 1: + raise ValueError("Only support aten.slice with step=1") + + M, K = x.shape[0], x.shape[1] + + # TODO why doesn't scale have shape? + scale_shaped = x._scale_e8m0.view(M, K // x._block_size) + + if dim == 0: + # Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now + sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step) + sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step).unsqueeze(-1) + elif dim == 1: + # Slicing along reduciton dim + if start is not None: + # Assert start is a multiple of block_size + assert start % x._block_size == 0, ( + f"Start index {start} must be a multiple of block_size {x._block_size}" + ) + + if end is not None: + # Assert end is a multiple of block_size + assert end % x._block_size == 0, ( + f"End index {end} must be a multiple of block_size {x._block_size}" + ) + + sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step) + + # Calculate which scale elements to keep + start_block = 0 if start is None else start // x._block_size + end_block = -1 if end is None else end // x._block_size + + # Slice the scale tensor accordingly + sliced_scale = aten.slice.Tensor( + scale_shaped, 1, start_block, end_block, step + ).unsqueeze(-1) + else: + raise ValueError( + f"MXTensor only supports slicing along dimensions 0 and 1, got dim={dim}" + ) + + return return_and_correct_aliasing( + func, + args, + kwargs, + MXTensor( + sliced_data, + sliced_scale, + x._elem_dtype, + x._block_size, + x._orig_dtype, + x._gemm_kernel_choice, + x._pack_fp6, + x.act_quant_kwargs, + ), + ) + + +@implements([aten.copy_.default]) +def mx_copy_(func, types, args, kwargs): + self = args[0] + src = args[1] + if MXTensor._same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return + raise ValueError( + f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" + ) + + +@implements([aten.clone.default]) +def mx_clone(func, types, args, kwargs): + self = args[0] + memory_format = kwargs.get("memory_format", None) + + if memory_format is not None: + clone_fn = lambda x: x.clone(memory_format=memory_format) + else: + clone_fn = lambda x: x.clone() + + return self._apply_fn_to_data(clone_fn) + + +@implements([aten.select.int]) +def mx_select(func, types, args, kwargs): + old_mx_tensor, dim, index = args + assert dim == 0, f"MXTensor aten.select.int with {dim=} is not yet supported" + assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor._scale_e8m0.shape), ( + "unsupported" + ) + new_mx_tensor = old_mx_tensor.__class__( + old_mx_tensor.qdata[index], + old_mx_tensor._scale_e8m0[index], + old_mx_tensor._elem_dtype, + old_mx_tensor._block_size, + old_mx_tensor._orig_dtype, + old_mx_tensor._gemm_kernel_choice, + old_mx_tensor._pack_fp6, + old_mx_tensor.act_quant_kwargs, + ) + return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor) diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 2802888980..247b17d838 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -16,7 +16,6 @@ triton_mx_block_rearrange, triton_to_mxfp8_dim1, ) -from torchao.prototype.mx_formats.mx_tensor import MXTensor Tensor = torch.Tensor @@ -120,6 +119,10 @@ def _to_mxfp8_dim1_kernel_wrapper( cast_kernel_choice, scale_calculation_mode: ScaleCalculationMode, ): + # avoid circular import + # TODO(future PR): split this utils file in two + from torchao.prototype.mx_formats.mx_tensor import MXTensor + if cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON: assert scale_calculation_mode == ScaleCalculationMode.FLOOR a_data, a_scale = triton_to_mxfp8_dim1(a, block_size) From ebd322691ddc58d68551f2cf983f2669bad3d488 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 06:32:02 -0700 Subject: [PATCH 09/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 90a0047bea..67e372a009 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -570,8 +570,6 @@ def test_index_select(): x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16) x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn, 32) - # import pdb; pdb.set_trace() - x_mx_1 = x_mx[1] torch.testing.assert_close( x_mx.to_dtype(x.dtype)[1], x_mx_1.to_dtype(x.dtype), atol=0, rtol=0 From b9dbfa8abd762a1169bfb0e0883f5b773ff75b36 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 06:56:40 -0700 Subject: [PATCH 10/13] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/mx_ops.py | 336 ------------------------- 1 file changed, 336 deletions(-) delete mode 100644 torchao/prototype/mx_formats/mx_ops.py diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py deleted file mode 100644 index f9cb2a9bfe..0000000000 --- a/torchao/prototype/mx_formats/mx_ops.py +++ /dev/null @@ -1,336 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -This file defines the ops needed for our tensor subclass implementation -of `MXTensor` to work naturally in PyTorch programs. For example, if -the modeling code is written as - - x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn) - w_mx = MXTensor.to_mx(w, torch.float8_e4m3fn) - y = F.linear(x_mx, w_mx) - -then the ops in this file are used under the hood to properly route -the underlying data fields to the MX matmul. -""" - -from typing import Any, Dict, Optional - -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) -from torch.utils._pytree import tree_map - -import torchao.ops -from torchao.prototype.mx_formats.config import MXGemmKernelChoice -from torchao.prototype.mx_formats.constants import ( - DTYPE_FP6_E2M3, - DTYPE_FP6_E3M2, -) -from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501 - MXTensor, - tensor_size_hp_to_fp4x2, - tensor_size_hpx3_to_fp6x4, -) -from torchao.prototype.mx_formats.utils import to_blocked -from torchao.utils import fill_defaults - -aten = torch.ops.aten - -MX_OPS_TABLE: Dict[Any, Any] = {} -MX_FUNCTION_TABLE: Dict[Any, Any] = {} - - -implements = MXTensor.implements - - -@implements([aten.detach.default, aten.alias.default]) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(func) - ) - - -def _get_gemm_choice( - choice_a: Optional[MXGemmKernelChoice], choice_b: Optional[MXGemmKernelChoice] -) -> MXGemmKernelChoice: - if choice_a is not None and choice_b is not None: - assert choice_a == choice_b, ( - "Both MXTensor inputs must have the same gemm config if specified" - ) - return choice_a - - # Assert that at least one is set and return that one - assert choice_a is not None or choice_b is not None, ( - "At least one gemm choice must be specified" - ) - return choice_a if choice_a is not None else choice_b - - -def _addmm_mx_dispatch( - a: torch.Tensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None -) -> torch.Tensor: - """ - Core implementation shared between mx_mm and mx_addmm. - The only difference is whether bias is None or not. - """ - - if not isinstance(a, MXTensor): - assert b.act_quant_kwargs is not None, "weight-only quant not yet supported" - k = b.act_quant_kwargs - a = MXTensor.to_mx( - a, - k.elem_dtype, - k.block_size, - k.scaling_mode, - k.gemm_kernel_choice, - k.pack_fp6, - ) - - gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice) - - if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): - # real MX gemm backed by torchao's CUTLASS kernels - M, K, N = a.shape[0], a.shape[1], b.shape[1] - assert a.qdata.is_contiguous() - assert b.qdata.t().is_contiguous() - assert a._block_size == 32, f"Invalid block size {a._block_size}" - assert b._block_size == 32, f"Invalid block size {b._block_size}" - - a_scale = a._scale_e8m0.view(M, K // a._block_size) - b_scale = b._scale_e8m0.view(N, K // b._block_size) - a_scale_block = to_blocked(a_scale) - b_scale_block = to_blocked(b_scale) - - if a._elem_dtype == torch.float8_e4m3fn: - assert b._elem_dtype == torch.float8_e4m3fn - assert gemm_choice is MXGemmKernelChoice.CUBLAS, ( - "CUBLAS is the only supported kernel choice for MX FP8 operations" - ) - - res = torch._scaled_mm( - a.qdata, - b.qdata, - a_scale_block.view(torch.float8_e8m0fnu), - b_scale_block.view(torch.float8_e8m0fnu), - bias=bias, - out_dtype=torch.bfloat16, - ) - else: - assert a._elem_dtype == torch.float4_e2m1fn_x2 - assert b._elem_dtype == torch.float4_e2m1fn_x2 - assert gemm_choice is MXGemmKernelChoice.CUTLASS, "unsupported" - # FP4 operations - res = torchao.ops.mx_fp4_bf16( - a.qdata, b.qdata, a_scale_block, b_scale_block - ) - # TODO add optional bias to kernel - if bias is not None: - res = res + bias - - else: - # emulated MX gemm - a_hp = a.to_dtype(a._orig_dtype) - b_hp = b.to_dtype(b._orig_dtype) - # assert memory layout we expect to be required in hardware - assert a_hp.is_contiguous() - assert b_hp.t().is_contiguous() - - # Call appropriate aten_op based on whether bias is provided - if bias is not None: - res = aten_op(bias, a_hp, b_hp) # addmm - else: - res = aten_op(a_hp, b_hp) # mm - - return res - - -@implements([aten.mm.default, aten.matmul.default]) -def mx_mm(func, types, args, kwargs): - a = args[0] - b = args[1] - assert isinstance(b, MXTensor) - - return _addmm_mx_dispatch(a, b, func) - - -@implements([aten.addmm.default]) -def mx_addmm(func, types, args, kwargs): - assert isinstance(args[0], torch.Tensor) and isinstance(args[2], MXTensor) - bias = args[0] - a = args[1] - b = args[2] - return _addmm_mx_dispatch(a, b, func, bias=bias) - - -@implements([aten.t.default]) -def mx_t(func, types, args, kwargs): - # For now, only transpose(input, 0, 1) is supported. - old = args[0] - new = MXTensor( - old.qdata.t(), - old._scale_e8m0, - old._elem_dtype, - old._block_size, - old._orig_dtype, - old._gemm_kernel_choice, - old._pack_fp6, - old.act_quant_kwargs, - ) - return new - - -@implements([aten.sum.dim_IntList]) -def mx_cast_up_op(func, types, args, kwargs): - """Be careful with this function, this is a "fallback" op that - casts the output of the op to the original precision. And performs the op. - - We currently need this to support the backward for admmm bias. - "addmm" -> out - "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut" - """ - - def unwrap(x): - if isinstance(x, MXTensor): - return x.to_dtype(x._orig_dtype) - return x - - new_args = tree_map(unwrap, args) - new_kwargs = tree_map(unwrap, kwargs) - return func(*new_args, **new_kwargs) - - -@implements([aten.view.default]) -def mx_view_op(func, types, args, kwargs): - data = args[0].qdata - new_size = args[1] - if args[0]._elem_dtype == torch.float4_e2m1fn_x2: - # special case fp4 as we pack two elements per byte - new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) - elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6: - # special case fp6 as we pack 4 elements in 3 bytes - new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous()) - new_data = func(data, new_size, *args[2:], **kwargs) - return MXTensor( - new_data, - args[0]._scale_e8m0, - args[0]._elem_dtype, - args[0]._block_size, - args[0]._orig_dtype, - args[0]._gemm_kernel_choice, - args[0]._pack_fp6, - args[0].act_quant_kwargs, - ) - - -@implements([aten.slice.Tensor]) -def mx_slice(func, types, args, kwargs): - x, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - - if step != 1: - raise ValueError("Only support aten.slice with step=1") - - M, K = x.shape[0], x.shape[1] - - # TODO why doesn't scale have shape? - scale_shaped = x._scale_e8m0.view(M, K // x._block_size) - - if dim == 0: - # Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now - sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step) - sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step).unsqueeze(-1) - elif dim == 1: - # Slicing along reduciton dim - if start is not None: - # Assert start is a multiple of block_size - assert start % x._block_size == 0, ( - f"Start index {start} must be a multiple of block_size {x._block_size}" - ) - - if end is not None: - # Assert end is a multiple of block_size - assert end % x._block_size == 0, ( - f"End index {end} must be a multiple of block_size {x._block_size}" - ) - - sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step) - - # Calculate which scale elements to keep - start_block = 0 if start is None else start // x._block_size - end_block = -1 if end is None else end // x._block_size - - # Slice the scale tensor accordingly - sliced_scale = aten.slice.Tensor( - scale_shaped, 1, start_block, end_block, step - ).unsqueeze(-1) - else: - raise ValueError( - f"MXTensor only supports slicing along dimensions 0 and 1, got dim={dim}" - ) - - return return_and_correct_aliasing( - func, - args, - kwargs, - MXTensor( - sliced_data, - sliced_scale, - x._elem_dtype, - x._block_size, - x._orig_dtype, - x._gemm_kernel_choice, - x._pack_fp6, - x.act_quant_kwargs, - ), - ) - - -@implements([aten.copy_.default]) -def mx_copy_(func, types, args, kwargs): - self = args[0] - src = args[1] - if MXTensor._same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" - ) - - -@implements([aten.clone.default]) -def mx_clone(func, types, args, kwargs): - self = args[0] - memory_format = kwargs.get("memory_format", None) - - if memory_format is not None: - clone_fn = lambda x: x.clone(memory_format=memory_format) - else: - clone_fn = lambda x: x.clone() - - return self._apply_fn_to_data(clone_fn) - - -@implements([aten.select.int]) -def mx_select(func, types, args, kwargs): - old_mx_tensor, dim, index = args - assert dim == 0, f"MXTensor aten.select.int with {dim=} is not yet supported" - assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor._scale_e8m0.shape), ( - "unsupported" - ) - new_mx_tensor = old_mx_tensor.__class__( - old_mx_tensor.qdata[index], - old_mx_tensor._scale_e8m0[index], - old_mx_tensor._elem_dtype, - old_mx_tensor._block_size, - old_mx_tensor._orig_dtype, - old_mx_tensor._gemm_kernel_choice, - old_mx_tensor._pack_fp6, - old_mx_tensor.act_quant_kwargs, - ) - return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor) From 19ac204ae00d18cc9bd396003cbf77c791f5cc7e Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 07:57:27 -0700 Subject: [PATCH 11/13] Update [ghstack-poisoned] --- .../mx_formats/test_inference_workflow.py | 20 ++++++++++++++++++- torchao/prototype/mx_formats/mx_tensor.py | 18 +---------------- torchao/testing/utils.py | 20 ++++++++++++++++++- 3 files changed, 39 insertions(+), 19 deletions(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index df5a5a28ff..50b61cda48 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -21,7 +21,7 @@ ) from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error -from torchao.testing.utils import skip_if_rocm +from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_rocm from torchao.utils import ( is_sm_at_least_89, is_sm_at_least_100, @@ -190,3 +190,21 @@ def test_inference_workflow_nvfp4( assert sqnr >= SQNR_THRESHOLD, ( f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}" ) + + +class VLLMIntegrationTestCase(TorchAOIntegrationTestCase): + def test_slice_and_copy_similar_to_vllm(self): + config = MXFPInferenceConfig( + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + gemm_kernel_choice=MXGemmKernelChoice.EMULATED, + ) + self._test_slice_and_copy_similar_to_vllm(config) + + def test_narrow_similar_to_vllm(self): + config = MXFPInferenceConfig( + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + gemm_kernel_choice=MXGemmKernelChoice.EMULATED, + ) + self._test_narrow_similar_to_vllm(config) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 51ec74c12b..15d29431f2 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -836,9 +836,7 @@ def mx_slice(func, types, args, kwargs): end_block = -1 if end is None else end // x._block_size # Slice the scale tensor accordingly - sliced_scale = aten.slice.Tensor( - scale_shaped, 1, start_block, end_block, step - ).unsqueeze(-1) + sliced_scale = aten.slice.Tensor(scale_shaped, 1, start_block, end_block, step) else: raise ValueError( f"MXTensor only supports slicing along dimensions 0 and 1, got dim={dim}" @@ -861,20 +859,6 @@ def mx_slice(func, types, args, kwargs): ) -@implements([aten.copy_.default]) -def mx_copy_(func, types, args, kwargs): - self = args[0] - src = args[1] - if MXTensor._same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" - ) - - @implements([aten.clone.default]) def mx_clone(func, types, args, kwargs): self = args[0] diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index bb9c2ca8dc..5fec85fee6 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -16,6 +16,7 @@ ) import torchao +from torchao.core.config import AOBaseConfig from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx from torchao.quantization import Int8WeightOnlyConfig, quantize_ from torchao.quantization.quant_primitives import MappingType @@ -426,7 +427,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TorchAOIntegrationTestCase(common_utils.TestCase): - def _test_slice_and_copy_similar_to_vllm(self, config): + def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): # making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly # the test is similar to the linked code, but with some hardcoded arguments # and does not use tensor parallelism @@ -607,6 +608,23 @@ def process_key(key: str) -> torch.Tensor: # make sure it runs moe_combined(input) + def _test_narrow_similar_to_vllm(self, config: AOBaseConfig): + # this happens various times in vllm when slicing weights around + + dtype = torch.bfloat16 + l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype) + quantize_(l, config) + + orig = l.weight + new = orig.narrow(1, 0, 1024) + + for data_attr_name in new.tensor_data_names: + orig_attr = getattr(orig, data_attr_name) + new_attr = getattr(new, data_attr_name) + assert len(orig_attr.shape) == len(new_attr.shape), ( + f"shape mismatch: {orig_attr.shape} vs {new_attr.shape}" + ) + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) From 15248b14c0d126b84ab5986e13fa80456ae856cf Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 08:00:10 -0700 Subject: [PATCH 12/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 67e372a009..305ec64500 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -557,6 +557,7 @@ def test_to_mx_inductor_single_kernel(): FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipIf(not is_sm_at_least_90(), "Need sm90+") def test_index_select(): """ From 38942fe3ed27f65e75bcc8741dc2bc4199192a0c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Sep 2025 09:18:50 -0700 Subject: [PATCH 13/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_inference_workflow.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 50b61cda48..afdabecd5d 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -193,6 +193,11 @@ def test_inference_workflow_nvfp4( class VLLMIntegrationTestCase(TorchAOIntegrationTestCase): + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch_version_at_least("2.8.0"), + reason="torch.compile requires PyTorch 2.8+", + ) def test_slice_and_copy_similar_to_vllm(self): config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn, @@ -201,6 +206,11 @@ def test_slice_and_copy_similar_to_vllm(self): ) self._test_slice_and_copy_similar_to_vllm(config) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch_version_at_least("2.8.0"), + reason="torch.compile requires PyTorch 2.8+", + ) def test_narrow_similar_to_vllm(self): config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn,