Skip to content

Commit 501eebd

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Add QuantizationMixin.ignore_unknown_modules API
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent dc201dc commit 501eebd

File tree

4 files changed

+58
-4
lines changed

4 files changed

+58
-4
lines changed

Docs/apiref/torch/nn.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ for :class:`torch.nn.Conv2d` or :class:`QuantizedSoftmax` for :class:`torch.nn.S
2121

2222
.. autoclass:: QuantizationMixin
2323
:noindex:
24-
:members: __quant_init__, forward, compute_encodings, ignore
24+
:members: __quant_init__, forward, compute_encodings, ignore, ignore_unknown_modules
2525

2626
Configuration
2727
-------------

TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ class BaseQuantizationMixin(abc.ABC):
233233
cls_to_qcls: dict
234234
qcls_to_cls: dict
235235

236+
_ignore_unknown_modules: bool = False
236237
_ignored_module_types = set()
237238

238239
def __init__(self, *args, **kwargs):
@@ -379,6 +380,13 @@ def ignore(cls, module_cls):
379380

380381
cls._ignored_module_types.add(module_cls)
381382

383+
@classmethod
384+
def ignore_unknown_modules(cls, ignore: bool = True):
385+
"""
386+
Exempt all unkown module types from quantization
387+
"""
388+
cls._ignore_unknown_modules = ignore
389+
382390
@classmethod
383391
def implements(cls, module_cls):
384392
"""
@@ -428,7 +436,7 @@ def from_module(cls, module: nn.Module):
428436
qtzn_module_cls = cls.cls_to_qcls.get(module_cls, None)
429437

430438
if not qtzn_module_cls:
431-
if module_cls in cls._ignored_module_types:
439+
if cls._ignore_unknown_modules or module_cls in cls._ignored_module_types:
432440
return module
433441
raise UnknownModuleError(module_cls, cls)
434442

TrainingExtensions/torch/src/python/aimet_torch/v2/nn/true_quant.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@ def ignore(cls, module_cls):
381381
Example:
382382
383383
>>> class MyModule(torch.nn.Module):
384-
... def forward(self, x):
385-
... return x ** 2
384+
... def forward(self, x):
385+
... return x ** 2
386386
>>> QuantizationMixin.ignore(MyModule)
387387
>>> model = torch.nn.Sequential(MyModule())
388388
>>> sim = aimet_torch.QuantizationSimModel(model, torch.randn(10, 10))
@@ -393,6 +393,26 @@ def ignore(cls, module_cls):
393393
"""
394394
super().ignore(module_cls)
395395

396+
@classmethod
397+
def ignore_unknown_modules(cls, ignore: bool = True):
398+
"""
399+
Exempt all unkown module types from quantization
400+
401+
Example:
402+
403+
>>> class MyModule(torch.nn.Module):
404+
... def forward(self, x):
405+
... return x ** 2
406+
>>> QuantizationMixin.ignore_unknown_modules(True)
407+
>>> model = torch.nn.Sequential(MyModule())
408+
>>> sim = aimet_torch.QuantizationSimModel(model, torch.randn(10, 10))
409+
>>> print(sim.model)
410+
Sequential(
411+
(0): MyModule()
412+
)
413+
"""
414+
super().ignore_unknown_modules(ignore)
415+
396416
@classmethod
397417
def implements(cls, module_cls):
398418
r"""

TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,6 +1805,11 @@ def test_fold_param_quantizers(device, requires_grad):
18051805

18061806

18071807
def test_ignore():
1808+
"""
1809+
When: Call QuantizationMixin.ignore
1810+
Then: Unknown modules should be ignored during quantization
1811+
"""
1812+
18081813
class MyModule(torch.nn.Module):
18091814
def forward(self, x):
18101815
return x**2
@@ -1816,3 +1821,24 @@ def forward(self, x):
18161821
model, dummy_input=torch.randn(1, 3, 224, 224)
18171822
)
18181823
assert type(sim.model[0]) == MyModule
1824+
1825+
"""
1826+
When: Call QuantizationMixin.ignore_unknown_modules
1827+
Then: Unknown modules should be ignored during quantization
1828+
"""
1829+
1830+
class MyModule(torch.nn.Module):
1831+
def forward(self, x):
1832+
return x**2
1833+
1834+
orig = QuantizationMixin._ignore_unknown_modules
1835+
try:
1836+
QuantizationMixin.ignore_unknown_modules(True)
1837+
1838+
model = torch.nn.Sequential(MyModule())
1839+
sim = aimet_torch.QuantizationSimModel(
1840+
model, dummy_input=torch.randn(1, 3, 224, 224)
1841+
)
1842+
assert type(sim.model[0]) == MyModule
1843+
finally:
1844+
QuantizationMixin.ignore_unknown_modules(orig)

0 commit comments

Comments
 (0)