In [1]:
import numpy as np
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

def weight_quantization(b):

    def uniform_quant(x, b):
        xdiv = x.mul((2 ** b - 1))
        xhard = xdiv.round().div(2 ** b - 1)  
        #print('uniform quant bit: ', b)
        return xhard

    class quant(torch.autograd.Function):      
        @staticmethod
        def forward(ctx, input, alpha):
            input.div_(alpha)                          # weights are first divided by alpha
            input_c = input.clamp(min=-1, max=1)       # then clipped to [-1,1]
            sign = input_c.sign()
            input_abs = input_c.abs()
            input_q = uniform_quant(input_abs, b).mul(sign)
            ctx.save_for_backward(input, input_q)
            input_q = input_q.mul(alpha)               # rescale to the original range
            return input_q

        @staticmethod
        def backward(ctx, grad_output):
            grad_input = grad_output.clone()             # grad for weights will not be clipped
            input, input_q = ctx.saved_tensors
            i = (input.abs()>1.).float()     # >1 means clipped regime 
            sign = input.sign()              # output matrix is a form of [+1, -1, -1, +1, ...]
            grad_alpha = (grad_output*sign*i).sum()
            # above line, if i = True,  and sign = +1, "grad_alpha = grad_output * 1"
            
            return grad_input, grad_alpha    # Because we have two inputs, outputs two gradients.

    return quant().apply  # such as "relu = MyReLU().apply" in the above cell, the function quant itself is passed.
                        # to see, please print "MyReLU().apply" in the above cell



class weight_quantize_fn(nn.Module):
    def __init__(self, w_bit):
        super(weight_quantize_fn, self).__init__()
        self.w_bit = w_bit-1
        self.weight_q = weight_quantization(b=self.w_bit)
        self.wgt_alpha = torch.nn.Parameter(torch.tensor(3.0))

    def forward(self, weight):
        #mean = weight.data.mean()  # normalization provides better quantization accuracy
        #std = weight.data.std()
        #weight = weight.add(-mean).div(std)      # weights normalization
        weight_q = self.weight_q(weight, self.wgt_alpha)
        
        return weight_q

In [2]:
weight_quant_fn = weight_quantize_fn(w_bit=5)
a  = torch.tensor([1.5,-1.,2.5], requires_grad=True)
#a  = torch.tensor([1.5,-1.,3.5], requires_grad=True)

a_q = weight_quant_fn(a)
print(a_q)

tensor([ 1.6000, -1.0000,  2.4000], grad_fn=<quantBackward>)


In [3]:
print(weight_quant_fn.wgt_alpha)
print(weight_quant_fn.wgt_alpha.grad)

Parameter containing:
tensor(3., requires_grad=True)
None


In [4]:
c = a_q.sum()
c.backward()


In [5]:
print(weight_quant_fn.wgt_alpha.grad)
print(a.grad)

tensor(0.)
tensor([1., 1., 1.])
