diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index fc38aac8d8..1d46bff0d3 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -519,7 +519,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx.qdata, x_mx.scale, x_mx._elem_dtype, - x_mx._block_size, + x_mx.block_size, hp_dtype, # noqa: E501 pack_fp6, ) @@ -527,7 +527,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx_c.qdata, x_mx_c.scale, x_mx_c._elem_dtype, - x_mx_c._block_size, + x_mx_c.block_size, hp_dtype, pack_fp6, ) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 5889019af3..e098edb745 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -71,7 +71,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold): reconstructed_amax = x_nvfp4.get_hp_scales().view(shape[0], -1, 1) * F4_E2M1_MAX max_abs = torch.amax( - torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1 + torch.abs(x.reshape(shape[0], -1, x_nvfp4.block_size)), dim=-1 ).unsqueeze(-1) assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0) @@ -526,7 +526,7 @@ def test_nvfp4_to_copy(): assert y.per_tensor_scale is None assert x.act_per_tensor_scale is None assert y.act_per_tensor_scale is None - assert x._block_size == y._block_size + assert x.block_size == y.block_size assert x.use_triton_kernel == y.use_triton_kernel assert x.act_quant_kwargs == y.act_quant_kwargs assert x.dtype == torch.float32 diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 67aa9d767a..9119ae4b24 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -490,7 +490,7 @@ class MXTensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale"] tensor_attribute_names = [ "_elem_dtype", - "_block_size", + "block_size", "_orig_dtype", "_gemm_kernel_choice", "_pack_fp6", @@ -547,7 +547,7 @@ def __new__( self.qdata = qdata self.scale = scale_e8m0_bits self._elem_dtype = elem_dtype - self._block_size = block_size + self.block_size = block_size self._orig_dtype = orig_dtype self._gemm_kernel_choice = gemm_kernel_choice self._pack_fp6 = pack_fp6 @@ -560,7 +560,7 @@ def __repr__(self): return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self.scale}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}, _is_swizzled_scales={self._is_swizzled_scales}" # noqa: E501 def _quantization_type(self): - return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" + return f"{self._elem_dtype=}, {self.block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: if output_dtype is None: @@ -575,9 +575,9 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor else: leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1] scale = from_blocked( - scale, math.prod(leading_dims) * M, K // self._block_size + scale, math.prod(leading_dims) * M, K // self.block_size ) - scale = scale.view(*leading_dims, M, K // self._block_size) + scale = scale.view(*leading_dims, M, K // self.block_size) if is_transposed: scale = scale.transpose(-2, -1) @@ -585,7 +585,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor self.qdata, scale, self._elem_dtype, - self._block_size, + self.block_size, output_dtype, self._pack_fp6, ) @@ -699,19 +699,19 @@ def _addmm_mx_dispatch( M, K, N = a.shape[0], a.shape[1], b.shape[1] assert a.qdata.is_contiguous() assert b.qdata.t().is_contiguous() - assert a._block_size == 32, f"Invalid block size {a._block_size}" - assert b._block_size == 32, f"Invalid block size {b._block_size}" + assert a.block_size == 32, f"Invalid block size {a.block_size}" + assert b.block_size == 32, f"Invalid block size {b.block_size}" if a._is_swizzled_scales: a_scale_block = a.scale else: - a_scale = a.scale.view(M, K // a._block_size) + a_scale = a.scale.view(M, K // a.block_size) a_scale_block = to_blocked(a_scale) if b._is_swizzled_scales: b_scale_block = b.scale.t() else: - b_scale = b.scale.t().view(N, K // b._block_size) + b_scale = b.scale.t().view(N, K // b.block_size) b_scale_block = to_blocked(b_scale) if a._elem_dtype == torch.float8_e4m3fn: @@ -804,7 +804,7 @@ def mx_t(func, types, args, kwargs): old.qdata.t(), old.scale.t(), old._elem_dtype, - old._block_size, + old.block_size, old._orig_dtype, old._gemm_kernel_choice, old._pack_fp6, @@ -849,7 +849,7 @@ def mx_view_op(func, types, args, kwargs): new_data, args[0].scale, args[0]._elem_dtype, - args[0]._block_size, + args[0].block_size, args[0]._orig_dtype, args[0]._gemm_kernel_choice, args[0]._pack_fp6, @@ -875,7 +875,7 @@ def mx_slice(func, types, args, kwargs): sliced_data, sliced_scale, x._elem_dtype, - x._block_size, + x.block_size, x._orig_dtype, x._gemm_kernel_choice, x._pack_fp6, @@ -910,7 +910,7 @@ def mx_select(func, types, args, kwargs): old_mx_tensor.qdata[index], old_mx_tensor.scale[index], old_mx_tensor._elem_dtype, - old_mx_tensor._block_size, + old_mx_tensor.block_size, old_mx_tensor._orig_dtype, old_mx_tensor._gemm_kernel_choice, old_mx_tensor._pack_fp6, diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 18f05290e5..26e48216ee 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -78,7 +78,7 @@ class NVFP4Tensor(TorchAOBaseTensor): scale: Blockwise scales in float8_e4m3fn format (may be swizzled) per_tensor_scale: Optional global per-tensor scale in float32 format act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation - _block_size (int): Block size for quantization (fixed at 16) + block_size (int): Block size for quantization (fixed at 16) _orig_dtype (torch.dtype): Original tensor dtype before quantization _is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format use_triton_kernel (bool): Whether to use triton kernels @@ -86,7 +86,7 @@ class NVFP4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale"] tensor_attribute_names = [ - "_block_size", + "block_size", "_orig_dtype", ] optional_tensor_data_names = ["per_tensor_scale", "act_per_tensor_scale"] @@ -126,7 +126,7 @@ def __new__( self.qdata = qdata self.scale = scale - self._block_size = block_size + self.block_size = block_size self._orig_dtype = orig_dtype self.per_tensor_scale = per_tensor_scale self.act_per_tensor_scale = act_per_tensor_scale @@ -238,10 +238,10 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor data_f32 = f4_unpacked_to_f32(data_unpacked) data_f32 = data_f32.view( - *leading_dims, M, K // self._block_size, self._block_size + *leading_dims, M, K // self.block_size, self.block_size ) scale_e4m3_reshaped = self.get_hp_scales().view( - *leading_dims, M, K // self._block_size, 1 + *leading_dims, M, K // self.block_size, 1 ) data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32) result = data_scaled.view(*leading_dims, M, K).to(output_dtype) @@ -267,7 +267,7 @@ def get_hp_scales(self) -> torch.Tensor: if self._is_swizzled_scales: scale_e4m3 = from_blocked( - scale_e4m3, math.prod(leading_dims) * M, K // self._block_size + scale_e4m3, math.prod(leading_dims) * M, K // self.block_size ) return ( @@ -297,7 +297,7 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool: return ( isinstance(self, NVFP4Tensor) and isinstance(src, NVFP4Tensor) - and self._block_size == src._block_size + and self.block_size == src.block_size and self._orig_dtype == src._orig_dtype and self._is_swizzled_scales == src._is_swizzled_scales and self.scale.shape == src.scale.shape @@ -341,7 +341,7 @@ def nvfp4_to_copy(func, types, args, kwargs): res = NVFP4Tensor( tensor.qdata, tensor.scale, - tensor._block_size, + tensor.block_size, dtype, tensor.per_tensor_scale, tensor.act_per_tensor_scale, @@ -399,7 +399,7 @@ def nvfp4_slice(func, types, args, kwargs): result = NVFP4Tensor( sliced_data, sliced_scale, - x._block_size, + x.block_size, x._orig_dtype, x.per_tensor_scale, x.act_per_tensor_scale, @@ -418,7 +418,7 @@ def nvfp4_t(func, types, args, kwargs): new = NVFP4Tensor( old.qdata.t(), old.scale.t(), - old._block_size, + old.block_size, old._orig_dtype, old.per_tensor_scale, old.act_per_tensor_scale, @@ -440,7 +440,7 @@ def nvfp4_transpose(func, types, args, kwargs): new = NVFP4Tensor( new_qdata, new_scale, - old._block_size, + old.block_size, old._orig_dtype, old.per_tensor_scale, old.act_per_tensor_scale, @@ -460,7 +460,7 @@ def nvfp4_view_op(func, types, args, kwargs): return NVFP4Tensor( new_data, args[0].scale, - args[0]._block_size, + args[0].block_size, args[0]._orig_dtype, args[0].per_tensor_scale, args[0].act_per_tensor_scale, @@ -478,7 +478,7 @@ def nvfp4_select(func, types, args, kwargs): new = old.__class__( old.qdata[index], old.scale[index], - old._block_size, + old.block_size, old._orig_dtype, old.per_tensor_scale, old.act_per_tensor_scale, @@ -500,8 +500,8 @@ def _addmm_nvfp4_dispatch( assert a.scale.is_contiguous() assert b.qdata.t().is_contiguous() assert b.scale.t().is_contiguous() - assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}" - assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}" + assert a.block_size == 16, f"NVFP4 requires block_size=16, got {a.block_size}" + assert b.block_size == 16, f"NVFP4 requires block_size=16, got {b.block_size}" assert len(a.shape) == 2 and len(b.shape) == 2 M, K = a.shape[0], a.shape[1] @@ -511,13 +511,13 @@ def _addmm_nvfp4_dispatch( if a._is_swizzled_scales: a_scale_blocked = a.scale # Already swizzled else: - a_scale = a.scale.view(M, K // a._block_size) + a_scale = a.scale.view(M, K // a.block_size) a_scale_blocked = to_blocked(a_scale) if b._is_swizzled_scales: b_scale_blocked = b.scale.t() # Already swizzled else: - b_scale = b.scale.t().view(N, K // b._block_size) + b_scale = b.scale.t().view(N, K // b.block_size) b_scale_blocked = to_blocked(b_scale) # Merge double quant scales into 1 scale for Scale_In^D diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 78bfd48ab7..72d8a47b81 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -232,7 +232,7 @@ def _swizzle_aware_slice( if x._is_swizzled_scales: scale_rows = M - scale_cols = K // x._block_size + scale_cols = K // x.block_size n_row_blocks = ceil_div(scale_rows, 128) n_col_blocks = ceil_div(scale_cols, 4) elements_per_block = 32 * 16 # 512 elements @@ -351,7 +351,7 @@ def _swizzle_aware_slice( ) else: - scale_shaped = x.scale.view(M, K // x._block_size) + scale_shaped = x.scale.view(M, K // x.block_size) if dim == 0: sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step) @@ -359,16 +359,16 @@ def _swizzle_aware_slice( elif dim == 1: if start is not None: - assert start % x._block_size == 0, ( - f"Start index {start} must be a multiple of block_size {x._block_size}" + assert start % x.block_size == 0, ( + f"Start index {start} must be a multiple of block_size {x.block_size}" ) assert start % 2 == 0, ( f"Start index {start} must be even for FP4 packing" ) if end is not None and end != sys.maxsize: - assert end % x._block_size == 0, ( - f"End index {end} must be a multiple of block_size {x._block_size}" + assert end % x.block_size == 0, ( + f"End index {end} must be a multiple of block_size {x.block_size}" ) assert end % 2 == 0, f"End index {end} must be even for FP4 packing" @@ -382,8 +382,8 @@ def _swizzle_aware_slice( x.qdata, dim, packed_start, packed_end, step ) - start_block = 0 if start is None else start // x._block_size - end_block = None if end is None else end // x._block_size + start_block = 0 if start is None else start // x.block_size + end_block = None if end is None else end // x.block_size sliced_scale = aten.slice.Tensor( scale_shaped, 1, start_block, end_block, step ) @@ -398,12 +398,12 @@ def _swizzle_aware_slice( # multiply by 2 to convert from bytes to num_elements sliced_K = sliced_data.shape[1] * 2 if x._is_swizzled_scales: - if x._block_size == 16: + if x.block_size == 16: scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4( sliced_M, sliced_K ) else: - assert x._block_size == 32, f"unexpected {x._block_size=}" + assert x.block_size == 32, f"unexpected {x.block_size=}" scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_mx( sliced_M, sliced_K ) @@ -413,7 +413,7 @@ def _swizzle_aware_slice( # mx: a 1x32 unpacked or 1x16 packed qdata tile corresponds to 1 # scale element scale_M = sliced_M - scale_K = sliced_K // x._block_size + scale_K = sliced_K // x.block_size sliced_scale = sliced_scale.view(scale_M, scale_K) return sliced_data, sliced_scale