diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 2c89fae96d..305ec64500 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -26,6 +26,7 @@ from torchao.quantization.utils import compute_error from torchao.utils import ( is_sm_at_least_89, + is_sm_at_least_90, is_sm_at_least_100, torch_version_at_least, ) @@ -556,6 +557,26 @@ def test_to_mx_inductor_single_kernel(): FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipIf(not is_sm_at_least_90(), "Need sm90+") +def test_index_select(): + """ + test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is + useful when stitching checkpoints of `num_experts` 2D parameters into + a single 3D parameter when converting between model definitions that + use 2D and 3D parameters for their expert weights. + """ + + E, K, N = 128, 256, 512 + x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16) + x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn, 32) + + x_mx_1 = x_mx[1] + torch.testing.assert_close( + x_mx.to_dtype(x.dtype)[1], x_mx_1.to_dtype(x.dtype), atol=0, rtol=0 + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not is_sm_at_least_89(), diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index d870698601..8fcda5e69c 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -322,3 +322,23 @@ def mx_clone(func, types, args, kwargs): clone_fn = lambda x: x.clone() return self._apply_fn_to_data(clone_fn) + + +@implements([aten.select.int]) +def mx_select(func, types, args, kwargs): + old_mx_tensor, dim, index = args + assert dim == 0, f"MXTensor aten.select.int with {dim=} is not yet supported" + assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor._scale_e8m0.shape), ( + "unsupported" + ) + new_mx_tensor = old_mx_tensor.__class__( + old_mx_tensor.qdata[index], + old_mx_tensor._scale_e8m0[index], + old_mx_tensor._elem_dtype, + old_mx_tensor._block_size, + old_mx_tensor._orig_dtype, + old_mx_tensor._gemm_kernel_choice, + old_mx_tensor._pack_fp6, + old_mx_tensor.act_quant_kwargs, + ) + return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor)