diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 38eefbff07..2c89fae96d 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -455,6 +455,20 @@ def test_view(elem_dtype): x_mx_2 = x_mx.view(2, 4) # noqa: F841 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_clone(): + data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) + block_size = 4 + data_mx = MXTensor.to_mx(data, torch.float8_e4m3fn, block_size) + data_mx_c = data_mx.clone() + torch.testing.assert_close( + data_mx.to_dtype(torch.bfloat16), + data_mx_c.to_dtype(torch.bfloat16), + atol=0, + rtol=0, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]) @pytest.mark.parametrize("pack_fp6", [False, True]) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 07e47eed66..1779b0e278 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -352,3 +352,16 @@ def autocast_to_copy(func, types, args, kwargs): # If only device was changed, return the device-changed tensor return tensor + + +@implements([aten.clone.default]) +def mx_clone(func, types, args, kwargs): + self = args[0] + memory_format = kwargs.get("memory_format", None) + + if memory_format is not None: + clone_fn = lambda x: x.clone(memory_format=memory_format) + else: + clone_fn = lambda x: x.clone() + + return self._apply_fn_to_data(clone_fn)