diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 1eb12dee90..799b373dd6 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -133,6 +133,9 @@ def __new__( def __repr__(self): return f"NVFP4Tensor: blockwise_scales: {self._scale_e4m3}, per_tensor_scale: {self._per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}" + def _quantization_type(self): + return f"{self._is_swizzled_scales=}, {self.use_triton_kernel=}, {self.act_quant_kwargs=}" + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # Use NVFP4-specific ops table