In [18]:
import torch

def round_pass(x):
    y = x.round()
    y_grad = x
    return y.detach() - y_grad.detach() + y_grad

class Quantizer():
    def __init__(self, N_bits: int, type: str = "per_tensor",  signed: bool = True, symmetric: bool = True):
        super().__init__()
            
        self.N_bits = N_bits
        self.signed = signed
        self.symmetric = symmetric
        self.q_type = type
        # self.eps = torch.iinfo(dtype).eps
        # self.minimum_range = torch.iinfo(dtype).eps
        if self.N_bits is None:
            return 

        if self.signed:
            self.Qn = - 2 ** (self.N_bits - 1)
            self.Qp = 2 ** (self.N_bits - 1) - 1
        else:
            self.Qn = 0
            self.Qp = 2 ** self.N_bits - 1

    def __call__(self, x):  
        return self.forward(x)

    def forward(self, x): 
        if self.N_bits is None:
            return x, 1

        if self.symmetric:
            if self.q_type == 'per_tensor': 
                max_x = x.abs().max().detach()
            elif self.q_type == 'per_token': 
                max_x = x.abs().amax(dim=-1, keepdim=True).detach()
            elif self.q_type == 'per_channel': 
                max_x = x.abs().amax(dim=0, keepdim=True).detach()

            print(max_x)
            scale = max_x / self.Qp
            x = x / scale 
            x = round_pass(x.clamp_(self.Qn, self.Qp)) 
            
        else: #Asymmetric
            min_x = x.min().detach()
            max_x = x.max().detach()
            range_x = (max_x - min_x).detach().clamp_(min=self.minimum_range)
            scale = range_x / (self.Qp - self.Qn)

            zero_point = torch.round((min_x / scale) - self.Qn)

            x = (x / scale) + zero_point
            x = round_pass(x.clamp_(self.Qn, self.Qp))

        return x, scale

quantizer = Quantizer (4, 'per_channel')

x = torch.rand(3,5) 
print("x",x)

q_x, s_x = quantizer(x)
print("q_x", q_x)
print("s_x", s_x)

x tensor([[0.7406, 0.0414, 0.5075, 0.0692, 0.8851],
        [0.1557, 0.5919, 0.0637, 0.9499, 0.7746],
        [0.0336, 0.8756, 0.5281, 0.3327, 0.2997]])
tensor([[0.7406, 0.8756, 0.5281, 0.9499, 0.8851]])
q_x tensor([[7., 0., 7., 1., 7.],
        [1., 5., 1., 7., 6.],
        [0., 7., 7., 2., 2.]])
s_x tensor([[0.1058, 0.1251, 0.0754, 0.1357, 0.1264]])


In [9]:
pip install torch

Collecting torch
  Downloading torch-2.4.1-cp38-cp38-manylinux1_x86_64.whl (797.1 MB)
[K     |████████████████████████████████| 797.1 MB 28 kB/s s eta 0:00:01
[?25hCollecting nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64"
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[K     |████████████████████████████████| 23.7 MB 10.8 MB/s eta 0:00:01
[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64"
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
[K     |████████████████████████████████| 14.1 MB 10.5 MB/s eta 0:00:01
[?25hCollecting filelock
  Downloading filelock-3.16.1-py3-none-any.whl (16 kB)
Collecting nvidia-curand-cu12==10.3.2.106; platform_system == "Linux" and platform_machine == "x86_64"
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
[K     |████████████████████