Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type annotations to torch.nn.quantized.modules.conv #49702

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 1 addition & 10 deletions mypy.ini
Expand Up @@ -91,16 +91,7 @@ ignore_errors = True
[mypy-torch.nn.modules.pooling]
ignore_errors = True

[mypy-torch.nn.qat.modules.activations]
ignore_errors = True

[mypy-torch.nn.qat.modules.conv]
ignore_errors = True

[mypy-torch.nn.quantized.dynamic.modules.linear]
ignore_errors = True

[mypy-torch.nn.quantized.modules.conv]
[mypy-torch.nn.parallel._functions]
ignore_errors = True

[mypy-torch._appdirs]
Expand Down
59 changes: 44 additions & 15 deletions torch/nn/quantized/modules/conv.py
@@ -1,7 +1,7 @@
# coding=utf-8
r"""Quantized convolution modules."""

from typing import Optional, List
from typing import Optional, List, TypeVar

import torch
import torch.nn as nn
Expand All @@ -16,11 +16,17 @@

class _ConvNd(nn.Module):

def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation,
transposed, output_padding,
groups, bias,
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros'):
# All subclasses have this signature - See PR #49702s
raise NotImplementedError

def _init(self, in_channels, out_channels, kernel_size, stride,
padding, dilation,
transposed, output_padding,
groups, bias,
padding_mode='zeros'):
super(_ConvNd, self).__init__()
if padding_mode != 'zeros':
raise NotImplementedError(
Expand Down Expand Up @@ -54,6 +60,15 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,
self.scale = 1.0
self.zero_point = 0

def set_weight_bias(self, qweight, bias_float):
raise NotImplementedError

def bias(self):
raise NotImplementedError

def _weight_bias(self):
raise NotImplementedError

def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}, scale={scale}, zero_point={zero_point}')
Expand Down Expand Up @@ -155,7 +170,8 @@ def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
assert weight_post_process.dtype == torch.qint8, \
'Weight observer must have a dtype of qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
# the __init__ call used is the one from derived classes and not the one from _ConvNd
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the issue is that the subclass __init__ methods have two fewer parameters (no transposed, output_padding) than _ConvNd. The current fix works, but technically the current way the class hierarchy is organized is not correct and the better way to do it would be:

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True,
                 padding_mode='zeros'):
        # All subclasses have this signature
        raise NotImplementedError

    def _init(self, in_channels, out_channels, kernel_size, stride,
              padding, dilation,
              transposed, output_padding,
              groups, bias,
              padding_mode='zeros'):
        super(_ConvNd, self).__init__()

and then subclasses should call super(Conv1d, self)._init(

I think that would be a good change, but it is only possible if there are no subclasses of _ConvNd floating around (it's private, but you never know).

Let's see what other reviewers think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll wait the other reviewers before changing this part of the code

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Soooo it is a little tricky. torch.nn._ConvNd is sadly de facto public API; see https://github.com/pytorch/fairseq/blob/master/fairseq/modules/quantization/scalar/modules/qconv.py#L14 for an example. But this is quantized ConvNd, and this code is relatively new and there ought to be less uses. cc @jerryzh168 for more thoughts.

I'm trying to decide if I agree that the alternate version is better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the high level question I have is why the subclass init methods have fewer parameters. This seems like a blatant oversight to me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the high level question I have is why the subclass init methods have fewer parameters

Because those two parameters are implicitly defined by the dimensionality of subclass - e.g. Conv1d hardcodes transpose, output_padding to False, (0,). It would not make sense for the user to define those when initializing Conv1d.

transpose is always False, so that really doesn't make much sense - probably code just copied from the non-quantized version or something like that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK fair enough. The suggested alternative seems like a reasonable way to shut up the type checker then.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the code to use _init.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, are we going to have similar changes in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py as well?

mod.stride, mod.padding, mod.dilation, mod.groups,
mod.bias is not None, mod.padding_mode)
qconv.set_weight_bias(qweight, mod.bias)
Expand Down Expand Up @@ -233,7 +249,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding = _pair_from_first(padding)
dilation = _pair_from_first(dilation)

super(Conv1d, self).__init__(
# Subclasses of _ConvNd needs to call _init rather than __init__. See
# discussion on PR #49702
super(Conv1d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias, padding_mode)

Expand Down Expand Up @@ -319,7 +337,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(Conv2d, self).__init__(
# Subclasses of _ConvNd need to call _init rather than __init__. See
# discussion on PR #49702
super(Conv2d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)

Expand Down Expand Up @@ -403,7 +423,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
super(Conv3d, self).__init__(
# Subclasses of _ConvNd need to call _init rather than __init__. See
# discussion on PR #49702
super(Conv3d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _triple(0), groups, bias, padding_mode)

Expand Down Expand Up @@ -450,15 +472,20 @@ def from_float(cls, mod):
return cls.get_qconv(mod, activation_post_process)

# === Transposed Convolutions ===
MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)

class _ConvTransposeNd(_ConvNd):

_FLOAT_MODULE = MOD

def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode):
if padding_mode != 'zeros':
raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))

super(_ConvTransposeNd, self).__init__(
# Subclasses of _ConvNd need to call _init rather than __init__. See
# discussion on PR #49702
super(_ConvTransposeNd, self)._init(
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode)
Expand All @@ -477,9 +504,10 @@ def from_float(cls, mod):
mod (Module): a float module, either produced by torch.quantization
utilities or provided by the user
"""
assert type(mod) == cls._FLOAT_MODULE, \
' nnq.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
# derived classes override cls._FLOAT_MODULE attribute
msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
assert type(mod) == cls._FLOAT_MODULE, msg
assert hasattr(mod, 'qconfig'), \
'Input float module must have qconfig defined.'
weight_post_process = mod.qconfig.weight()
Expand All @@ -488,7 +516,8 @@ def from_float(cls, mod):
assert weight_post_process.dtype == torch.qint8, \
'Weight observer must have a dtype of qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
# the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
mod.stride, mod.padding, mod.output_padding, mod.groups,
mod.bias is not None, mod.dilation, mod.padding_mode)
qconv.set_weight_bias(qweight, mod.bias)
Expand Down