Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions test/quantization/test_qat_backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
import copy
from torch.quantization.fake_quantize import FakeQuantize
from torch.quantization.fake_quantize_backward import _FakeQuantizeWithBackward
from torch.testing._internal.common_utils import TestCase
from .test_workflow_module import to_tensor
from hypothesis import given
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()

TORCH_RANDOM_SEED = 1776
tolerance = 1e-6

class TestQATBackward(TestCase):

@given(quantize_forward=st.booleans(),
quantize_backward=st.booleans(),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_forward_and_backward(self, quantize_forward, quantize_backward, X):
r"""Tests the forward and backward path of the FakeQuantizeWithBackward module
"""
def fake_quantize_tensor(X):
scale, zero_point = torch._choose_qparams_per_tensor(X, reduce_range=False)
return torch.fake_quantize_per_tensor_affine(X, scale, zero_point, 0, 255)

device = 'cpu' # CUDA support to come in a future PR

torch.manual_seed(TORCH_RANDOM_SEED)
X, (_, _, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max

fake_with_backward = _FakeQuantizeWithBackward(quant_min=quant_min,
quant_max=quant_max,
quantize_forward=quantize_forward,
quantize_backward=quantize_backward)

fake_reference = FakeQuantize(quant_min=quant_min, quant_max=quant_max)

X = to_tensor(X, device)
X.requires_grad_()
X_ref = copy.deepcopy(X)
X_ref.requires_grad_()

Y = fake_with_backward(X)

# If quantize_forward is false, fake_with_backward(X) should be identity
# Otherwise, should be quantized
Y_ref = fake_reference(X_ref) if quantize_forward else X_ref

# Make sure that the forward functions as expected
self.assertEqual(Y, Y_ref, rtol=tolerance, atol=tolerance)

dout = torch.rand(X.shape, dtype=torch.float).to(device)

Y.backward(dout)
Y_ref.backward(dout)

# Behavior of backward depends on quantize_forward and quantize_backward
# If both are true, should be quantized output of the normal FakeQuantize backward
# If just quantize_backward is true, should only quantize the gradient
# If just quantize_forward is true, should be the normal FakeQuantize backward
# If both are false, should be identity
dX_ref = fake_quantize_tensor(X_ref.grad) if quantize_backward else X_ref.grad

# Check the gradients to make sure the backward functions as expected
self.assertEqual(X.grad, dX_ref, rtol=tolerance, atol=tolerance)
3 changes: 3 additions & 0 deletions test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,8 @@
# Equalization
from quantization.test_equalize import TestEqualizeEager # noqa: F401

# Experimental QAT backward tests
from quantization.test_qat_backward import TestQATBackward # noqa: F401

if __name__ == '__main__':
run_tests()
51 changes: 51 additions & 0 deletions torch/quantization/fake_quantize_backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from .fake_quantize import FakeQuantize
from .observer import MovingAverageMinMaxObserver

class _QuantizeBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, X):
return X

@staticmethod
def backward(ctx, grad_X):
scale, zero_point = torch._choose_qparams_per_tensor(grad_X, reduce_range=False)
grad_X = torch.fake_quantize_per_tensor_affine(grad_X, scale, zero_point, 0, 255)

return grad_X


class _FakeQuantizeWithBackward(FakeQuantize):
r""" Simulate the quantize and dequantize operations in training time.
See documentation for parent module torch.quantization.FakeQuantize.

* :attr:`quantize_backward` controls the application of fake quantization on tensor gradients in
the backward pass. This quantization is always done dynamically, and uses affine per-tensor
quantization on unsigned 8-bit ints.

Args:
observer (module): Module for observing statistics on input tensors and calculating scale
and zero-point.
quant_min (int): The minimum allowable quantized value.
quant_max (int): The maximum allowable quantized value.
quantize_forward (bool): If true, quantize on the forward pass. (default: True)
quantize_backward (bool): If true, quantize on the backward pass. (default: False)
observer_kwargs (optional): Arguments for the observer module

Attributes:
observer (Module): User provided module that collects statistics on the input tensor and
provides a method to calculate scale and zero-point.

"""
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
quantize_forward=True, quantize_backward=False, **observer_kwargs):
super(_FakeQuantizeWithBackward, self).__init__(observer, quant_min, quant_max, **observer_kwargs)
self.enable_fake_quant(quantize_forward)
self.quantize_backward = quantize_backward

def forward(self, X):
X = super(_FakeQuantizeWithBackward, self).forward(X)
if self.quantize_backward:
X = _QuantizeBackwardFunction.apply(X)
return X
28 changes: 28 additions & 0 deletions torch/quantization/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from collections import namedtuple
from .observer import *
from .fake_quantize import *
from .fake_quantize_backward import _FakeQuantizeWithBackward
import torch
import torch.nn as nn

class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
Expand Down Expand Up @@ -76,6 +78,7 @@ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
default_activation_only_qconfig = QConfig(activation=default_fake_quant,
weight=torch.nn.Identity)


def get_default_qconfig(backend='fbgemm'):
if backend == 'fbgemm':
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
Expand Down Expand Up @@ -104,3 +107,28 @@ def get_default_qat_qconfig(backend='fbgemm'):
else:
qconfig = default_qat_qconfig
return qconfig

def _get_default_qat_qconfig_backward(quant_forward, quant_backward):
fake_quant = _FakeQuantizeWithBackward.with_args(
observer=MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
quant_forward=quant_forward,
quant_backward=quant_backward,
reduce_range=True,
)

weight_fake_quant = _FakeQuantizeWithBackward.with_args(
observer=MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
quant_forward=quant_forward,
quant_backward=quant_backward,
reduce_range=False,
)

return QConfig(activation=fake_quant, weight=weight_fake_quant)