Skip to content

Commit fd9f95d

Browse files
authored
Add doc page to README.md (#367)
1 parent 9dc2c11 commit fd9f95d

File tree

4 files changed

+46
-30
lines changed

4 files changed

+46
-30
lines changed

docs/source/api_ref_dtypes.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ torchao.dtypes
1111
:nosignatures:
1212

1313
to_nf4
14-
UInt4Tensor
14+
to_affine_quantized
15+
AffineQuantizedTensor
1516

1617
..
1718
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring

docs/source/api_ref_quantization.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ torchao.quantization
99
.. autosummary::
1010
:toctree: generated/
1111
:nosignatures:
12-
13-
apply_weight_only_int8_quant
14-
apply_dynamic_quant
15-
change_linear_weights_to_int8_dqtensors
16-
change_linear_weights_to_int8_woqtensors
17-
change_linear_weights_to_int4_woqtensors
12+
1813
SmoothFakeDynQuantMixin
1914
SmoothFakeDynamicallyQuantizedLinear
2015
swap_linear_with_smooth_fq_linear
2116
smooth_fq_linear_to_inference
2217
Int4WeightOnlyGPTQQuantizer
2318
Int4WeightOnlyQuantizer
19+
quantize
20+
int8_dynamic_activation_int4_weight
21+
int8_dynamic_activation_int8_weight
22+
int4_weight_only
23+
int8_weight_only

torchao/quantization/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
"dequantize_affine",
3131
"choose_qprams_affine",
3232
"quantize",
33-
"int8_dynamic_act_int4_weight",
34-
"int8_dynamic_act_int8_weight",
33+
"int8_dynamic_activation_int4_weight",
34+
"int8_dynamic_activation_int8_weight",
3535
"int4_weight_only",
3636
"int8_weight_only",
3737
]

torchao/quantization/quant_api.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from typing import Any, Callable, Union, Dict
21+
from typing import Any, Callable, Union, Dict, Optional
2222

2323
from torchao.utils import (
2424
TORCH_VERSION_AFTER_2_4,
@@ -258,38 +258,53 @@ def insert_subclass(lin):
258258

259259
return insert_subclass
260260

261-
def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None) -> torch.nn.Module:
261+
def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module:
262262
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
263263
264264
Args:
265-
model: input model
266-
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance
267-
filter_fn: used to filter out the modules that we don't want to apply tenosr subclass
265+
model (torch.nn.Module): input model
266+
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance (e.g. affine quantized tensor instance)
267+
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
268+
the weight of the module
268269
269270
Example::
270271
271-
# weight settings
272-
groupsize = 32
273-
mapping_type = MappingType.ASYMMETRIC
274-
block_size = (1, groupsize)
275-
target_dtype = torch.int32
276-
quant_min = 0
277-
quant_max = 15
278-
eps = 1e-6
279-
preserve_zero = False
280-
zero_point_dtype = torch.bfloat16
281-
zero_point_domain = ZeroPointDomain.FLOAT
272+
import torch
273+
import torch.nn as nn
274+
from torchao import quantize
275+
276+
# 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to
277+
# optimized execution paths or kernels (e.g. int4 tinygemm kernel)
278+
# also customizable with arguments
279+
# currently options are
280+
# int8_dynamic_activation_int4_weight (for executorch)
281+
# int8_dynamic_activation_int8_weight (optimized with int8 mm op and torch.compile)
282+
# int4_weight_only (optimized with int4 tinygemm kernel and torch.compile)
283+
# int8_weight_only (optimized with int8 mm op and torch.compile
284+
from torchao.quantization.quant_api import int4_weight_only
285+
286+
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
287+
m = quantize(m, int4_weight_only(group_size=32))
288+
289+
# 2. write your own new apply_tensor_subclass
290+
# You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor
291+
# on weight
282292
293+
from torchao.dtypes import to_affine_quantized
294+
295+
# weight only uint4 asymmetric groupwise quantization
296+
groupsize = 32
283297
apply_weight_quant = lambda x: to_affine_quantized(
284-
x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
285-
zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
298+
x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6,
299+
zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float")
286300
287301
# apply to modules under block0 submodule
288-
def filter_fn(module, fqn):
289-
return fqn == "block0"
302+
def filter_fn(module: nn.Module, fqn: str) -> bool:
303+
return isinstance(module, nn.Linear)
290304
291-
m = MyModel(...)
305+
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
292306
m = quantize(m, apply_weight_quant, filter_fn)
307+
293308
"""
294309
if isinstance(apply_tensor_subclass, str):
295310
if apply_tensor_subclass not in _APPLY_TS_TABLE:

0 commit comments

Comments
 (0)