Skip to content

Commit de999a5

Browse files
committed
Add module fqn regex support for ModuleFqnToConfig
Summary: To simplify the config file for torchao quantized models we want to allow people to configure the ModuleFqnToConfig through regex, e.g. `linear*`, `language.layers.*.gate_proj` Test Plan: python test/quantization/test_quant_api.py -k test_module_fqn_to_config_module_name_regex Reviewers: Subscribers: Tasks: Tags:
1 parent 7690612 commit de999a5

File tree

2 files changed

+130
-13
lines changed

2 files changed

+130
-13
lines changed

test/quantization/test_quant_api.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
TensorCoreTiledLayout,
3535
)
3636
from torchao.quantization import (
37+
Int4TilePackedTo4dTensor,
38+
IntxUnpackedToInt8Tensor,
3739
LinearActivationQuantizedTensor,
3840
PerGroup,
3941
)
@@ -57,9 +59,6 @@
5759
_replace_with_custom_fn_if_matches_filter,
5860
)
5961
from torchao.quantization.quant_primitives import MappingType
60-
from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import (
61-
IntxUnpackedToInt8Tensor,
62-
)
6362
from torchao.quantization.subclass import (
6463
Int4WeightOnlyQuantizedLinearWeight,
6564
Int8WeightOnlyQuantizedLinearWeight,
@@ -691,6 +690,100 @@ def test_module_fqn_to_config_module_name(self):
691690
assert isinstance(model.linear2.weight, AffineQuantizedTensor)
692691
assert isinstance(model.linear2.weight._layout, PlainLayout)
693692

693+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
694+
def test_module_fqn_to_config_regex_basic(self):
695+
config1 = Int4WeightOnlyConfig(
696+
group_size=32, int4_packing_format="tile_packed_to_4d"
697+
)
698+
config = ModuleFqnToConfig({"re:linear.*": config1})
699+
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
700+
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
701+
quantize_(model, config)
702+
model(*example_inputs)
703+
assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
704+
assert isinstance(model.linear2.weight, Int4TilePackedTo4dTensor)
705+
706+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
707+
def test_module_fqn_to_config_regex_precedence(self):
708+
"""Testing that full path config takes precedence over
709+
regex config in ModuleFqnToConfig
710+
"""
711+
config1 = Int4WeightOnlyConfig(
712+
group_size=32, int4_packing_format="tile_packed_to_4d"
713+
)
714+
config2 = IntxWeightOnlyConfig()
715+
config = ModuleFqnToConfig({"linear1": config1, "re:linear.*": config2})
716+
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
717+
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
718+
quantize_(model, config)
719+
model(*example_inputs)
720+
assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
721+
assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor)
722+
723+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
724+
def test_module_fqn_to_config_regex_precedence2(self):
725+
"""Testing that full path config takes precedence over
726+
regex config in ModuleFqnToConfig, swapping
727+
the order of `re:linear.*` and `linear1` to make sure that
728+
`linear1` config has precedence even it comes after `linear*`
729+
"""
730+
config1 = Int4WeightOnlyConfig(
731+
group_size=32, int4_packing_format="tile_packed_to_4d"
732+
)
733+
config2 = IntxWeightOnlyConfig()
734+
config = ModuleFqnToConfig({"re:linear.*": config2, "linear1": config1})
735+
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
736+
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
737+
quantize_(model, config)
738+
model(*example_inputs)
739+
assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
740+
assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor)
741+
742+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
743+
def test_module_fqn_to_config_regex_fullmatch(self):
744+
"""Testing that we will only match the fqns that fully
745+
matches the regex
746+
"""
747+
748+
class M(torch.nn.Module):
749+
def __init__(self, dtype, device):
750+
super().__init__()
751+
self.dtype = dtype
752+
self.device = device
753+
self.linear1 = torch.nn.Linear(32, 64, dtype=dtype, device=device)
754+
self.not_full_match_linear2 = torch.nn.Linear(
755+
64, 32, dtype=dtype, device=device
756+
)
757+
self.linear3_full_match = torch.nn.Linear(
758+
32, 32, dtype=dtype, device=device
759+
)
760+
761+
def forward(self, x):
762+
x = self.linear1(x)
763+
x = self.not_full_match_linear2(x)
764+
x = self.linear3_full_match(x)
765+
return
766+
767+
def example_inputs(self):
768+
return (torch.randn(1, 32, dtype=self.dtype, device=self.device),)
769+
770+
config1 = Int4WeightOnlyConfig(
771+
group_size=32, int4_packing_format="tile_packed_to_4d"
772+
)
773+
config2 = IntxWeightOnlyConfig()
774+
config = ModuleFqnToConfig({"re:linear.*": config2, "linear1": config1})
775+
model = M(dtype=torch.bfloat16, device="cuda")
776+
example_inputs = model.example_inputs()
777+
quantize_(model, config)
778+
model(*example_inputs)
779+
assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
780+
# since fqn does not fully match `linear*`, it should not be quantized
781+
assert not isinstance(
782+
model.not_full_match_linear2.weight, IntxUnpackedToInt8Tensor
783+
)
784+
# linear3_full_match matches `linear*`, so should be quantized
785+
assert isinstance(model.linear3_full_match.weight, IntxUnpackedToInt8Tensor)
786+
694787
def test_module_fqn_to_config_embedding_linear(self):
695788
weight_dtype = torch.int8
696789
granularity = PerGroup(8)

torchao/quantization/quant_api.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
"""
1717

1818
import logging
19+
import re
1920
import types
2021
import warnings
22+
from collections import OrderedDict
2123
from dataclasses import dataclass, field
22-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
24+
from typing import Any, Callable, List, Optional, Tuple, Union
25+
from typing import OrderedDict as OrderedDictType
2326

2427
import torch
2528
import torch.nn as nn
@@ -2366,15 +2369,28 @@ class ModuleFqnToConfig(AOBaseConfig):
23662369
"""Per module configurations for torchao quantize_ API
23672370
23682371
Args:
2369-
`module_fqn_to_config`: Dict[str, Optional[AOBaseConfig]]: a dictionary from
2370-
the fully qualified name of module to the AOBaseConfig that we want to apply to the module.
2371-
Also has a special key: "_default", if "_default" is present in the dictionary,
2372-
the config for "_default" will be applied to all the remaining modules that does not have
2373-
per module configuration specified.
2372+
`module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an
2373+
ordered dictionary from
2374+
(1). fully qualified name (fqn) of module or
2375+
(2). regex of fully qualified name (in python `re` module regex format), should
2376+
start with prefix "re:" or
2377+
(3). "_default"
2378+
to the config that we want to apply to the module or None
2379+
2380+
Config key ordered by precedence:
2381+
* fully qualified module name, e.g. `language.layers.0.q_proj`
2382+
* regex for module names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj`,
2383+
whiever regex fully matches the module fqn first will be applied
2384+
(order of keys for dictionary are kept consistent since we are using OrderedDict)
2385+
* "_default", fallback for **all modules** if no match for all previous keys
2386+
(Note, when using `_default`, the config is applied to all modules, to apply
2387+
it to only a subset of modules, e.g. with some types, it's better to filter
2388+
the modules that we don't want to quantize before hand and configure them to
2389+
None, e.g. `{"re:.+norm.+": None, "_default": linear_config}`)
23742390
"""
23752391

2376-
module_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field(
2377-
default_factory=dict
2392+
module_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field(
2393+
default_factory=OrderedDict
23782394
)
23792395

23802396
def __post_init__(self):
@@ -2389,8 +2405,16 @@ def _module_fqn_to_config_handler(
23892405
# Maybe: we can add module type specific config in the future, in needed
23902406
c = config.module_fqn_to_config[module_fqn]
23912407
else:
2392-
# fallback to use default if no module specific config is provided
2393-
c = config.module_fqn_to_config.get("_default", None)
2408+
for maybe_module_fqn_pattern in config.module_fqn_to_config:
2409+
if not maybe_module_fqn_pattern.startswith("re:"):
2410+
continue
2411+
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
2412+
# we'll apply the config for first fully matched pattern
2413+
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
2414+
break
2415+
else:
2416+
# fallback to use default if no module specific config is provided
2417+
c = config.module_fqn_to_config.get("_default", None)
23942418

23952419
if c is not None:
23962420
handler = _QUANTIZE_CONFIG_HANDLER[type(c)]

0 commit comments

Comments
 (0)