Skip to content

Commit

Permalink
quant docs: add and clean up InstanceNorm{n}d
Browse files Browse the repository at this point in the history
Fixes docstrings and adds to quantization docs for quantized InstanceNorm.

Test plan:
* build on Mac OS and inspect

ghstack-source-id: f4d38fdc7f9ea84057e680045394522c09093074
Pull Request resolved: #40345
  • Loading branch information
vkuzo committed Jun 20, 2020
1 parent ad62f2e commit a2ee513
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 9 deletions.
35 changes: 35 additions & 0 deletions docs/source/quantization.rst
Expand Up @@ -243,6 +243,9 @@ Layers for the quantization-aware training
* :class:`~torch.nn.qat.Hardswish` — Hardswish
* :class:`~torch.nn.qat.LayerNorm` — LayerNorm
* :class:`~torch.nn.qat.GroupNorm` — GroupNorm
* :class:`~torch.nn.qat.InstanceNorm1d` — InstanceNorm1d
* :class:`~torch.nn.qat.InstanceNorm2d` — InstanceNorm2d
* :class:`~torch.nn.qat.InstanceNorm3d` — InstanceNorm3d

``torch.quantization``
~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -356,6 +359,9 @@ Quantized version of standard NN layers.
* :class:`~torch.nn.quantized.Hardswish` — Hardswish
* :class:`~torch.nn.quantized.LayerNorm` — LayerNorm. *Note: performance on ARM is not optimized*.
* :class:`~torch.nn.quantized.GroupNorm` — GroupNorm. *Note: performance on ARM is not optimized*.
* :class:`~torch.nn.quantized.InstanceNorm1d` — InstanceNorm1d. *Note: performance on ARM is not optimized*.
* :class:`~torch.nn.quantized.InstanceNorm2d` — InstanceNorm2d. *Note: performance on ARM is not optimized*.
* :class:`~torch.nn.quantized.InstanceNorm3d` — InstanceNorm3d. *Note: performance on ARM is not optimized*.

``torch.nn.quantized.dynamic``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -721,6 +727,20 @@ GroupNorm
.. autoclass:: GroupNorm
:members:

InstanceNorm1d
~~~~~~~~~~~~~~~
.. autoclass:: InstanceNorm1d
:members:

InstanceNorm2d
~~~~~~~~~~~~~~~
.. autoclass:: InstanceNorm2d
:members:

InstanceNorm3d
~~~~~~~~~~~~~~~
.. autoclass:: InstanceNorm3d
:members:

torch.nn.quantized
----------------------------
Expand Down Expand Up @@ -814,6 +834,21 @@ GroupNorm
.. autoclass:: GroupNorm
:members:

InstanceNorm1d
~~~~~~~~~~~~~~~
.. autoclass:: InstanceNorm1d
:members:

InstanceNorm2d
~~~~~~~~~~~~~~~
.. autoclass:: InstanceNorm2d
:members:

InstanceNorm3d
~~~~~~~~~~~~~~~
.. autoclass:: InstanceNorm3d
:members:

torch.nn.quantized.dynamic
----------------------------

Expand Down
12 changes: 6 additions & 6 deletions torch/nn/qat/modules/normalization.py
Expand Up @@ -47,8 +47,8 @@ def from_float(cls, mod, qconfig=None):

class InstanceNorm1d(nn.InstanceNorm1d):
r"""
A InstanceNorm1d module attached with FakeQuantize modules for both output
activation and weight, used for quantization aware training.
A InstanceNorm1d module attached with FakeQuantize modules for output
activation, used for quantization aware training.
Similar to `torch.nn.InstanceNorm1d`, with FakeQuantize modules initialized to
default.
Expand Down Expand Up @@ -92,8 +92,8 @@ def from_float(cls, mod, qconfig=None):

class InstanceNorm2d(nn.InstanceNorm2d):
r"""
A InstanceNorm2d module attached with FakeQuantize modules for both output
activation and weight, used for quantization aware training.
A InstanceNorm2d module attached with FakeQuantize modules for output
activation, used for quantization aware training.
Similar to `torch.nn.InstanceNorm2d`, with FakeQuantize modules initialized to
default.
Expand Down Expand Up @@ -137,8 +137,8 @@ def from_float(cls, mod, qconfig=None):

class InstanceNorm3d(nn.InstanceNorm3d):
r"""
A InstanceNorm3d module attached with FakeQuantize modules for both output
activation and weight, used for quantization aware training.
A InstanceNorm3d module attached with FakeQuantize modules for output
activation, used for quantization aware training.
Similar to `torch.nn.InstanceNorm3d`, with FakeQuantize modules initialized to
default.
Expand Down
21 changes: 18 additions & 3 deletions torch/nn/quantized/modules/normalization.py
Expand Up @@ -76,7 +76,12 @@ def from_float(cls, mod):
return new_mod

class InstanceNorm1d(torch.nn.InstanceNorm1d):
r"""This is the quantized version of `torch.nn.InstanceNorm1d`.
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(self, num_features, weight, bias, scale, zero_point,
eps=1e-5, momentum=0.1, affine=False,
Expand Down Expand Up @@ -106,7 +111,12 @@ def from_float(cls, mod):
return new_mod

class InstanceNorm2d(torch.nn.InstanceNorm2d):
r"""This is the quantized version of `torch.nn.InstanceNorm2d`.
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(self, num_features, weight, bias, scale, zero_point,
eps=1e-5, momentum=0.1, affine=False,
Expand Down Expand Up @@ -136,7 +146,12 @@ def from_float(cls, mod):
return new_mod

class InstanceNorm3d(torch.nn.InstanceNorm3d):
r"""This is the quantized version of `torch.nn.InstanceNorm2d`.
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(self, num_features, weight, bias, scale, zero_point,
eps=1e-5, momentum=0.1, affine=False,
Expand Down

0 comments on commit a2ee513

Please sign in to comment.