Skip to content

Commit 5a419f3

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Add QuantizationMixin.ignore API
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent e026fd1 commit 5a419f3

File tree

5 files changed

+55
-1
lines changed

5 files changed

+55
-1
lines changed

Docs/apiref/torch/nn.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ for :class:`torch.nn.Conv2d` or :class:`QuantizedSoftmax` for :class:`torch.nn.S
2121

2222
.. autoclass:: QuantizationMixin
2323
:noindex:
24-
:members: __quant_init__, forward, compute_encodings
24+
:members: __quant_init__, forward, compute_encodings, ignore
2525

2626
Configuration
2727
-------------

TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def generate_err_msg(self) -> str:
107107
f"Please register the quantized module definition of {module_cls} "
108108
f"using `@{mixin_cls.__name__}.implements({module_cls.__name__})` decorator.\n\n"
109109
f"For example:\n\n{code_example}\n\n"
110+
"If you believe this module need not be quantized, please exclude it from quantization by calling "
111+
f"`QuantizationMixin.ignore({module_cls.__name__})`.\n\n"
110112
f"For more details, please refer to the official API reference:\n{self.api_reference_url}"
111113
)
112114

@@ -231,6 +233,8 @@ class BaseQuantizationMixin(abc.ABC):
231233
cls_to_qcls: dict
232234
qcls_to_cls: dict
233235

236+
_ignored_module_types = set()
237+
234238
def __init__(self, *args, **kwargs):
235239
super().__init__(*args, **kwargs)
236240
self.__quant_init__()
@@ -365,6 +369,16 @@ def wrap(cls, module_cls: Type[nn.Module]):
365369
Wrap a regular module class into a quantized module class
366370
"""
367371

372+
@classmethod
373+
def ignore(cls, module_cls):
374+
"""
375+
Exempt given module type from quantization
376+
"""
377+
if not issubclass(module_cls, torch.nn.Module):
378+
raise RuntimeError
379+
380+
cls._ignored_module_types.add(module_cls)
381+
368382
@classmethod
369383
def implements(cls, module_cls):
370384
"""
@@ -414,6 +428,8 @@ def from_module(cls, module: nn.Module):
414428
qtzn_module_cls = cls.cls_to_qcls.get(module_cls, None)
415429

416430
if not qtzn_module_cls:
431+
if module_cls in cls._ignored_module_types:
432+
return module
417433
raise UnknownModuleError(module_cls, cls)
418434

419435
qtzn_module = cls.__new__(qtzn_module_cls)

TrainingExtensions/torch/src/python/aimet_torch/v2/nn/true_quant.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,26 @@ def from_module(cls, module: nn.Module):
373373
"""
374374
return super().from_module(module)
375375

376+
@classmethod
377+
def ignore(cls, module_cls):
378+
"""
379+
Exempt given module type from quantization
380+
381+
Example:
382+
383+
>>> class MyModule(torch.nn.Module):
384+
... def forward(self, x):
385+
... return x ** 2
386+
>>> QuantizationMixin.ignore(MyModule)
387+
>>> model = torch.nn.Sequential(MyModule())
388+
>>> sim = aimet_torch.QuantizationSimModel(model, torch.randn(10, 10))
389+
>>> print(sim.model)
390+
Sequential(
391+
(0): MyModule()
392+
)
393+
"""
394+
super().ignore(module_cls)
395+
376396
@classmethod
377397
def implements(cls, module_cls):
378398
r"""

TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ def _convert_to_qmodule(module: torch.nn.Module):
185185
f"using `@{e.mixin_cls.__name__}.implements()` decorator.",
186186
"For example:",
187187
*(e.generate_code_example() for e in exceptions.values()),
188+
"If you believe these modules need not be quantized, "
189+
"please exclude them from quantization by calling `QuantizationMixin.ignore` "
190+
f"(for example, `QuantizationMixin.ignore({e.module_cls.__name__})`) "
191+
"before creating QuantizationSimModel.",
188192
f"For more details, please refer to the official API reference:\n{e.api_reference_url}",
189193
]
190194
)

TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,3 +1802,17 @@ def test_fold_param_quantizers(device, requires_grad):
18021802
assert isinstance(qlinear.bias, torch.Tensor)
18031803
assert isinstance(qlinear.bias, torch.nn.Parameter)
18041804
assert torch.equal(qlinear.bias, original_bias)
1805+
1806+
1807+
def test_ignore():
1808+
class MyModule(torch.nn.Module):
1809+
def forward(self, x):
1810+
return x**2
1811+
1812+
QuantizationMixin.ignore(MyModule)
1813+
1814+
model = torch.nn.Sequential(MyModule())
1815+
sim = aimet_torch.QuantizationSimModel(
1816+
model, dummy_input=torch.randn(1, 3, 224, 224)
1817+
)
1818+
assert type(sim.model[0]) == MyModule

0 commit comments

Comments
 (0)