Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quant: nice error message on convtranspose with per-channel weight #49899

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/quantization/test_quantize.py
Expand Up @@ -726,6 +726,20 @@ def forward(self, x):
ref_res = ref_m(data)
self.assertEqual(res, ref_res)

@skipIfNoFBGEMM
def test_convtranspose_per_channel_fails_early(self):
r"""
Verifies that attempting to quantize a ConvTranspose module with per-Channel
weight observers fails in the prepare step, as opposed to the convert step.
"""
m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1))
m.qconfig = torch.quantization.get_default_qconfig('fbgemm')
with self.assertRaises(AssertionError) as context:
mp = torch.quantization.prepare(m)
self.assertTrue(
str(context.exception) ==
'Per channel weight observer is not supported yet for ConvTranspose{n}d.')


@skipIfNoFBGEMM
class TestPostTrainingDynamic(QuantizationTestCase):
Expand Down
15 changes: 15 additions & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -1275,6 +1275,21 @@ def test_fp32_input_fp32_output(self):
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, prepare_count_check, convert_count_check)

@skipIfNoFBGEMM
def test_convtranspose_per_channel_fails_early(self):
r"""
Verifies that attempting to quantize a ConvTranspose module with per-Channel
weight observers fails in the prepare step, as opposed to the convert step.
"""
m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1))
m.eval()
qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')}
with self.assertRaises(AssertionError) as context:
mp = prepare_fx(m, qconfig_dict)
self.assertTrue(
str(context.exception) ==
'Per channel weight observer is not supported yet for ConvTranspose{n}d.')

@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
Expand Down
17 changes: 17 additions & 0 deletions torch/quantization/qconfig.py
Expand Up @@ -3,6 +3,8 @@
from .fake_quantize import *
import torch.nn as nn

from typing import Union

class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
"""
Describes how to quantize a layer or a part of the network by providing
Expand Down Expand Up @@ -109,3 +111,18 @@ def get_default_qat_qconfig(backend='fbgemm'):
else:
qconfig = default_qat_qconfig
return qconfig

def assert_valid_qconfig(qconfig: Union[QConfig, QConfigDynamic],
mod: torch.nn.Module) -> None:
is_conv_transpose_mod = (
isinstance(mod, torch.nn.ConvTranspose1d) or
isinstance(mod, torch.nn.ConvTranspose2d) or
isinstance(mod, torch.nn.ConvTranspose3d))
if is_conv_transpose_mod:
example_observer = qconfig.weight()
is_per_channel = (
isinstance(example_observer, torch.quantization.PerChannelMinMaxObserver),
isinstance(example_observer, torch.quantization.MovingAveragePerChannelMinMaxObserver),
)
assert not is_per_channel, \
'Per channel weight observer is not supported yet for ConvTranspose{n}d.'
2 changes: 2 additions & 0 deletions torch/quantization/quantize.py
Expand Up @@ -49,6 +49,8 @@ def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None,
module_qconfig = qconfig_dict.get(prefix, module_qconfig)
module_qconfig = getattr(module, 'qconfig', module_qconfig)

torch.quantization.qconfig.assert_valid_qconfig(module_qconfig, module)

module.qconfig = module_qconfig
for name, child in module.named_children():
module_prefix = prefix + '.' + name if prefix else name
Expand Down