Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 0 additions & 43 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,49 +311,6 @@ def mx_copy_(func, types, args, kwargs):
)


@implements([aten._to_copy.default])
def autocast_to_copy(func, types, args, kwargs):
"""Autocast + device movement"""
assert isinstance(args[0], MXTensor)

# Handle dtype parameter
dtype = kwargs.pop("dtype", None)
if dtype is not None:
assert dtype in {
torch.float16,
torch.bfloat16,
}, "Only support floating point conversion for autocast w/ MXTensor"

# Handle device parameter
device = kwargs.pop("device", None)
if device is not None:
# Apply device change using _apply_fn_to_data
tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device))
tensor = return_and_correct_aliasing(func, args, {}, tensor)
else:
tensor = args[0]

# Verify no other kwargs remain
assert len(kwargs) == 0, "Only support dtype and device kwargs for autocast"

# If dtype is specified, create a new MXTensor with the requested dtype
if dtype is not None:
res = MXTensor(
tensor.qdata,
tensor._scale_e8m0,
tensor._elem_dtype,
tensor._block_size,
dtype,
tensor._gemm_kernel_choice,
tensor._pack_fp6,
tensor.act_quant_kwargs,
)
return res

# If only device was changed, return the device-changed tensor
return tensor


@implements([aten.clone.default])
def mx_clone(func, types, args, kwargs):
self = args[0]
Expand Down
Loading