<a href="https://colab.research.google.com/github/samitha278/transformer-optim/blob/main/quantization_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Quantization

In [1]:
import torch
import torch.nn as nn

## Asymmetric quantization

In [25]:
x = torch.randn((2,8))
print("x:\n", x)

xmin = torch.min(x, dim=-1).values      # shape (2,)
xmax = torch.max(x, dim=-1).values

n = 4  # bits
levels = 2**n - 1

s = (xmax - xmin) / levels              # (2,)
z = torch.round(-xmin / s)              # (2,)

# reshape for broadcasting
s = s.unsqueeze(-1)                     # (2,1)
z = z.unsqueeze(-1)                     # (2,1)
xmin_q = torch.zeros_like(z)
xmax_q = torch.full_like(z, levels)

x_q = torch.floor(x / s + z)
x_q = torch.clamp(x_q, xmin_q, xmax_q)

print("x_q:\n", x_q)


x:
 tensor([[ 0.1639, -0.2898, -0.4479, -0.7410, -0.3116,  1.5138, -0.3913, -0.1784],
        [-0.6064, -0.2402,  0.3083, -0.9694, -0.7455, -0.2932, -0.6298, -0.1996]])
x_q:
 tensor([[ 6.,  3.,  2.,  0.,  2., 15.,  2.,  3.],
        [ 3.,  8., 14.,  0.,  2.,  7.,  3.,  8.]])


### Dequantize

In [27]:
s*(x_q-z)

tensor([[ 0.1503, -0.3006, -0.4510, -0.7516, -0.4510,  1.5032, -0.4510, -0.3006],
        [-0.6815, -0.2555,  0.2555, -0.9370, -0.7666, -0.3407, -0.6815, -0.2555]])

## Symmetric quantization

In [28]:
x = torch.randn((2,8))
print("x:\n", x)

xmin = torch.min(x, dim=-1).values      # shape (2,)
xmax = torch.max(x, dim=-1).values

n = 4  # bits
levels = 2**(n-1) - 1

s = torch.abs(xmax) / levels              # (2,)


# reshape for broadcasting
s = s.unsqueeze(-1)                     # (2,1)
xmin_q =  torch.full_like(s, -levels)
xmax_q = torch.full_like(s, levels)

x_q = torch.floor(x / s)
x_q = torch.clamp(x_q, xmin_q, xmax_q)

print("x_q:\n", x_q)


x:
 tensor([[ 1.2749, -1.0468,  1.1475, -1.5243, -2.5364, -0.2713, -1.7483, -0.5429],
        [-0.2785, -0.5994, -2.6728, -0.6100,  0.2030,  0.9089,  1.0148,  0.2108]])
x_q:
 tensor([[ 7., -6.,  6., -7., -7., -2., -7., -3.],
        [-2., -5., -7., -5.,  1.,  6.,  7.,  1.]])


### Dequantize

In [29]:
s * x_q

tensor([[ 1.2749, -1.0928,  1.0928, -1.2749, -1.2749, -0.3643, -1.2749, -0.5464],
        [-0.2899, -0.7248, -1.0148, -0.7248,  0.1450,  0.8698,  1.0148,  0.1450]])