Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,10 @@ def test_nvfp4_to_copy():
y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16)
assert torch.equal(x.qdata, y.qdata)
assert torch.equal(x.scale, y.scale)
assert x._per_tensor_scale is None
assert y._per_tensor_scale is None
assert x._act_per_tensor_scale is None
assert y._act_per_tensor_scale is None
assert x.per_tensor_scale is None
assert y.per_tensor_scale is None
assert x.act_per_tensor_scale is None
assert y.act_per_tensor_scale is None
assert x._block_size == y._block_size
assert x.use_triton_kernel == y.use_triton_kernel
assert x.act_quant_kwargs == y.act_quant_kwargs
Expand Down
64 changes: 32 additions & 32 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ class NVFP4Tensor(TorchAOBaseTensor):
Attributes:
qdata: Packed FP4 data (2 values per byte)
scale: Blockwise scales in float8_e4m3fn format (may be swizzled)
_per_tensor_scale: Optional global per-tensor scale in float32 format
_act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation
per_tensor_scale: Optional global per-tensor scale in float32 format
act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation
_block_size (int): Block size for quantization (fixed at 16)
_orig_dtype (torch.dtype): Original tensor dtype before quantization
_is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
Expand All @@ -89,7 +89,7 @@ class NVFP4Tensor(TorchAOBaseTensor):
"_block_size",
"_orig_dtype",
]
optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"]
optional_tensor_data_names = ["per_tensor_scale", "act_per_tensor_scale"]
optional_tensor_attribute_names = [
"_is_swizzled_scales",
"use_triton_kernel",
Expand All @@ -102,8 +102,8 @@ def __new__(
scale,
block_size,
orig_dtype,
_per_tensor_scale=None,
_act_per_tensor_scale=None,
per_tensor_scale=None,
act_per_tensor_scale=None,
_is_swizzled_scales=False,
use_triton_kernel=False,
act_quant_kwargs=None,
Expand All @@ -128,15 +128,15 @@ def __new__(
self.scale = scale
self._block_size = block_size
self._orig_dtype = orig_dtype
self._per_tensor_scale = _per_tensor_scale
self._act_per_tensor_scale = _act_per_tensor_scale
self.per_tensor_scale = per_tensor_scale
self.act_per_tensor_scale = act_per_tensor_scale
self._is_swizzled_scales = _is_swizzled_scales
self.use_triton_kernel = use_triton_kernel
self.act_quant_kwargs = act_quant_kwargs
return self

def __repr__(self):
return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self._per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}"
return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self.per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}"

def _quantization_type(self):
return f"{self._is_swizzled_scales=}, {self.use_triton_kernel=}, {self.act_quant_kwargs=}"
Expand Down Expand Up @@ -270,8 +270,8 @@ def get_hp_scales(self) -> torch.Tensor:

return (
scale_e4m3.to(self._orig_dtype)
if self._per_tensor_scale is None
else self._per_tensor_scale * scale_e4m3.to(self._orig_dtype)
if self.per_tensor_scale is None
else self.per_tensor_scale * scale_e4m3.to(self._orig_dtype)
)

@classmethod
Expand All @@ -286,11 +286,11 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
bool: True if both tensors have identical metadata, False otherwise
"""
per_tensor_scale_equal = (
self._per_tensor_scale is None and src._per_tensor_scale is None
) or (self._per_tensor_scale.shape == src._per_tensor_scale.shape)
self.per_tensor_scale is None and src.per_tensor_scale is None
) or (self.per_tensor_scale.shape == src.per_tensor_scale.shape)
act_per_tensor_scale_equal = (
self._act_per_tensor_scale is None and src._act_per_tensor_scale is None
) or (self._act_per_tensor_scale.shape == src._act_per_tensor_scale.shape)
self.act_per_tensor_scale is None and src.act_per_tensor_scale is None
) or (self.act_per_tensor_scale.shape == src.act_per_tensor_scale.shape)

return (
isinstance(self, NVFP4Tensor)
Expand Down Expand Up @@ -341,8 +341,8 @@ def nvfp4_to_copy(func, types, args, kwargs):
tensor.scale,
tensor._block_size,
dtype,
tensor._per_tensor_scale,
tensor._act_per_tensor_scale,
tensor.per_tensor_scale,
tensor.act_per_tensor_scale,
tensor._is_swizzled_scales,
tensor.use_triton_kernel,
tensor.act_quant_kwargs,
Expand Down Expand Up @@ -565,8 +565,8 @@ def nvfp4_slice(func, types, args, kwargs):
sliced_scale,
x._block_size,
x._orig_dtype,
x._per_tensor_scale,
x._act_per_tensor_scale,
x.per_tensor_scale,
x.act_per_tensor_scale,
x._is_swizzled_scales,
x.use_triton_kernel,
x.act_quant_kwargs,
Expand All @@ -584,8 +584,8 @@ def nvfp4_t(func, types, args, kwargs):
old.scale.t(),
old._block_size,
old._orig_dtype,
old._per_tensor_scale,
old._act_per_tensor_scale,
old.per_tensor_scale,
old.act_per_tensor_scale,
old._is_swizzled_scales,
old.use_triton_kernel,
old.act_quant_kwargs,
Expand All @@ -606,8 +606,8 @@ def nvfp4_transpose(func, types, args, kwargs):
new_scale,
old._block_size,
old._orig_dtype,
old._per_tensor_scale,
old._act_per_tensor_scale,
old.per_tensor_scale,
old.act_per_tensor_scale,
old._is_swizzled_scales,
old.use_triton_kernel,
old.act_quant_kwargs,
Expand All @@ -626,8 +626,8 @@ def nvfp4_view_op(func, types, args, kwargs):
args[0].scale,
args[0]._block_size,
args[0]._orig_dtype,
args[0]._per_tensor_scale,
args[0]._act_per_tensor_scale,
args[0].per_tensor_scale,
args[0].act_per_tensor_scale,
args[0]._is_swizzled_scales,
args[0].use_triton_kernel,
args[0].act_quant_kwargs,
Expand All @@ -644,8 +644,8 @@ def nvfp4_select(func, types, args, kwargs):
old.scale[index],
old._block_size,
old._orig_dtype,
old._per_tensor_scale,
old._act_per_tensor_scale,
old.per_tensor_scale,
old.act_per_tensor_scale,
old._is_swizzled_scales,
old.use_triton_kernel,
old.act_quant_kwargs,
Expand Down Expand Up @@ -684,11 +684,11 @@ def _addmm_nvfp4_dispatch(
b_scale_blocked = to_blocked(b_scale)

# Merge double quant scales into 1 scale for Scale_In^D
if a._per_tensor_scale is not None:
assert b._per_tensor_scale is not None
scale_result = a._per_tensor_scale * b._per_tensor_scale
if a.per_tensor_scale is not None:
assert b.per_tensor_scale is not None
scale_result = a.per_tensor_scale * b.per_tensor_scale
else:
assert b._per_tensor_scale is None and a._per_tensor_scale is None
assert b.per_tensor_scale is None and a.per_tensor_scale is None
scale_result = None

# THIS IS A WORKAROUND:
Expand Down Expand Up @@ -772,7 +772,7 @@ def nvfp4_mm(func, types, args, kwargs):
tensor_amax = torch.max(torch.abs(input_tensor))
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
else:
per_tensor_scale = weight_tensor._act_per_tensor_scale
per_tensor_scale = weight_tensor.act_per_tensor_scale
input_tensor = NVFP4Tensor.to_nvfp4(
input_tensor,
block_size=k.block_size,
Expand Down Expand Up @@ -805,7 +805,7 @@ def nvfp4_addmm(func, types, args, kwargs):
tensor_amax = torch.max(torch.abs(input_tensor))
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
else:
per_tensor_scale = weight_tensor._act_per_tensor_scale
per_tensor_scale = weight_tensor.act_per_tensor_scale
input_tensor = NVFP4Tensor.to_nvfp4(
input_tensor,
block_size=k.block_size,
Expand Down
Loading