|
18 | 18 | import torch |
19 | 19 | import torch.nn as nn |
20 | 20 | import torch.nn.functional as F |
21 | | -from typing import Any, Callable, Union, Dict |
| 21 | +from typing import Any, Callable, Union, Dict, Optional |
22 | 22 |
|
23 | 23 | from torchao.utils import ( |
24 | 24 | TORCH_VERSION_AFTER_2_4, |
@@ -258,38 +258,53 @@ def insert_subclass(lin): |
258 | 258 |
|
259 | 259 | return insert_subclass |
260 | 260 |
|
261 | | -def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None) -> torch.nn.Module: |
| 261 | +def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module: |
262 | 262 | """Convert the weight of linear modules in the model with `apply_tensor_subclass` |
263 | 263 |
|
264 | 264 | Args: |
265 | | - model: input model |
266 | | - apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance |
267 | | - filter_fn: used to filter out the modules that we don't want to apply tenosr subclass |
| 265 | + model (torch.nn.Module): input model |
| 266 | + apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance (e.g. affine quantized tensor instance) |
| 267 | + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on |
| 268 | + the weight of the module |
268 | 269 |
|
269 | 270 | Example:: |
270 | 271 |
|
271 | | - # weight settings |
272 | | - groupsize = 32 |
273 | | - mapping_type = MappingType.ASYMMETRIC |
274 | | - block_size = (1, groupsize) |
275 | | - target_dtype = torch.int32 |
276 | | - quant_min = 0 |
277 | | - quant_max = 15 |
278 | | - eps = 1e-6 |
279 | | - preserve_zero = False |
280 | | - zero_point_dtype = torch.bfloat16 |
281 | | - zero_point_domain = ZeroPointDomain.FLOAT |
| 272 | + import torch |
| 273 | + import torch.nn as nn |
| 274 | + from torchao import quantize |
| 275 | +
|
| 276 | + # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to |
| 277 | + # optimized execution paths or kernels (e.g. int4 tinygemm kernel) |
| 278 | + # also customizable with arguments |
| 279 | + # currently options are |
| 280 | + # int8_dynamic_activation_int4_weight (for executorch) |
| 281 | + # int8_dynamic_activation_int8_weight (optimized with int8 mm op and torch.compile) |
| 282 | + # int4_weight_only (optimized with int4 tinygemm kernel and torch.compile) |
| 283 | + # int8_weight_only (optimized with int8 mm op and torch.compile |
| 284 | + from torchao.quantization.quant_api import int4_weight_only |
| 285 | +
|
| 286 | + m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) |
| 287 | + m = quantize(m, int4_weight_only(group_size=32)) |
| 288 | +
|
| 289 | + # 2. write your own new apply_tensor_subclass |
| 290 | + # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor |
| 291 | + # on weight |
282 | 292 |
|
| 293 | + from torchao.dtypes import to_affine_quantized |
| 294 | +
|
| 295 | + # weight only uint4 asymmetric groupwise quantization |
| 296 | + groupsize = 32 |
283 | 297 | apply_weight_quant = lambda x: to_affine_quantized( |
284 | | - x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, |
285 | | - zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) |
| 298 | + x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6, |
| 299 | + zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float") |
286 | 300 |
|
287 | 301 | # apply to modules under block0 submodule |
288 | | - def filter_fn(module, fqn): |
289 | | - return fqn == "block0" |
| 302 | + def filter_fn(module: nn.Module, fqn: str) -> bool: |
| 303 | + return isinstance(module, nn.Linear) |
290 | 304 |
|
291 | | - m = MyModel(...) |
| 305 | + m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) |
292 | 306 | m = quantize(m, apply_weight_quant, filter_fn) |
| 307 | +
|
293 | 308 | """ |
294 | 309 | if isinstance(apply_tensor_subclass, str): |
295 | 310 | if apply_tensor_subclass not in _APPLY_TS_TABLE: |
|
0 commit comments