diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index 02e262a3d646..5b6dac8081dd 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -243,6 +243,9 @@ Layers for the quantization-aware training * :class:`~torch.nn.qat.Hardswish` — Hardswish * :class:`~torch.nn.qat.LayerNorm` — LayerNorm * :class:`~torch.nn.qat.GroupNorm` — GroupNorm +* :class:`~torch.nn.qat.InstanceNorm1d` — InstanceNorm1d +* :class:`~torch.nn.qat.InstanceNorm2d` — InstanceNorm2d +* :class:`~torch.nn.qat.InstanceNorm3d` — InstanceNorm3d ``torch.quantization`` ~~~~~~~~~~~~~~~~~~~~~~ @@ -356,6 +359,9 @@ Quantized version of standard NN layers. * :class:`~torch.nn.quantized.Hardswish` — Hardswish * :class:`~torch.nn.quantized.LayerNorm` — LayerNorm. *Note: performance on ARM is not optimized*. * :class:`~torch.nn.quantized.GroupNorm` — GroupNorm. *Note: performance on ARM is not optimized*. +* :class:`~torch.nn.quantized.InstanceNorm1d` — InstanceNorm1d. *Note: performance on ARM is not optimized*. +* :class:`~torch.nn.quantized.InstanceNorm2d` — InstanceNorm2d. *Note: performance on ARM is not optimized*. +* :class:`~torch.nn.quantized.InstanceNorm3d` — InstanceNorm3d. *Note: performance on ARM is not optimized*. ``torch.nn.quantized.dynamic`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -721,6 +727,20 @@ GroupNorm .. autoclass:: GroupNorm :members: +InstanceNorm1d +~~~~~~~~~~~~~~~ +.. autoclass:: InstanceNorm1d + :members: + +InstanceNorm2d +~~~~~~~~~~~~~~~ +.. autoclass:: InstanceNorm2d + :members: + +InstanceNorm3d +~~~~~~~~~~~~~~~ +.. autoclass:: InstanceNorm3d + :members: torch.nn.quantized ---------------------------- @@ -814,6 +834,21 @@ GroupNorm .. autoclass:: GroupNorm :members: +InstanceNorm1d +~~~~~~~~~~~~~~~ +.. autoclass:: InstanceNorm1d + :members: + +InstanceNorm2d +~~~~~~~~~~~~~~~ +.. autoclass:: InstanceNorm2d + :members: + +InstanceNorm3d +~~~~~~~~~~~~~~~ +.. autoclass:: InstanceNorm3d + :members: + torch.nn.quantized.dynamic ---------------------------- diff --git a/torch/nn/qat/modules/normalization.py b/torch/nn/qat/modules/normalization.py index 13c6a4f15843..c16c162aaae9 100644 --- a/torch/nn/qat/modules/normalization.py +++ b/torch/nn/qat/modules/normalization.py @@ -47,8 +47,8 @@ def from_float(cls, mod, qconfig=None): class InstanceNorm1d(nn.InstanceNorm1d): r""" - A InstanceNorm1d module attached with FakeQuantize modules for both output - activation and weight, used for quantization aware training. + A InstanceNorm1d module attached with FakeQuantize modules for output + activation, used for quantization aware training. Similar to `torch.nn.InstanceNorm1d`, with FakeQuantize modules initialized to default. @@ -92,8 +92,8 @@ def from_float(cls, mod, qconfig=None): class InstanceNorm2d(nn.InstanceNorm2d): r""" - A InstanceNorm2d module attached with FakeQuantize modules for both output - activation and weight, used for quantization aware training. + A InstanceNorm2d module attached with FakeQuantize modules for output + activation, used for quantization aware training. Similar to `torch.nn.InstanceNorm2d`, with FakeQuantize modules initialized to default. @@ -137,8 +137,8 @@ def from_float(cls, mod, qconfig=None): class InstanceNorm3d(nn.InstanceNorm3d): r""" - A InstanceNorm3d module attached with FakeQuantize modules for both output - activation and weight, used for quantization aware training. + A InstanceNorm3d module attached with FakeQuantize modules for output + activation, used for quantization aware training. Similar to `torch.nn.InstanceNorm3d`, with FakeQuantize modules initialized to default. diff --git a/torch/nn/quantized/modules/normalization.py b/torch/nn/quantized/modules/normalization.py index 43619942830b..a54ca5f1cc44 100644 --- a/torch/nn/quantized/modules/normalization.py +++ b/torch/nn/quantized/modules/normalization.py @@ -76,7 +76,12 @@ def from_float(cls, mod): return new_mod class InstanceNorm1d(torch.nn.InstanceNorm1d): - r"""This is the quantized version of `torch.nn.InstanceNorm1d`. + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + """ def __init__(self, num_features, weight, bias, scale, zero_point, eps=1e-5, momentum=0.1, affine=False, @@ -106,7 +111,12 @@ def from_float(cls, mod): return new_mod class InstanceNorm2d(torch.nn.InstanceNorm2d): - r"""This is the quantized version of `torch.nn.InstanceNorm2d`. + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + """ def __init__(self, num_features, weight, bias, scale, zero_point, eps=1e-5, momentum=0.1, affine=False, @@ -136,7 +146,12 @@ def from_float(cls, mod): return new_mod class InstanceNorm3d(torch.nn.InstanceNorm3d): - r"""This is the quantized version of `torch.nn.InstanceNorm2d`. + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + """ def __init__(self, num_features, weight, bias, scale, zero_point, eps=1e-5, momentum=0.1, affine=False,