Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math

import pytest
import torch
from torch._inductor.utils import run_and_get_code
Expand All @@ -22,6 +24,7 @@
ScaleCalculationMode,
to_dtype,
)
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
from torchao.quantization.utils import compute_error
from torchao.utils import (
is_sm_at_least_89,
Expand Down Expand Up @@ -388,6 +391,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
MXGemmKernelChoice.EMULATED,
pack_fp6,
None,
False,
)
tensor_hp = tensor_mx.dequantize(torch.float)
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))
Expand Down Expand Up @@ -645,8 +649,6 @@ def to_f8(x):
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked

rows, cols = shape
device = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -716,3 +718,68 @@ def test_scale_shape_matches_qdata(transpose, shape):
assert expected_padded_k == actual_padded_k, (
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}"
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
@pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2))
@pytest.mark.parametrize("transpose", [False, True])
@pytest.mark.parametrize(
"shape",
(
(128, 64),
(1, 128, 64),
),
)
def test_swizzle(elem_dtype, transpose, shape):
if len(shape) == 3 and transpose:
pytest.skip("transpose not yet implemented for 3D MXTensor")

block_size = 32

x_hp = torch.randn(*shape, device="cuda")
x = MXTensor.to_mx(
x_hp,
elem_dtype,
block_size,
ScaleCalculationMode.FLOOR,
)

xs = MXTensor.to_mx(
x_hp,
elem_dtype,
block_size,
ScaleCalculationMode.FLOOR,
is_swizzled_scales=True,
)

if transpose:
x = x.t()
xs = xs.t()

torch.testing.assert_close(x.qdata, xs.qdata, atol=0, rtol=0)

if transpose:
leading_dims, M, K = x.shape[:-2], x.shape[-1], x.shape[-2]
xs_scale_unblocked = from_blocked(
xs.scale.t(), math.prod(leading_dims) * M, K // block_size
)
xs_scale_unblocked = xs_scale_unblocked.view(*leading_dims, M, K // block_size)
xs_scale_unblocked = xs_scale_unblocked.t()
else:
leading_dims, M, K = x.shape[:-2], x.shape[-2], x.shape[-1]
xs_scale_unblocked = from_blocked(
xs.scale, math.prod(leading_dims) * M, K // block_size
)
xs_scale_unblocked = xs_scale_unblocked.view(*leading_dims, M, K // block_size)

torch.testing.assert_close(
x.scale,
xs_scale_unblocked,
atol=0,
rtol=0,
)

x_dq = x.dequantize(x.dtype)
xs_dq = xs.dequantize(xs.dtype)
torch.testing.assert_close(x_dq, xs_dq, atol=0, rtol=0)
2 changes: 2 additions & 0 deletions torchao/prototype/mx_formats/inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _mx_inference_linear_transform(
block_size=config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
pack_fp6=False,
is_swizzled_scales=True,
)

# Convert weight to MX Tensor
Expand All @@ -121,6 +122,7 @@ def _mx_inference_linear_transform(
gemm_kernel_choice=config.gemm_kernel_choice,
pack_fp6=False, # TODO
act_quant_kwargs=act_quant_kwargs,
is_swizzled_scales=True,
)

module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
Expand Down
Loading
Loading