diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 1779b0e278..d870698601 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -311,49 +311,6 @@ def mx_copy_(func, types, args, kwargs): ) -@implements([aten._to_copy.default]) -def autocast_to_copy(func, types, args, kwargs): - """Autocast + device movement""" - assert isinstance(args[0], MXTensor) - - # Handle dtype parameter - dtype = kwargs.pop("dtype", None) - if dtype is not None: - assert dtype in { - torch.float16, - torch.bfloat16, - }, "Only support floating point conversion for autocast w/ MXTensor" - - # Handle device parameter - device = kwargs.pop("device", None) - if device is not None: - # Apply device change using _apply_fn_to_data - tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device)) - tensor = return_and_correct_aliasing(func, args, {}, tensor) - else: - tensor = args[0] - - # Verify no other kwargs remain - assert len(kwargs) == 0, "Only support dtype and device kwargs for autocast" - - # If dtype is specified, create a new MXTensor with the requested dtype - if dtype is not None: - res = MXTensor( - tensor.qdata, - tensor._scale_e8m0, - tensor._elem_dtype, - tensor._block_size, - dtype, - tensor._gemm_kernel_choice, - tensor._pack_fp6, - tensor.act_quant_kwargs, - ) - return res - - # If only device was changed, return the device-changed tensor - return tensor - - @implements([aten.clone.default]) def mx_clone(func, types, args, kwargs): self = args[0]