Skip to content

Commit ba4593f

Browse files
authored
Rename NVFP4Tensor's _per_tensor_scale and _act_per_tensor_scale fields (#3168)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 1c7ceea commit ba4593f

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,10 @@ def test_nvfp4_to_copy():
522522
y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16)
523523
assert torch.equal(x.qdata, y.qdata)
524524
assert torch.equal(x.scale, y.scale)
525-
assert x._per_tensor_scale is None
526-
assert y._per_tensor_scale is None
527-
assert x._act_per_tensor_scale is None
528-
assert y._act_per_tensor_scale is None
525+
assert x.per_tensor_scale is None
526+
assert y.per_tensor_scale is None
527+
assert x.act_per_tensor_scale is None
528+
assert y.act_per_tensor_scale is None
529529
assert x._block_size == y._block_size
530530
assert x.use_triton_kernel == y.use_triton_kernel
531531
assert x.act_quant_kwargs == y.act_quant_kwargs

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ class NVFP4Tensor(TorchAOBaseTensor):
7676
Attributes:
7777
qdata: Packed FP4 data (2 values per byte)
7878
scale: Blockwise scales in float8_e4m3fn format (may be swizzled)
79-
_per_tensor_scale: Optional global per-tensor scale in float32 format
80-
_act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation
79+
per_tensor_scale: Optional global per-tensor scale in float32 format
80+
act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation
8181
_block_size (int): Block size for quantization (fixed at 16)
8282
_orig_dtype (torch.dtype): Original tensor dtype before quantization
8383
_is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
@@ -89,7 +89,7 @@ class NVFP4Tensor(TorchAOBaseTensor):
8989
"_block_size",
9090
"_orig_dtype",
9191
]
92-
optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"]
92+
optional_tensor_data_names = ["per_tensor_scale", "act_per_tensor_scale"]
9393
optional_tensor_attribute_names = [
9494
"_is_swizzled_scales",
9595
"use_triton_kernel",
@@ -102,8 +102,8 @@ def __new__(
102102
scale,
103103
block_size,
104104
orig_dtype,
105-
_per_tensor_scale=None,
106-
_act_per_tensor_scale=None,
105+
per_tensor_scale=None,
106+
act_per_tensor_scale=None,
107107
_is_swizzled_scales=False,
108108
use_triton_kernel=False,
109109
act_quant_kwargs=None,
@@ -128,15 +128,15 @@ def __new__(
128128
self.scale = scale
129129
self._block_size = block_size
130130
self._orig_dtype = orig_dtype
131-
self._per_tensor_scale = _per_tensor_scale
132-
self._act_per_tensor_scale = _act_per_tensor_scale
131+
self.per_tensor_scale = per_tensor_scale
132+
self.act_per_tensor_scale = act_per_tensor_scale
133133
self._is_swizzled_scales = _is_swizzled_scales
134134
self.use_triton_kernel = use_triton_kernel
135135
self.act_quant_kwargs = act_quant_kwargs
136136
return self
137137

138138
def __repr__(self):
139-
return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self._per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}"
139+
return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self.per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}"
140140

141141
def _quantization_type(self):
142142
return f"{self._is_swizzled_scales=}, {self.use_triton_kernel=}, {self.act_quant_kwargs=}"
@@ -270,8 +270,8 @@ def get_hp_scales(self) -> torch.Tensor:
270270

271271
return (
272272
scale_e4m3.to(self._orig_dtype)
273-
if self._per_tensor_scale is None
274-
else self._per_tensor_scale * scale_e4m3.to(self._orig_dtype)
273+
if self.per_tensor_scale is None
274+
else self.per_tensor_scale * scale_e4m3.to(self._orig_dtype)
275275
)
276276

277277
@classmethod
@@ -286,11 +286,11 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
286286
bool: True if both tensors have identical metadata, False otherwise
287287
"""
288288
per_tensor_scale_equal = (
289-
self._per_tensor_scale is None and src._per_tensor_scale is None
290-
) or (self._per_tensor_scale.shape == src._per_tensor_scale.shape)
289+
self.per_tensor_scale is None and src.per_tensor_scale is None
290+
) or (self.per_tensor_scale.shape == src.per_tensor_scale.shape)
291291
act_per_tensor_scale_equal = (
292-
self._act_per_tensor_scale is None and src._act_per_tensor_scale is None
293-
) or (self._act_per_tensor_scale.shape == src._act_per_tensor_scale.shape)
292+
self.act_per_tensor_scale is None and src.act_per_tensor_scale is None
293+
) or (self.act_per_tensor_scale.shape == src.act_per_tensor_scale.shape)
294294

295295
return (
296296
isinstance(self, NVFP4Tensor)
@@ -341,8 +341,8 @@ def nvfp4_to_copy(func, types, args, kwargs):
341341
tensor.scale,
342342
tensor._block_size,
343343
dtype,
344-
tensor._per_tensor_scale,
345-
tensor._act_per_tensor_scale,
344+
tensor.per_tensor_scale,
345+
tensor.act_per_tensor_scale,
346346
tensor._is_swizzled_scales,
347347
tensor.use_triton_kernel,
348348
tensor.act_quant_kwargs,
@@ -565,8 +565,8 @@ def nvfp4_slice(func, types, args, kwargs):
565565
sliced_scale,
566566
x._block_size,
567567
x._orig_dtype,
568-
x._per_tensor_scale,
569-
x._act_per_tensor_scale,
568+
x.per_tensor_scale,
569+
x.act_per_tensor_scale,
570570
x._is_swizzled_scales,
571571
x.use_triton_kernel,
572572
x.act_quant_kwargs,
@@ -584,8 +584,8 @@ def nvfp4_t(func, types, args, kwargs):
584584
old.scale.t(),
585585
old._block_size,
586586
old._orig_dtype,
587-
old._per_tensor_scale,
588-
old._act_per_tensor_scale,
587+
old.per_tensor_scale,
588+
old.act_per_tensor_scale,
589589
old._is_swizzled_scales,
590590
old.use_triton_kernel,
591591
old.act_quant_kwargs,
@@ -606,8 +606,8 @@ def nvfp4_transpose(func, types, args, kwargs):
606606
new_scale,
607607
old._block_size,
608608
old._orig_dtype,
609-
old._per_tensor_scale,
610-
old._act_per_tensor_scale,
609+
old.per_tensor_scale,
610+
old.act_per_tensor_scale,
611611
old._is_swizzled_scales,
612612
old.use_triton_kernel,
613613
old.act_quant_kwargs,
@@ -626,8 +626,8 @@ def nvfp4_view_op(func, types, args, kwargs):
626626
args[0].scale,
627627
args[0]._block_size,
628628
args[0]._orig_dtype,
629-
args[0]._per_tensor_scale,
630-
args[0]._act_per_tensor_scale,
629+
args[0].per_tensor_scale,
630+
args[0].act_per_tensor_scale,
631631
args[0]._is_swizzled_scales,
632632
args[0].use_triton_kernel,
633633
args[0].act_quant_kwargs,
@@ -644,8 +644,8 @@ def nvfp4_select(func, types, args, kwargs):
644644
old.scale[index],
645645
old._block_size,
646646
old._orig_dtype,
647-
old._per_tensor_scale,
648-
old._act_per_tensor_scale,
647+
old.per_tensor_scale,
648+
old.act_per_tensor_scale,
649649
old._is_swizzled_scales,
650650
old.use_triton_kernel,
651651
old.act_quant_kwargs,
@@ -684,11 +684,11 @@ def _addmm_nvfp4_dispatch(
684684
b_scale_blocked = to_blocked(b_scale)
685685

686686
# Merge double quant scales into 1 scale for Scale_In^D
687-
if a._per_tensor_scale is not None:
688-
assert b._per_tensor_scale is not None
689-
scale_result = a._per_tensor_scale * b._per_tensor_scale
687+
if a.per_tensor_scale is not None:
688+
assert b.per_tensor_scale is not None
689+
scale_result = a.per_tensor_scale * b.per_tensor_scale
690690
else:
691-
assert b._per_tensor_scale is None and a._per_tensor_scale is None
691+
assert b.per_tensor_scale is None and a.per_tensor_scale is None
692692
scale_result = None
693693

694694
# THIS IS A WORKAROUND:
@@ -772,7 +772,7 @@ def nvfp4_mm(func, types, args, kwargs):
772772
tensor_amax = torch.max(torch.abs(input_tensor))
773773
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
774774
else:
775-
per_tensor_scale = weight_tensor._act_per_tensor_scale
775+
per_tensor_scale = weight_tensor.act_per_tensor_scale
776776
input_tensor = NVFP4Tensor.to_nvfp4(
777777
input_tensor,
778778
block_size=k.block_size,
@@ -805,7 +805,7 @@ def nvfp4_addmm(func, types, args, kwargs):
805805
tensor_amax = torch.max(torch.abs(input_tensor))
806806
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
807807
else:
808-
per_tensor_scale = weight_tensor._act_per_tensor_scale
808+
per_tensor_scale = weight_tensor.act_per_tensor_scale
809809
input_tensor = NVFP4Tensor.to_nvfp4(
810810
input_tensor,
811811
block_size=k.block_size,

0 commit comments

Comments
 (0)