diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 26e48216ee..69aa62afd4 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -7,7 +7,7 @@ import math from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional +from typing import Optional import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -39,8 +39,6 @@ aten = torch.ops.aten -NVFP4_OPS_TABLE: Dict[Any, Any] = {} - class NVFP4MMConfig(Enum): DYNAMIC = "dynamic" @@ -55,18 +53,6 @@ class QuantizeTensorToNVFP4Kwargs(QuantizeTensorKwargs): use_dynamic_per_tensor_scale: bool = False -# TODO(future PR): move over to TorchAOBaseTensor's dispatch -def implements(aten_ops): - """Register aten ops to the NVFP4 op table""" - - def decorator(func): - for op in aten_ops: - NVFP4_OPS_TABLE[op] = func - return func - - return decorator - - class NVFP4Tensor(TorchAOBaseTensor): """NVIDIA FP4 (NVFP4) Tensor subclass. @@ -141,14 +127,6 @@ def __repr__(self): def _quantization_type(self): return f"{self._is_swizzled_scales=}, {self.use_triton_kernel=}, {self.act_quant_kwargs=}" - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - # Use NVFP4-specific ops table - if func in NVFP4_OPS_TABLE: - return NVFP4_OPS_TABLE[func](func, types, args, kwargs) - - raise NotImplementedError(f"{func} not implemented for NVFP4Tensor") - @staticmethod def to_nvfp4( data_hp: torch.Tensor, @@ -308,13 +286,10 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool: ) -@implements([aten.detach.default, aten.alias.default]) -def nvfp4_detach_alias(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(func) - ) +implements = NVFP4Tensor.implements +# TODO(future PR): move this to AOBaseTensor (will require debugging/fixing CI) @implements([aten._to_copy.default]) def nvfp4_to_copy(func, types, args, kwargs): """Autocast + device movement""" @@ -354,33 +329,6 @@ def nvfp4_to_copy(func, types, args, kwargs): return tensor -@implements([aten.copy_.default]) -def nvfp4_copy_(func, types, args, kwargs): - self = args[0] - src = args[1] - if NVFP4Tensor._same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return self - raise ValueError( - f"Not supported args for copy_ due to metadata mismatch: {self}, {src}" - ) - - -@implements([aten.clone.default]) -def nvfp4_clone(func, types, args, kwargs): - self = args[0] - memory_format = kwargs.get("memory_format", None) - - if memory_format is not None: - clone_fn = lambda x: x.clone(memory_format=memory_format) - else: - clone_fn = lambda x: x.clone() - - return self._apply_fn_to_data(clone_fn) - - @implements([aten.slice.Tensor]) def nvfp4_slice(func, types, args, kwargs): x, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])