diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 4cb797beb4..a9cd7bf803 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -11,7 +11,8 @@ torchao.dtypes :nosignatures: to_nf4 - UInt4Tensor + to_affine_quantized + AffineQuantizedTensor .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index 8a580ae24e..7e61bea1d1 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -9,15 +9,15 @@ torchao.quantization .. autosummary:: :toctree: generated/ :nosignatures: - - apply_weight_only_int8_quant - apply_dynamic_quant - change_linear_weights_to_int8_dqtensors - change_linear_weights_to_int8_woqtensors - change_linear_weights_to_int4_woqtensors + SmoothFakeDynQuantMixin SmoothFakeDynamicallyQuantizedLinear swap_linear_with_smooth_fq_linear smooth_fq_linear_to_inference Int4WeightOnlyGPTQQuantizer Int4WeightOnlyQuantizer + quantize + int8_dynamic_activation_int4_weight + int8_dynamic_activation_int8_weight + int4_weight_only + int8_weight_only diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index f1bb82921e..3e45bf34f1 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -32,8 +32,8 @@ "dequantize_affine", "choose_qprams_affine", "quantize", - "int8_dynamic_act_int4_weight", - "int8_dynamic_act_int8_weight", + "int8_dynamic_activation_int4_weight", + "int8_dynamic_activation_int8_weight", "int4_weight_only", "int8_weight_only", ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3a1516d9b5..6f7f549704 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from typing import Any, Callable, Union, Dict +from typing import Any, Callable, Union, Dict, Optional from torchao.utils import ( TORCH_VERSION_AFTER_2_4, @@ -258,38 +258,53 @@ def insert_subclass(lin): return insert_subclass -def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None) -> torch.nn.Module: +def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass` Args: - model: input model - apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance - filter_fn: used to filter out the modules that we don't want to apply tenosr subclass + model (torch.nn.Module): input model + apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance (e.g. affine quantized tensor instance) + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on + the weight of the module Example:: - # weight settings - groupsize = 32 - mapping_type = MappingType.ASYMMETRIC - block_size = (1, groupsize) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT + import torch + import torch.nn as nn + from torchao import quantize + + # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to + # optimized execution paths or kernels (e.g. int4 tinygemm kernel) + # also customizable with arguments + # currently options are + # int8_dynamic_activation_int4_weight (for executorch) + # int8_dynamic_activation_int8_weight (optimized with int8 mm op and torch.compile) + # int4_weight_only (optimized with int4 tinygemm kernel and torch.compile) + # int8_weight_only (optimized with int8 mm op and torch.compile + from torchao.quantization.quant_api import int4_weight_only + + m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) + m = quantize(m, int4_weight_only(group_size=32)) + + # 2. write your own new apply_tensor_subclass + # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor + # on weight + from torchao.dtypes import to_affine_quantized + + # weight only uint4 asymmetric groupwise quantization + groupsize = 32 apply_weight_quant = lambda x: to_affine_quantized( - x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, - zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) + x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6, + zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float") # apply to modules under block0 submodule - def filter_fn(module, fqn): - return fqn == "block0" + def filter_fn(module: nn.Module, fqn: str) -> bool: + return isinstance(module, nn.Linear) - m = MyModel(...) + m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) m = quantize(m, apply_weight_quant, filter_fn) + """ if isinstance(apply_tensor_subclass, str): if apply_tensor_subclass not in _APPLY_TS_TABLE: