Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ torchao.dtypes
:nosignatures:

to_nf4
UInt4Tensor
to_affine_quantized
AffineQuantizedTensor

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
Expand Down
12 changes: 6 additions & 6 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ torchao.quantization
.. autosummary::
:toctree: generated/
:nosignatures:

apply_weight_only_int8_quant
apply_dynamic_quant
change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_woqtensors
change_linear_weights_to_int4_woqtensors

SmoothFakeDynQuantMixin
SmoothFakeDynamicallyQuantizedLinear
swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference
Int4WeightOnlyGPTQQuantizer
Int4WeightOnlyQuantizer
quantize
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int8_weight
int4_weight_only
int8_weight_only
4 changes: 2 additions & 2 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
"dequantize_affine",
"choose_qprams_affine",
"quantize",
"int8_dynamic_act_int4_weight",
"int8_dynamic_act_int8_weight",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int4_weight_only",
"int8_weight_only",
]
57 changes: 36 additions & 21 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Union, Dict
from typing import Any, Callable, Union, Dict, Optional

from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
Expand Down Expand Up @@ -258,38 +258,53 @@ def insert_subclass(lin):

return insert_subclass

def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None) -> torch.nn.Module:
def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module:
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`

Args:
model: input model
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance
filter_fn: used to filter out the modules that we don't want to apply tenosr subclass
model (torch.nn.Module): input model
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance (e.g. affine quantized tensor instance)
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
the weight of the module

Example::

# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
import torch
import torch.nn as nn
from torchao import quantize

# 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to
# optimized execution paths or kernels (e.g. int4 tinygemm kernel)
# also customizable with arguments
# currently options are
# int8_dynamic_activation_int4_weight (for executorch)
# int8_dynamic_activation_int8_weight (optimized with int8 mm op and torch.compile)
# int4_weight_only (optimized with int4 tinygemm kernel and torch.compile)
# int8_weight_only (optimized with int8 mm op and torch.compile
from torchao.quantization.quant_api import int4_weight_only

m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
m = quantize(m, int4_weight_only(group_size=32))

# 2. write your own new apply_tensor_subclass
# You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor
# on weight

from torchao.dtypes import to_affine_quantized

# weight only uint4 asymmetric groupwise quantization
groupsize = 32
apply_weight_quant = lambda x: to_affine_quantized(
x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6,
zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float")

# apply to modules under block0 submodule
def filter_fn(module, fqn):
return fqn == "block0"
def filter_fn(module: nn.Module, fqn: str) -> bool:
return isinstance(module, nn.Linear)

m = MyModel(...)
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
m = quantize(m, apply_weight_quant, filter_fn)

"""
if isinstance(apply_tensor_subclass, str):
if apply_tensor_subclass not in _APPLY_TS_TABLE:
Expand Down