@@ -76,8 +76,8 @@ class NVFP4Tensor(TorchAOBaseTensor):
7676 Attributes:
7777 qdata: Packed FP4 data (2 values per byte)
7878 scale: Blockwise scales in float8_e4m3fn format (may be swizzled)
79- _per_tensor_scale : Optional global per-tensor scale in float32 format
80- _act_per_tensor_scale : Optional global per-tensor scale in float32 format, for activation
79+ per_tensor_scale : Optional global per-tensor scale in float32 format
80+ act_per_tensor_scale : Optional global per-tensor scale in float32 format, for activation
8181 _block_size (int): Block size for quantization (fixed at 16)
8282 _orig_dtype (torch.dtype): Original tensor dtype before quantization
8383 _is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
@@ -89,7 +89,7 @@ class NVFP4Tensor(TorchAOBaseTensor):
8989 "_block_size" ,
9090 "_orig_dtype" ,
9191 ]
92- optional_tensor_data_names = ["_per_tensor_scale " , "_act_per_tensor_scale " ]
92+ optional_tensor_data_names = ["per_tensor_scale " , "act_per_tensor_scale " ]
9393 optional_tensor_attribute_names = [
9494 "_is_swizzled_scales" ,
9595 "use_triton_kernel" ,
@@ -102,8 +102,8 @@ def __new__(
102102 scale ,
103103 block_size ,
104104 orig_dtype ,
105- _per_tensor_scale = None ,
106- _act_per_tensor_scale = None ,
105+ per_tensor_scale = None ,
106+ act_per_tensor_scale = None ,
107107 _is_swizzled_scales = False ,
108108 use_triton_kernel = False ,
109109 act_quant_kwargs = None ,
@@ -128,15 +128,15 @@ def __new__(
128128 self .scale = scale
129129 self ._block_size = block_size
130130 self ._orig_dtype = orig_dtype
131- self ._per_tensor_scale = _per_tensor_scale
132- self ._act_per_tensor_scale = _act_per_tensor_scale
131+ self .per_tensor_scale = per_tensor_scale
132+ self .act_per_tensor_scale = act_per_tensor_scale
133133 self ._is_swizzled_scales = _is_swizzled_scales
134134 self .use_triton_kernel = use_triton_kernel
135135 self .act_quant_kwargs = act_quant_kwargs
136136 return self
137137
138138 def __repr__ (self ):
139- return f"NVFP4Tensor: scale: { self .scale } , per_tensor_scale: { self ._per_tensor_scale } , d: { self .qdata } , d_hp: { self .to_dtype (self ._orig_dtype )} "
139+ return f"NVFP4Tensor: scale: { self .scale } , per_tensor_scale: { self .per_tensor_scale } , d: { self .qdata } , d_hp: { self .to_dtype (self ._orig_dtype )} "
140140
141141 def _quantization_type (self ):
142142 return f"{ self ._is_swizzled_scales = } , { self .use_triton_kernel = } , { self .act_quant_kwargs = } "
@@ -270,8 +270,8 @@ def get_hp_scales(self) -> torch.Tensor:
270270
271271 return (
272272 scale_e4m3 .to (self ._orig_dtype )
273- if self ._per_tensor_scale is None
274- else self ._per_tensor_scale * scale_e4m3 .to (self ._orig_dtype )
273+ if self .per_tensor_scale is None
274+ else self .per_tensor_scale * scale_e4m3 .to (self ._orig_dtype )
275275 )
276276
277277 @classmethod
@@ -286,11 +286,11 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
286286 bool: True if both tensors have identical metadata, False otherwise
287287 """
288288 per_tensor_scale_equal = (
289- self ._per_tensor_scale is None and src ._per_tensor_scale is None
290- ) or (self ._per_tensor_scale .shape == src ._per_tensor_scale .shape )
289+ self .per_tensor_scale is None and src .per_tensor_scale is None
290+ ) or (self .per_tensor_scale .shape == src .per_tensor_scale .shape )
291291 act_per_tensor_scale_equal = (
292- self ._act_per_tensor_scale is None and src ._act_per_tensor_scale is None
293- ) or (self ._act_per_tensor_scale .shape == src ._act_per_tensor_scale .shape )
292+ self .act_per_tensor_scale is None and src .act_per_tensor_scale is None
293+ ) or (self .act_per_tensor_scale .shape == src .act_per_tensor_scale .shape )
294294
295295 return (
296296 isinstance (self , NVFP4Tensor )
@@ -341,8 +341,8 @@ def nvfp4_to_copy(func, types, args, kwargs):
341341 tensor .scale ,
342342 tensor ._block_size ,
343343 dtype ,
344- tensor ._per_tensor_scale ,
345- tensor ._act_per_tensor_scale ,
344+ tensor .per_tensor_scale ,
345+ tensor .act_per_tensor_scale ,
346346 tensor ._is_swizzled_scales ,
347347 tensor .use_triton_kernel ,
348348 tensor .act_quant_kwargs ,
@@ -565,8 +565,8 @@ def nvfp4_slice(func, types, args, kwargs):
565565 sliced_scale ,
566566 x ._block_size ,
567567 x ._orig_dtype ,
568- x ._per_tensor_scale ,
569- x ._act_per_tensor_scale ,
568+ x .per_tensor_scale ,
569+ x .act_per_tensor_scale ,
570570 x ._is_swizzled_scales ,
571571 x .use_triton_kernel ,
572572 x .act_quant_kwargs ,
@@ -584,8 +584,8 @@ def nvfp4_t(func, types, args, kwargs):
584584 old .scale .t (),
585585 old ._block_size ,
586586 old ._orig_dtype ,
587- old ._per_tensor_scale ,
588- old ._act_per_tensor_scale ,
587+ old .per_tensor_scale ,
588+ old .act_per_tensor_scale ,
589589 old ._is_swizzled_scales ,
590590 old .use_triton_kernel ,
591591 old .act_quant_kwargs ,
@@ -606,8 +606,8 @@ def nvfp4_transpose(func, types, args, kwargs):
606606 new_scale ,
607607 old ._block_size ,
608608 old ._orig_dtype ,
609- old ._per_tensor_scale ,
610- old ._act_per_tensor_scale ,
609+ old .per_tensor_scale ,
610+ old .act_per_tensor_scale ,
611611 old ._is_swizzled_scales ,
612612 old .use_triton_kernel ,
613613 old .act_quant_kwargs ,
@@ -626,8 +626,8 @@ def nvfp4_view_op(func, types, args, kwargs):
626626 args [0 ].scale ,
627627 args [0 ]._block_size ,
628628 args [0 ]._orig_dtype ,
629- args [0 ]._per_tensor_scale ,
630- args [0 ]._act_per_tensor_scale ,
629+ args [0 ].per_tensor_scale ,
630+ args [0 ].act_per_tensor_scale ,
631631 args [0 ]._is_swizzled_scales ,
632632 args [0 ].use_triton_kernel ,
633633 args [0 ].act_quant_kwargs ,
@@ -644,8 +644,8 @@ def nvfp4_select(func, types, args, kwargs):
644644 old .scale [index ],
645645 old ._block_size ,
646646 old ._orig_dtype ,
647- old ._per_tensor_scale ,
648- old ._act_per_tensor_scale ,
647+ old .per_tensor_scale ,
648+ old .act_per_tensor_scale ,
649649 old ._is_swizzled_scales ,
650650 old .use_triton_kernel ,
651651 old .act_quant_kwargs ,
@@ -684,11 +684,11 @@ def _addmm_nvfp4_dispatch(
684684 b_scale_blocked = to_blocked (b_scale )
685685
686686 # Merge double quant scales into 1 scale for Scale_In^D
687- if a ._per_tensor_scale is not None :
688- assert b ._per_tensor_scale is not None
689- scale_result = a ._per_tensor_scale * b ._per_tensor_scale
687+ if a .per_tensor_scale is not None :
688+ assert b .per_tensor_scale is not None
689+ scale_result = a .per_tensor_scale * b .per_tensor_scale
690690 else :
691- assert b ._per_tensor_scale is None and a ._per_tensor_scale is None
691+ assert b .per_tensor_scale is None and a .per_tensor_scale is None
692692 scale_result = None
693693
694694 # THIS IS A WORKAROUND:
@@ -772,7 +772,7 @@ def nvfp4_mm(func, types, args, kwargs):
772772 tensor_amax = torch .max (torch .abs (input_tensor ))
773773 per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
774774 else :
775- per_tensor_scale = weight_tensor ._act_per_tensor_scale
775+ per_tensor_scale = weight_tensor .act_per_tensor_scale
776776 input_tensor = NVFP4Tensor .to_nvfp4 (
777777 input_tensor ,
778778 block_size = k .block_size ,
@@ -805,7 +805,7 @@ def nvfp4_addmm(func, types, args, kwargs):
805805 tensor_amax = torch .max (torch .abs (input_tensor ))
806806 per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
807807 else :
808- per_tensor_scale = weight_tensor ._act_per_tensor_scale
808+ per_tensor_scale = weight_tensor .act_per_tensor_scale
809809 input_tensor = NVFP4Tensor .to_nvfp4 (
810810 input_tensor ,
811811 block_size = k .block_size ,
0 commit comments