diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index df5a5a28ff..afdabecd5d 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -21,7 +21,7 @@ ) from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error -from torchao.testing.utils import skip_if_rocm +from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_rocm from torchao.utils import ( is_sm_at_least_89, is_sm_at_least_100, @@ -190,3 +190,31 @@ def test_inference_workflow_nvfp4( assert sqnr >= SQNR_THRESHOLD, ( f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}" ) + + +class VLLMIntegrationTestCase(TorchAOIntegrationTestCase): + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch_version_at_least("2.8.0"), + reason="torch.compile requires PyTorch 2.8+", + ) + def test_slice_and_copy_similar_to_vllm(self): + config = MXFPInferenceConfig( + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + gemm_kernel_choice=MXGemmKernelChoice.EMULATED, + ) + self._test_slice_and_copy_similar_to_vllm(config) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch_version_at_least("2.8.0"), + reason="torch.compile requires PyTorch 2.8+", + ) + def test_narrow_similar_to_vllm(self): + config = MXFPInferenceConfig( + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + gemm_kernel_choice=MXGemmKernelChoice.EMULATED, + ) + self._test_narrow_similar_to_vllm(config) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 2eff1cc431..432ef393d2 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -836,9 +836,7 @@ def mx_slice(func, types, args, kwargs): end_block = -1 if end is None else end // x._block_size # Slice the scale tensor accordingly - sliced_scale = aten.slice.Tensor( - scale_shaped, 1, start_block, end_block, step - ).unsqueeze(-1) + sliced_scale = aten.slice.Tensor(scale_shaped, 1, start_block, end_block, step) else: raise ValueError( f"MXTensor only supports slicing along dimensions 0 and 1, got dim={dim}" @@ -861,20 +859,6 @@ def mx_slice(func, types, args, kwargs): ) -@implements([aten.copy_.default]) -def mx_copy_(func, types, args, kwargs): - self = args[0] - src = args[1] - if MXTensor._same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" - ) - - @implements([aten.clone.default]) def mx_clone(func, types, args, kwargs): self = args[0] diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index bb9c2ca8dc..5fec85fee6 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -16,6 +16,7 @@ ) import torchao +from torchao.core.config import AOBaseConfig from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx from torchao.quantization import Int8WeightOnlyConfig, quantize_ from torchao.quantization.quant_primitives import MappingType @@ -426,7 +427,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TorchAOIntegrationTestCase(common_utils.TestCase): - def _test_slice_and_copy_similar_to_vllm(self, config): + def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): # making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly # the test is similar to the linked code, but with some hardcoded arguments # and does not use tensor parallelism @@ -607,6 +608,23 @@ def process_key(key: str) -> torch.Tensor: # make sure it runs moe_combined(input) + def _test_narrow_similar_to_vllm(self, config: AOBaseConfig): + # this happens various times in vllm when slicing weights around + + dtype = torch.bfloat16 + l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype) + quantize_(l, config) + + orig = l.weight + new = orig.narrow(1, 0, 1024) + + for data_attr_name in new.tensor_data_names: + orig_attr = getattr(orig, data_attr_name) + new_attr = getattr(new, data_attr_name) + assert len(orig_attr.shape) == len(new_attr.shape), ( + f"shape mismatch: {orig_attr.shape} vs {new_attr.shape}" + ) + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)