diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 8893a1e9f6..16947b26ef 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -522,10 +522,10 @@ def test_nvfp4_to_copy(): y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16) assert torch.equal(x.qdata, y.qdata) assert torch.equal(x.scale, y.scale) - assert x._per_tensor_scale is None - assert y._per_tensor_scale is None - assert x._act_per_tensor_scale is None - assert y._act_per_tensor_scale is None + assert x.per_tensor_scale is None + assert y.per_tensor_scale is None + assert x.act_per_tensor_scale is None + assert y.act_per_tensor_scale is None assert x._block_size == y._block_size assert x.use_triton_kernel == y.use_triton_kernel assert x.act_quant_kwargs == y.act_quant_kwargs diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index cbc87313b9..d19991f326 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -76,8 +76,8 @@ class NVFP4Tensor(TorchAOBaseTensor): Attributes: qdata: Packed FP4 data (2 values per byte) scale: Blockwise scales in float8_e4m3fn format (may be swizzled) - _per_tensor_scale: Optional global per-tensor scale in float32 format - _act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation + per_tensor_scale: Optional global per-tensor scale in float32 format + act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation _block_size (int): Block size for quantization (fixed at 16) _orig_dtype (torch.dtype): Original tensor dtype before quantization _is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format @@ -89,7 +89,7 @@ class NVFP4Tensor(TorchAOBaseTensor): "_block_size", "_orig_dtype", ] - optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"] + optional_tensor_data_names = ["per_tensor_scale", "act_per_tensor_scale"] optional_tensor_attribute_names = [ "_is_swizzled_scales", "use_triton_kernel", @@ -102,8 +102,8 @@ def __new__( scale, block_size, orig_dtype, - _per_tensor_scale=None, - _act_per_tensor_scale=None, + per_tensor_scale=None, + act_per_tensor_scale=None, _is_swizzled_scales=False, use_triton_kernel=False, act_quant_kwargs=None, @@ -128,15 +128,15 @@ def __new__( self.scale = scale self._block_size = block_size self._orig_dtype = orig_dtype - self._per_tensor_scale = _per_tensor_scale - self._act_per_tensor_scale = _act_per_tensor_scale + self.per_tensor_scale = per_tensor_scale + self.act_per_tensor_scale = act_per_tensor_scale self._is_swizzled_scales = _is_swizzled_scales self.use_triton_kernel = use_triton_kernel self.act_quant_kwargs = act_quant_kwargs return self def __repr__(self): - return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self._per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}" + return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self.per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}" def _quantization_type(self): return f"{self._is_swizzled_scales=}, {self.use_triton_kernel=}, {self.act_quant_kwargs=}" @@ -270,8 +270,8 @@ def get_hp_scales(self) -> torch.Tensor: return ( scale_e4m3.to(self._orig_dtype) - if self._per_tensor_scale is None - else self._per_tensor_scale * scale_e4m3.to(self._orig_dtype) + if self.per_tensor_scale is None + else self.per_tensor_scale * scale_e4m3.to(self._orig_dtype) ) @classmethod @@ -286,11 +286,11 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool: bool: True if both tensors have identical metadata, False otherwise """ per_tensor_scale_equal = ( - self._per_tensor_scale is None and src._per_tensor_scale is None - ) or (self._per_tensor_scale.shape == src._per_tensor_scale.shape) + self.per_tensor_scale is None and src.per_tensor_scale is None + ) or (self.per_tensor_scale.shape == src.per_tensor_scale.shape) act_per_tensor_scale_equal = ( - self._act_per_tensor_scale is None and src._act_per_tensor_scale is None - ) or (self._act_per_tensor_scale.shape == src._act_per_tensor_scale.shape) + self.act_per_tensor_scale is None and src.act_per_tensor_scale is None + ) or (self.act_per_tensor_scale.shape == src.act_per_tensor_scale.shape) return ( isinstance(self, NVFP4Tensor) @@ -341,8 +341,8 @@ def nvfp4_to_copy(func, types, args, kwargs): tensor.scale, tensor._block_size, dtype, - tensor._per_tensor_scale, - tensor._act_per_tensor_scale, + tensor.per_tensor_scale, + tensor.act_per_tensor_scale, tensor._is_swizzled_scales, tensor.use_triton_kernel, tensor.act_quant_kwargs, @@ -565,8 +565,8 @@ def nvfp4_slice(func, types, args, kwargs): sliced_scale, x._block_size, x._orig_dtype, - x._per_tensor_scale, - x._act_per_tensor_scale, + x.per_tensor_scale, + x.act_per_tensor_scale, x._is_swizzled_scales, x.use_triton_kernel, x.act_quant_kwargs, @@ -584,8 +584,8 @@ def nvfp4_t(func, types, args, kwargs): old.scale.t(), old._block_size, old._orig_dtype, - old._per_tensor_scale, - old._act_per_tensor_scale, + old.per_tensor_scale, + old.act_per_tensor_scale, old._is_swizzled_scales, old.use_triton_kernel, old.act_quant_kwargs, @@ -606,8 +606,8 @@ def nvfp4_transpose(func, types, args, kwargs): new_scale, old._block_size, old._orig_dtype, - old._per_tensor_scale, - old._act_per_tensor_scale, + old.per_tensor_scale, + old.act_per_tensor_scale, old._is_swizzled_scales, old.use_triton_kernel, old.act_quant_kwargs, @@ -626,8 +626,8 @@ def nvfp4_view_op(func, types, args, kwargs): args[0].scale, args[0]._block_size, args[0]._orig_dtype, - args[0]._per_tensor_scale, - args[0]._act_per_tensor_scale, + args[0].per_tensor_scale, + args[0].act_per_tensor_scale, args[0]._is_swizzled_scales, args[0].use_triton_kernel, args[0].act_quant_kwargs, @@ -644,8 +644,8 @@ def nvfp4_select(func, types, args, kwargs): old.scale[index], old._block_size, old._orig_dtype, - old._per_tensor_scale, - old._act_per_tensor_scale, + old.per_tensor_scale, + old.act_per_tensor_scale, old._is_swizzled_scales, old.use_triton_kernel, old.act_quant_kwargs, @@ -684,11 +684,11 @@ def _addmm_nvfp4_dispatch( b_scale_blocked = to_blocked(b_scale) # Merge double quant scales into 1 scale for Scale_In^D - if a._per_tensor_scale is not None: - assert b._per_tensor_scale is not None - scale_result = a._per_tensor_scale * b._per_tensor_scale + if a.per_tensor_scale is not None: + assert b.per_tensor_scale is not None + scale_result = a.per_tensor_scale * b.per_tensor_scale else: - assert b._per_tensor_scale is None and a._per_tensor_scale is None + assert b.per_tensor_scale is None and a.per_tensor_scale is None scale_result = None # THIS IS A WORKAROUND: @@ -772,7 +772,7 @@ def nvfp4_mm(func, types, args, kwargs): tensor_amax = torch.max(torch.abs(input_tensor)) per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) else: - per_tensor_scale = weight_tensor._act_per_tensor_scale + per_tensor_scale = weight_tensor.act_per_tensor_scale input_tensor = NVFP4Tensor.to_nvfp4( input_tensor, block_size=k.block_size, @@ -805,7 +805,7 @@ def nvfp4_addmm(func, types, args, kwargs): tensor_amax = torch.max(torch.abs(input_tensor)) per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) else: - per_tensor_scale = weight_tensor._act_per_tensor_scale + per_tensor_scale = weight_tensor.act_per_tensor_scale input_tensor = NVFP4Tensor.to_nvfp4( input_tensor, block_size=k.block_size,