diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index fbe680953e..26e1266c09 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -6,19 +6,42 @@ torchao.dtypes .. currentmodule:: torchao.dtypes +Layouts and Tensor Subclasses +----------------------------- +.. autosummary:: + :toctree: generated/ + :nosignatures: + + NF4Tensor + AffineQuantizedTensor + Layout + PlainLayout + SemiSparseLayout + TensorCoreTiledLayout + Float8Layout + FloatxTensor + FloatxTensorCoreLayout + MarlinSparseLayout + BlockSparseLayout + UintxLayout + MarlinQQQTensor + MarlinQQQLayout + Int4CPULayout + CutlassInt4PackedLayout + +Quantization techniques +----------------------- .. autosummary:: :toctree: generated/ :nosignatures: - to_nf4 to_affine_quantized_intx to_affine_quantized_intx_static + to_affine_quantized_fpx to_affine_quantized_floatx to_affine_quantized_floatx_static - to_affine_quantized_fpx - NF4Tensor - AffineQuantizedTensor - + to_marlinqqq_quantized_intx + to_nf4 .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation. diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e7aca34c5f..e3ac420de7 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -44,9 +44,8 @@ # Tensor Subclass Definition # ############################## class AffineQuantizedTensor(TorchAOBaseTensor): - """ - Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: - quantized_tensor = float_tensor / scale + zero_point + """Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: + quantized_tensor = float_tensor / scale + zero_point To see what happens during choose_qparams, quantization and dequantization for affine quantization, please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py @@ -56,21 +55,18 @@ class AffineQuantizedTensor(TorchAOBaseTensor): regardless of the internal representation's type or orientation. fields: - tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data, - e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device - and operator/kernel - block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam - e.g. when size is the same as the input tensor dimension, we are using per tensor quantization - shape (torch.Size): the shape for the original high precision Tensor - quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during - quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT - dtype: dtype for original high precision tensor, e.g. torch.float32 + - tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data, + e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device and operator/kernel + - block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + - shape (torch.Size): the shape for the original high precision Tensor + - quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + - quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization + default is ZeroPointDomain.INT + - dtype: dtype for original high precision tensor, e.g. torch.float32 """ @staticmethod @@ -207,6 +203,7 @@ def from_hp_to_intx( _layout: Layout = PlainLayout(), use_hqq: bool = False, ): + """Convert a high precision tensor to an integer affine quantized tensor.""" original_shape = input_float.shape input_float = _layout.pre_process(input_float) @@ -302,6 +299,7 @@ def from_hp_to_intx_static( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), ): + """Create an integer AffineQuantizedTensor from a high precision tensor using static parameters.""" if target_dtype not in FP8_TYPES: assert ( zero_point_domain is not None @@ -348,6 +346,7 @@ def from_hp_to_floatx( _layout: Layout, scale_dtype: Optional[torch.dtype] = None, ): + """Convert a high precision tensor to a float8 quantized tensor.""" if target_dtype in FP8_TYPES: return cls.from_hp_to_intx( input_float=input_float, @@ -378,6 +377,7 @@ def from_hp_to_floatx_static( target_dtype: torch.dtype, _layout: Layout, ): + """Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters.""" if target_dtype in FP8_TYPES: return cls.from_hp_to_intx_static( input_float=input_float, @@ -401,6 +401,7 @@ def from_hp_to_fpx( input_float: torch.Tensor, _layout: Layout, ): + """Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7.""" from torchao.dtypes.floatx import FloatxTensorCoreLayout assert isinstance( diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index dd995fb157..5a7e1924b3 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -25,6 +25,12 @@ @dataclass(frozen=True) class Float8Layout(Layout): + """Represents the layout configuration for Float8 affine quantized tensors. + + Attributes: + mm_config (Optional[Float8MMConfig]): Configuration for matrix multiplication operations involving Float8 tensors. If None, default settings are used. + """ + mm_config: Optional[Float8MMConfig] = None diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index 0f67e9826e..beaa2e536e 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -450,7 +450,9 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> # quantization api integrations @dataclass(frozen=True) class FloatxTensorCoreLayout(Layout): - """Layout type for FloatxTensorCoreAQTTensorImpl""" + """FloatxTensorCoreLayout is a data class that defines the layout for a tensor with a specific number of exponent bits (ebits) and mantissa bits (mbits). + This layout is used in the context of quantization and packing of tensors optimized for TensorCore operations. + """ ebits: int mbits: int diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 14a8c2d43e..5ae06a1fe1 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -662,10 +662,9 @@ def dequantize_scalers( ) -> torch.Tensor: """Used to unpack the double quantized scalers - Args; + Args: input_tensor: Input tensor to convert to QLoRA format this is the quantized scalers in int8 format quantization_factor: Tensor of per_scaler_block quantization factors stored in inpt_weight.dtype - size: (n_scaler_blocks) scaler_block_size: Scaler block size to use for double quantization. """ @@ -953,6 +952,7 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256): + """Convert a given tensor to normalized float 4-bit tensor.""" return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 0670986b13..6681847608 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -27,6 +27,12 @@ @dataclass(frozen=True) class BlockSparseLayout(Layout): + """BlockSparseLayout is a data class that represents the layout of a block sparse matrix. + + Attributes: + blocksize (int): The size of the blocks in the sparse matrix. Default is 64. + """ + blocksize: int = 64 diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index a6412ec88c..9c0d0bb055 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -29,6 +29,8 @@ def _aqt_is_int4(aqt): @dataclass(frozen=True) class CutlassInt4PackedLayout(Layout): + """Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel.""" + pass diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 7c734a8a44..d587591ccc 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -24,15 +24,16 @@ @dataclass(frozen=True) class Int4CPULayout(Layout): - """Only for PyTorch version at least 2.6""" + """Layout class for int4 CPU layout for affine quantized tensor, used by tinygemm kernels `_weight_int4pack_mm_for_cpu`. + Only for PyTorch version at least 2.6 + """ pass @register_layout(Int4CPULayout) class Int4CPUAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, + """TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, used by tinygemm kernels `_weight_int4pack_mm_for_cpu` It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of dimension: [n][k / 2] (uint8 dtype) diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index b75d959b41..3a4253bb3f 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -29,8 +29,7 @@ class MarlinQQQTensor(AffineQuantizedTensor): - """ - MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. + """MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py @@ -58,6 +57,7 @@ def from_hp_to_intx( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, _layout: Optional[Layout] = None, ): + """Converts a floating point tensor to a Marlin QQQ quantized tensor.""" original_shape = input_float.shape input_float = _layout.pre_process(input_float) nbits = int(math.log2(quant_max - quant_min + 1)) @@ -81,6 +81,8 @@ def from_hp_to_intx( @dataclass(frozen=True) class MarlinQQQLayout(Layout): + """MarlinQQQLayout is a layout class for Marlin QQQ quantization.""" + pass diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index 2a84dd1813..22763eb0c2 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -71,6 +71,17 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b @dataclass(frozen=True) class MarlinSparseLayout(Layout): + """MarlinSparseLayout is a layout class for handling sparse tensor formats + specifically designed for the Marlin sparse kernel. This layout is used + to optimize the storage and computation of affine quantized tensors with + 2:4 sparsity patterns. + + The layout ensures that the tensor data is pre-processed and stored in a + format that is compatible with the Marlin sparse kernel operations. It + provides methods for preprocessing input tensors and managing the layout + of quantized tensors. + """ + def pre_process(self, input: torch.Tensor) -> torch.Tensor: """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - 1ยบ: the input tensor is transposed since the linear layer keeps the weights in a transposed format diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index a554fd9bc6..3c35a4d8cd 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -66,6 +66,13 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( @dataclass(frozen=True) class SemiSparseLayout(Layout): + """SemiSparseLayout is a layout class for handling semi-structured sparse + matrices in affine quantized tensors. This layout is specifically designed + to work with the 2:4 sparsity pattern, where two out of every four elements + are pruned to zero. This class provides methods for preprocessing input + tensors to conform to this sparsity pattern. + """ + def pre_process(self, input: torch.Tensor) -> torch.Tensor: # prune to 2:4 if not already temp = input.detach() diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 378744e7e1..b29c9d167b 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -91,9 +91,10 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): @dataclass(frozen=True) class TensorCoreTiledLayout(Layout): - """ - inner_k_tiles is an internal argument for packing function of tensor core tiled layout - that can affect the performance of the matmul kernel + """TensorCoreTiledLayout is a layout class for handling tensor core tiled layouts in affine quantized tensors. It provides methods for pre-processing and post-processing tensors to fit the required layout for efficient computation on tensor cores. + + Attributes: + inner_k_tiles (int): An internal argument for the packing function of tensor core tiled layout that can affect the performance of the matmul kernel. Defaults to 8. """ inner_k_tiles: int = 8 @@ -149,8 +150,7 @@ def extra_repr(self): @register_layout(TensorCoreTiledLayout) class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, + """TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, used by tinygemm kernels `_weight_int4pack_mm` It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of diff --git a/torchao/dtypes/uintx/uintx_layout.py b/torchao/dtypes/uintx/uintx_layout.py index 29c2ae93fe..ef85319cd5 100644 --- a/torchao/dtypes/uintx/uintx_layout.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -209,6 +209,17 @@ def _(func, types, args, kwargs): @dataclass(frozen=True) class UintxLayout(Layout): + """A layout class for Uintx tensors, which are tensors with elements packed into + smaller bit-widths than the standard 8-bit byte. This layout is used to define + how the data is stored and processed in UintxTensor objects. + + Attributes: + dtype (torch.dtype): The data type of the tensor elements, which determines + the bit-width used for packing. + pack_dim (int): The dimension along which the data is packed. Default is -1, + which indicates the last dimension. + """ + dtype: torch.dtype pack_dim: int = -1 diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 0952b2a4bf..45a0b4312d 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -27,6 +27,15 @@ @dataclass(frozen=True) class Layout: + """The Layout class serves as a base class for defining different data layouts for tensors. + It provides methods for pre-processing and post-processing tensors, as well as static + pre-processing with additional parameters like scale, zero_point, and block_size. + + The Layout class is designed to be extended by other layout classes that define specific + data representations and behaviors for tensors. It is used in conjunction with TensorImpl + classes to represent custom data layouts and how tensors interact with different operators. + """ + def pre_process(self, input: torch.Tensor) -> torch.Tensor: return input @@ -49,13 +58,13 @@ def extra_repr(self) -> str: return "" -""" -Plain Layout, the most basic Layout, also has no extra metadata, will typically be the default -""" - - @dataclass(frozen=True) class PlainLayout(Layout): + """PlainLayout is the most basic layout class, inheriting from the Layout base class. + It does not add any additional metadata or processing steps to the tensor. + Typically, this layout is used as the default when no specific layout is required. + """ + pass