Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchao.quantization.utils import compute_error
from torchao.utils import (
is_sm_at_least_89,
is_sm_at_least_90,
is_sm_at_least_100,
torch_version_at_least,
)
Expand Down Expand Up @@ -556,6 +557,26 @@ def test_to_mx_inductor_single_kernel():
FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0])


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipIf(not is_sm_at_least_90(), "Need sm90+")
def test_index_select():
"""
test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is
useful when stitching checkpoints of `num_experts` 2D parameters into
a single 3D parameter when converting between model definitions that
use 2D and 3D parameters for their expert weights.
"""

E, K, N = 128, 256, 512
x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16)
x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn, 32)

x_mx_1 = x_mx[1]
torch.testing.assert_close(
x_mx.to_dtype(x.dtype)[1], x_mx_1.to_dtype(x.dtype), atol=0, rtol=0
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_89(),
Expand Down
20 changes: 20 additions & 0 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,23 @@ def mx_clone(func, types, args, kwargs):
clone_fn = lambda x: x.clone()

return self._apply_fn_to_data(clone_fn)


@implements([aten.select.int])
def mx_select(func, types, args, kwargs):
old_mx_tensor, dim, index = args
assert dim == 0, f"MXTensor aten.select.int with {dim=} is not yet supported"
assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor._scale_e8m0.shape), (
"unsupported"
)
new_mx_tensor = old_mx_tensor.__class__(
old_mx_tensor.qdata[index],
old_mx_tensor._scale_e8m0[index],
old_mx_tensor._elem_dtype,
old_mx_tensor._block_size,
old_mx_tensor._orig_dtype,
old_mx_tensor._gemm_kernel_choice,
old_mx_tensor._pack_fp6,
old_mx_tensor.act_quant_kwargs,
)
return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor)
Loading