Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 96 additions & 3 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
TensorCoreTiledLayout,
)
from torchao.quantization import (
Int4TilePackedTo4dTensor,
IntxUnpackedToInt8Tensor,
LinearActivationQuantizedTensor,
PerGroup,
)
Expand All @@ -57,9 +59,6 @@
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import (
IntxUnpackedToInt8Tensor,
)
from torchao.quantization.subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -691,6 +690,100 @@ def test_module_fqn_to_config_module_name(self):
assert isinstance(model.linear2.weight, AffineQuantizedTensor)
assert isinstance(model.linear2.weight._layout, PlainLayout)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_module_fqn_to_config_regex_basic(self):
config1 = Int4WeightOnlyConfig(
group_size=32, int4_packing_format="tile_packed_to_4d"
)
config = ModuleFqnToConfig({"re:linear.*": config1})
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
quantize_(model, config)
model(*example_inputs)
assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
assert isinstance(model.linear2.weight, Int4TilePackedTo4dTensor)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_module_fqn_to_config_regex_precedence(self):
"""Testing that full path config takes precedence over
regex config in ModuleFqnToConfig
"""
config1 = Int4WeightOnlyConfig(
group_size=32, int4_packing_format="tile_packed_to_4d"
)
config2 = IntxWeightOnlyConfig()
config = ModuleFqnToConfig({"linear1": config1, "re:linear.*": config2})
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
quantize_(model, config)
model(*example_inputs)
assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_module_fqn_to_config_regex_precedence2(self):
"""Testing that full path config takes precedence over
regex config in ModuleFqnToConfig, swapping
the order of `re:linear.*` and `linear1` to make sure that
`linear1` config has precedence even it comes after `linear*`
"""
config1 = Int4WeightOnlyConfig(
group_size=32, int4_packing_format="tile_packed_to_4d"
)
config2 = IntxWeightOnlyConfig()
config = ModuleFqnToConfig({"re:linear.*": config2, "linear1": config1})
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
quantize_(model, config)
model(*example_inputs)
assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_module_fqn_to_config_regex_fullmatch(self):
"""Testing that we will only match the fqns that fully
matches the regex
"""

class M(torch.nn.Module):
def __init__(self, dtype, device):
super().__init__()
self.dtype = dtype
self.device = device
self.linear1 = torch.nn.Linear(32, 64, dtype=dtype, device=device)
self.not_full_match_linear2 = torch.nn.Linear(
64, 32, dtype=dtype, device=device
)
self.linear3_full_match = torch.nn.Linear(
32, 32, dtype=dtype, device=device
)

def forward(self, x):
x = self.linear1(x)
x = self.not_full_match_linear2(x)
x = self.linear3_full_match(x)
return

def example_inputs(self):
return (torch.randn(1, 32, dtype=self.dtype, device=self.device),)

config1 = Int4WeightOnlyConfig(
group_size=32, int4_packing_format="tile_packed_to_4d"
)
config2 = IntxWeightOnlyConfig()
config = ModuleFqnToConfig({"re:linear.*": config2, "linear1": config1})
model = M(dtype=torch.bfloat16, device="cuda")
example_inputs = model.example_inputs()
quantize_(model, config)
model(*example_inputs)
assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
# since fqn does not fully match `linear*`, it should not be quantized
assert not isinstance(
model.not_full_match_linear2.weight, IntxUnpackedToInt8Tensor
)
# linear3_full_match matches `linear*`, so should be quantized
assert isinstance(model.linear3_full_match.weight, IntxUnpackedToInt8Tensor)

def test_module_fqn_to_config_embedding_linear(self):
weight_dtype = torch.int8
granularity = PerGroup(8)
Expand Down
44 changes: 34 additions & 10 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
"""

import logging
import re
import types
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import OrderedDict as OrderedDictType

import torch
import torch.nn as nn
Expand Down Expand Up @@ -2366,15 +2369,28 @@ class ModuleFqnToConfig(AOBaseConfig):
"""Per module configurations for torchao quantize_ API

Args:
`module_fqn_to_config`: Dict[str, Optional[AOBaseConfig]]: a dictionary from
the fully qualified name of module to the AOBaseConfig that we want to apply to the module.
Also has a special key: "_default", if "_default" is present in the dictionary,
the config for "_default" will be applied to all the remaining modules that does not have
per module configuration specified.
`module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an
ordered dictionary from
(1). fully qualified name (fqn) of module or
(2). regex of fully qualified name (in python `re` module regex format), should
start with prefix "re:" or
(3). "_default"
to the config that we want to apply to the module or None

Config key ordered by precedence:
* fully qualified module name, e.g. `language.layers.0.q_proj`
* regex for module names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj`,
whiever regex fully matches the module fqn first will be applied
(order of keys for dictionary are kept consistent since we are using OrderedDict)
* "_default", fallback for **all modules** if no match for all previous keys
(Note, when using `_default`, the config is applied to all modules, to apply
it to only a subset of modules, e.g. with some types, it's better to filter
the modules that we don't want to quantize before hand and configure them to
None, e.g. `{"re:.+norm.+": None, "_default": linear_config}`)
"""

module_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field(
default_factory=dict
module_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field(
default_factory=OrderedDict
)

def __post_init__(self):
Expand All @@ -2389,8 +2405,16 @@ def _module_fqn_to_config_handler(
# Maybe: we can add module type specific config in the future, in needed
c = config.module_fqn_to_config[module_fqn]
else:
# fallback to use default if no module specific config is provided
c = config.module_fqn_to_config.get("_default", None)
for maybe_module_fqn_pattern in config.module_fqn_to_config:
if not maybe_module_fqn_pattern.startswith("re:"):
continue
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
# we'll apply the config for first fully matched pattern
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
break
else:
# fallback to use default if no module specific config is provided
c = config.module_fqn_to_config.get("_default", None)

if c is not None:
handler = _QUANTIZE_CONFIG_HANDLER[type(c)]
Expand Down
Loading