Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
08faa8e
Update
vkuzo Sep 25, 2025
c97533c
Update
vkuzo Sep 25, 2025
b88afaf
Update
vkuzo Sep 25, 2025
00d2634
Update
vkuzo Sep 26, 2025
5a840c1
Update
vkuzo Sep 26, 2025
ff57676
Update
vkuzo Sep 26, 2025
4edba12
Update
vkuzo Sep 26, 2025
6d6e465
Update
vkuzo Sep 26, 2025
1f1fc5e
Update
vkuzo Sep 26, 2025
f58607e
Update
vkuzo Sep 26, 2025
263ad98
Update
vkuzo Sep 26, 2025
235494e
Update
vkuzo Sep 26, 2025
ebd3226
Update
vkuzo Sep 26, 2025
b9dbfa8
Update
vkuzo Sep 26, 2025
19ac204
Update
vkuzo Sep 26, 2025
15248b1
Update
vkuzo Sep 26, 2025
1cc9581
Update
vkuzo Sep 26, 2025
3ea5b94
Update
vkuzo Sep 26, 2025
38942fe
Update
vkuzo Sep 26, 2025
9482380
Update
vkuzo Sep 26, 2025
c9bc96c
Update
vkuzo Sep 26, 2025
a668c02
Update
vkuzo Sep 26, 2025
04f54e8
Update
vkuzo Sep 26, 2025
2fa1c9b
Update
vkuzo Sep 26, 2025
a88918e
Update
vkuzo Sep 26, 2025
1735077
Update
vkuzo Sep 26, 2025
7d12624
Update
vkuzo Sep 26, 2025
f4471a5
Update
vkuzo Sep 26, 2025
4c8a966
Update
vkuzo Sep 26, 2025
088a286
Update
vkuzo Sep 26, 2025
1b0ec76
Update
vkuzo Sep 26, 2025
3d22740
Update
vkuzo Sep 26, 2025
9ea1221
Update
vkuzo Sep 26, 2025
e9cea19
Update
vkuzo Sep 26, 2025
4fb76ae
Update
vkuzo Sep 26, 2025
e4f2855
Update
vkuzo Sep 26, 2025
9e7094e
Update
vkuzo Sep 26, 2025
1b57f52
Update
vkuzo Sep 26, 2025
2e73cef
Update
vkuzo Sep 26, 2025
0f12582
Update
vkuzo Sep 26, 2025
e3c719a
Update
vkuzo Sep 26, 2025
5d4f713
Update
vkuzo Sep 26, 2025
deeef68
Update
vkuzo Sep 26, 2025
f07fb27
Update
vkuzo Sep 26, 2025
3467ee3
Update
vkuzo Sep 26, 2025
e7de5db
Update
vkuzo Sep 26, 2025
f962851
Update
vkuzo Sep 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from torchao.quantization import quantize_
from torchao.quantization.utils import compute_error
from torchao.testing.utils import skip_if_rocm
from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_rocm
from torchao.utils import (
is_sm_at_least_89,
is_sm_at_least_100,
Expand Down Expand Up @@ -190,3 +190,31 @@ def test_inference_workflow_nvfp4(
assert sqnr >= SQNR_THRESHOLD, (
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}"
)


class VLLMIntegrationTestCase(TorchAOIntegrationTestCase):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"),
reason="torch.compile requires PyTorch 2.8+",
)
def test_slice_and_copy_similar_to_vllm(self):
config = MXFPInferenceConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
)
self._test_slice_and_copy_similar_to_vllm(config)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"),
reason="torch.compile requires PyTorch 2.8+",
)
def test_narrow_similar_to_vllm(self):
config = MXFPInferenceConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
)
self._test_narrow_similar_to_vllm(config)
18 changes: 1 addition & 17 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,9 +836,7 @@ def mx_slice(func, types, args, kwargs):
end_block = -1 if end is None else end // x._block_size

# Slice the scale tensor accordingly
sliced_scale = aten.slice.Tensor(
scale_shaped, 1, start_block, end_block, step
).unsqueeze(-1)
sliced_scale = aten.slice.Tensor(scale_shaped, 1, start_block, end_block, step)
else:
raise ValueError(
f"MXTensor only supports slicing along dimensions 0 and 1, got dim={dim}"
Expand All @@ -861,20 +859,6 @@ def mx_slice(func, types, args, kwargs):
)


@implements([aten.copy_.default])
def mx_copy_(func, types, args, kwargs):
self = args[0]
src = args[1]
if MXTensor._same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)


@implements([aten.clone.default])
def mx_clone(func, types, args, kwargs):
self = args[0]
Expand Down
20 changes: 19 additions & 1 deletion torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

import torchao
from torchao.core.config import AOBaseConfig
from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx
from torchao.quantization import Int8WeightOnlyConfig, quantize_
from torchao.quantization.quant_primitives import MappingType
Expand Down Expand Up @@ -426,7 +427,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class TorchAOIntegrationTestCase(common_utils.TestCase):
def _test_slice_and_copy_similar_to_vllm(self, config):
def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
# making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly
# the test is similar to the linked code, but with some hardcoded arguments
# and does not use tensor parallelism
Expand Down Expand Up @@ -607,6 +608,23 @@ def process_key(key: str) -> torch.Tensor:
# make sure it runs
moe_combined(input)

def _test_narrow_similar_to_vllm(self, config: AOBaseConfig):
# this happens various times in vllm when slicing weights around

dtype = torch.bfloat16
l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype)
quantize_(l, config)

orig = l.weight
new = orig.narrow(1, 0, 1024)

for data_attr_name in new.tensor_data_names:
orig_attr = getattr(orig, data_attr_name)
new_attr = getattr(new, data_attr_name)
assert len(orig_attr.shape) == len(new_attr.shape), (
f"shape mismatch: {orig_attr.shape} vs {new_attr.shape}"
)


common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)
Expand Down
Loading