Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,15 +519,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
x_mx.qdata,
x_mx.scale,
x_mx._elem_dtype,
x_mx._block_size,
x_mx.block_size,
hp_dtype, # noqa: E501
pack_fp6,
)
x_mx_c_dq = to_dtype_c(
x_mx_c.qdata,
x_mx_c.scale,
x_mx_c._elem_dtype,
x_mx_c._block_size,
x_mx_c.block_size,
hp_dtype,
pack_fp6,
)
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):

reconstructed_amax = x_nvfp4.get_hp_scales().view(shape[0], -1, 1) * F4_E2M1_MAX
max_abs = torch.amax(
torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1
torch.abs(x.reshape(shape[0], -1, x_nvfp4.block_size)), dim=-1
).unsqueeze(-1)

assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0)
Expand Down Expand Up @@ -526,7 +526,7 @@ def test_nvfp4_to_copy():
assert y.per_tensor_scale is None
assert x.act_per_tensor_scale is None
assert y.act_per_tensor_scale is None
assert x._block_size == y._block_size
assert x.block_size == y.block_size
assert x.use_triton_kernel == y.use_triton_kernel
assert x.act_quant_kwargs == y.act_quant_kwargs
assert x.dtype == torch.float32
Expand Down
28 changes: 14 additions & 14 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ class MXTensor(TorchAOBaseTensor):
tensor_data_names = ["qdata", "scale"]
tensor_attribute_names = [
"_elem_dtype",
"_block_size",
"block_size",
"_orig_dtype",
"_gemm_kernel_choice",
"_pack_fp6",
Expand Down Expand Up @@ -547,7 +547,7 @@ def __new__(
self.qdata = qdata
self.scale = scale_e8m0_bits
self._elem_dtype = elem_dtype
self._block_size = block_size
self.block_size = block_size
self._orig_dtype = orig_dtype
self._gemm_kernel_choice = gemm_kernel_choice
self._pack_fp6 = pack_fp6
Expand All @@ -560,7 +560,7 @@ def __repr__(self):
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self.scale}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}, _is_swizzled_scales={self._is_swizzled_scales}" # noqa: E501

def _quantization_type(self):
return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}"
return f"{self._elem_dtype=}, {self.block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}"

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
Expand All @@ -575,17 +575,17 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
else:
leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1]
scale = from_blocked(
scale, math.prod(leading_dims) * M, K // self._block_size
scale, math.prod(leading_dims) * M, K // self.block_size
)
scale = scale.view(*leading_dims, M, K // self._block_size)
scale = scale.view(*leading_dims, M, K // self.block_size)
if is_transposed:
scale = scale.transpose(-2, -1)

return to_dtype(
self.qdata,
scale,
self._elem_dtype,
self._block_size,
self.block_size,
output_dtype,
self._pack_fp6,
)
Expand Down Expand Up @@ -699,19 +699,19 @@ def _addmm_mx_dispatch(
M, K, N = a.shape[0], a.shape[1], b.shape[1]
assert a.qdata.is_contiguous()
assert b.qdata.t().is_contiguous()
assert a._block_size == 32, f"Invalid block size {a._block_size}"
assert b._block_size == 32, f"Invalid block size {b._block_size}"
assert a.block_size == 32, f"Invalid block size {a.block_size}"
assert b.block_size == 32, f"Invalid block size {b.block_size}"

if a._is_swizzled_scales:
a_scale_block = a.scale
else:
a_scale = a.scale.view(M, K // a._block_size)
a_scale = a.scale.view(M, K // a.block_size)
a_scale_block = to_blocked(a_scale)

if b._is_swizzled_scales:
b_scale_block = b.scale.t()
else:
b_scale = b.scale.t().view(N, K // b._block_size)
b_scale = b.scale.t().view(N, K // b.block_size)
b_scale_block = to_blocked(b_scale)

if a._elem_dtype == torch.float8_e4m3fn:
Expand Down Expand Up @@ -804,7 +804,7 @@ def mx_t(func, types, args, kwargs):
old.qdata.t(),
old.scale.t(),
old._elem_dtype,
old._block_size,
old.block_size,
old._orig_dtype,
old._gemm_kernel_choice,
old._pack_fp6,
Expand Down Expand Up @@ -849,7 +849,7 @@ def mx_view_op(func, types, args, kwargs):
new_data,
args[0].scale,
args[0]._elem_dtype,
args[0]._block_size,
args[0].block_size,
args[0]._orig_dtype,
args[0]._gemm_kernel_choice,
args[0]._pack_fp6,
Expand All @@ -875,7 +875,7 @@ def mx_slice(func, types, args, kwargs):
sliced_data,
sliced_scale,
x._elem_dtype,
x._block_size,
x.block_size,
x._orig_dtype,
x._gemm_kernel_choice,
x._pack_fp6,
Expand Down Expand Up @@ -910,7 +910,7 @@ def mx_select(func, types, args, kwargs):
old_mx_tensor.qdata[index],
old_mx_tensor.scale[index],
old_mx_tensor._elem_dtype,
old_mx_tensor._block_size,
old_mx_tensor.block_size,
old_mx_tensor._orig_dtype,
old_mx_tensor._gemm_kernel_choice,
old_mx_tensor._pack_fp6,
Expand Down
34 changes: 17 additions & 17 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ class NVFP4Tensor(TorchAOBaseTensor):
scale: Blockwise scales in float8_e4m3fn format (may be swizzled)
per_tensor_scale: Optional global per-tensor scale in float32 format
act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation
_block_size (int): Block size for quantization (fixed at 16)
block_size (int): Block size for quantization (fixed at 16)
_orig_dtype (torch.dtype): Original tensor dtype before quantization
_is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
use_triton_kernel (bool): Whether to use triton kernels
"""

tensor_data_names = ["qdata", "scale"]
tensor_attribute_names = [
"_block_size",
"block_size",
"_orig_dtype",
]
optional_tensor_data_names = ["per_tensor_scale", "act_per_tensor_scale"]
Expand Down Expand Up @@ -126,7 +126,7 @@ def __new__(

self.qdata = qdata
self.scale = scale
self._block_size = block_size
self.block_size = block_size
self._orig_dtype = orig_dtype
self.per_tensor_scale = per_tensor_scale
self.act_per_tensor_scale = act_per_tensor_scale
Expand Down Expand Up @@ -238,10 +238,10 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
data_f32 = f4_unpacked_to_f32(data_unpacked)

data_f32 = data_f32.view(
*leading_dims, M, K // self._block_size, self._block_size
*leading_dims, M, K // self.block_size, self.block_size
)
scale_e4m3_reshaped = self.get_hp_scales().view(
*leading_dims, M, K // self._block_size, 1
*leading_dims, M, K // self.block_size, 1
)
data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32)
result = data_scaled.view(*leading_dims, M, K).to(output_dtype)
Expand All @@ -267,7 +267,7 @@ def get_hp_scales(self) -> torch.Tensor:

if self._is_swizzled_scales:
scale_e4m3 = from_blocked(
scale_e4m3, math.prod(leading_dims) * M, K // self._block_size
scale_e4m3, math.prod(leading_dims) * M, K // self.block_size
)

return (
Expand Down Expand Up @@ -297,7 +297,7 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
return (
isinstance(self, NVFP4Tensor)
and isinstance(src, NVFP4Tensor)
and self._block_size == src._block_size
and self.block_size == src.block_size
and self._orig_dtype == src._orig_dtype
and self._is_swizzled_scales == src._is_swizzled_scales
and self.scale.shape == src.scale.shape
Expand Down Expand Up @@ -341,7 +341,7 @@ def nvfp4_to_copy(func, types, args, kwargs):
res = NVFP4Tensor(
tensor.qdata,
tensor.scale,
tensor._block_size,
tensor.block_size,
dtype,
tensor.per_tensor_scale,
tensor.act_per_tensor_scale,
Expand Down Expand Up @@ -399,7 +399,7 @@ def nvfp4_slice(func, types, args, kwargs):
result = NVFP4Tensor(
sliced_data,
sliced_scale,
x._block_size,
x.block_size,
x._orig_dtype,
x.per_tensor_scale,
x.act_per_tensor_scale,
Expand All @@ -418,7 +418,7 @@ def nvfp4_t(func, types, args, kwargs):
new = NVFP4Tensor(
old.qdata.t(),
old.scale.t(),
old._block_size,
old.block_size,
old._orig_dtype,
old.per_tensor_scale,
old.act_per_tensor_scale,
Expand All @@ -440,7 +440,7 @@ def nvfp4_transpose(func, types, args, kwargs):
new = NVFP4Tensor(
new_qdata,
new_scale,
old._block_size,
old.block_size,
old._orig_dtype,
old.per_tensor_scale,
old.act_per_tensor_scale,
Expand All @@ -460,7 +460,7 @@ def nvfp4_view_op(func, types, args, kwargs):
return NVFP4Tensor(
new_data,
args[0].scale,
args[0]._block_size,
args[0].block_size,
args[0]._orig_dtype,
args[0].per_tensor_scale,
args[0].act_per_tensor_scale,
Expand All @@ -478,7 +478,7 @@ def nvfp4_select(func, types, args, kwargs):
new = old.__class__(
old.qdata[index],
old.scale[index],
old._block_size,
old.block_size,
old._orig_dtype,
old.per_tensor_scale,
old.act_per_tensor_scale,
Expand All @@ -500,8 +500,8 @@ def _addmm_nvfp4_dispatch(
assert a.scale.is_contiguous()
assert b.qdata.t().is_contiguous()
assert b.scale.t().is_contiguous()
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"
assert a.block_size == 16, f"NVFP4 requires block_size=16, got {a.block_size}"
assert b.block_size == 16, f"NVFP4 requires block_size=16, got {b.block_size}"
assert len(a.shape) == 2 and len(b.shape) == 2

M, K = a.shape[0], a.shape[1]
Expand All @@ -511,13 +511,13 @@ def _addmm_nvfp4_dispatch(
if a._is_swizzled_scales:
a_scale_blocked = a.scale # Already swizzled
else:
a_scale = a.scale.view(M, K // a._block_size)
a_scale = a.scale.view(M, K // a.block_size)
a_scale_blocked = to_blocked(a_scale)

if b._is_swizzled_scales:
b_scale_blocked = b.scale.t() # Already swizzled
else:
b_scale = b.scale.t().view(N, K // b._block_size)
b_scale = b.scale.t().view(N, K // b.block_size)
b_scale_blocked = to_blocked(b_scale)

# Merge double quant scales into 1 scale for Scale_In^D
Expand Down
22 changes: 11 additions & 11 deletions torchao/prototype/mx_formats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _swizzle_aware_slice(

if x._is_swizzled_scales:
scale_rows = M
scale_cols = K // x._block_size
scale_cols = K // x.block_size
n_row_blocks = ceil_div(scale_rows, 128)
n_col_blocks = ceil_div(scale_cols, 4)
elements_per_block = 32 * 16 # 512 elements
Expand Down Expand Up @@ -351,24 +351,24 @@ def _swizzle_aware_slice(
)

else:
scale_shaped = x.scale.view(M, K // x._block_size)
scale_shaped = x.scale.view(M, K // x.block_size)

if dim == 0:
sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step)
sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step)

elif dim == 1:
if start is not None:
assert start % x._block_size == 0, (
f"Start index {start} must be a multiple of block_size {x._block_size}"
assert start % x.block_size == 0, (
f"Start index {start} must be a multiple of block_size {x.block_size}"
)
assert start % 2 == 0, (
f"Start index {start} must be even for FP4 packing"
)

if end is not None and end != sys.maxsize:
assert end % x._block_size == 0, (
f"End index {end} must be a multiple of block_size {x._block_size}"
assert end % x.block_size == 0, (
f"End index {end} must be a multiple of block_size {x.block_size}"
)
assert end % 2 == 0, f"End index {end} must be even for FP4 packing"

Expand All @@ -382,8 +382,8 @@ def _swizzle_aware_slice(
x.qdata, dim, packed_start, packed_end, step
)

start_block = 0 if start is None else start // x._block_size
end_block = None if end is None else end // x._block_size
start_block = 0 if start is None else start // x.block_size
end_block = None if end is None else end // x.block_size
sliced_scale = aten.slice.Tensor(
scale_shaped, 1, start_block, end_block, step
)
Expand All @@ -398,12 +398,12 @@ def _swizzle_aware_slice(
# multiply by 2 to convert from bytes to num_elements
sliced_K = sliced_data.shape[1] * 2
if x._is_swizzled_scales:
if x._block_size == 16:
if x.block_size == 16:
scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(
sliced_M, sliced_K
)
else:
assert x._block_size == 32, f"unexpected {x._block_size=}"
assert x.block_size == 32, f"unexpected {x.block_size=}"
scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_mx(
sliced_M, sliced_K
)
Expand All @@ -413,7 +413,7 @@ def _swizzle_aware_slice(
# mx: a 1x32 unpacked or 1x16 packed qdata tile corresponds to 1
# scale element
scale_M = sliced_M
scale_K = sliced_K // x._block_size
scale_K = sliced_K // x.block_size
sliced_scale = sliced_scale.view(scale_M, scale_K)

return sliced_data, sliced_scale
Loading