# bag of holding (aka. quantization)

## setup

In [1]:
import torch as t
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import triton
import triton.language as tl
import wandb
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from datetime import datetime

device = 'cuda' if t.cuda.is_available() else 'cpu'

In [2]:
batch_size = 640

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [3]:
class Mnist(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(1024, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        # pad x to [..., 32, 32] because tritron's tl.dot makes life hard if it doesn't like the shape >_<'
        x = nn.functional.pad(x, (2, 2, 2, 2))
        return self.mlp(x.flatten(start_dim=-3))

In [4]:
@t.no_grad()
def eval(model, dataloader=test_loader):
    model.eval()
    loss, accuracy, count = 0, 0, 0
    for x, y in dataloader:
        logits = model(x.to(device))
        loss += F.cross_entropy(logits, y.to(device)).item()
        accuracy += (logits.argmax(1) == y.to(device)).sum().item()
        count += len(x)
    model.train()
    return loss / count, accuracy / count

def eval_dict(model, dataloader=test_loader, name='test'):
    loss, accuracy = eval(model, dataloader=dataloader)
    return {f'{name}_loss': loss, f'{name}_accuracy': accuracy}

In [5]:
def train(model, opt, dataloader=train_loader, wnb=True, epochs=1000):
    model.train()
    if wnb: wandb.init()
    for epoch in tqdm(range(epochs)):
        for x, y in dataloader:
            logits = model(x.to(device))
            loss = F.cross_entropy(logits, y.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
        if epoch % 10 == 0 or epoch == epochs - 1:
            if wnb: wandb.log({'epoch': epoch} | eval_dict(model) | eval_dict(model, dataloader=train_loader, name='train'))
    if wnb: wandb.finish()

In [None]:
mnist = Mnist().to(device)
opt = optim.AdamW(mnist.parameters(), lr=3e-4)
train(mnist, opt, epochs=10)

In [440]:
t.save(mnist.state_dict(), f'weights/mnist-1024-128-64-10_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')

## kernel

### tiled triton int8 matmul

In [6]:
# heavily inspired by: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html

@triton.jit
def matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr,
        BLOCK_SIZE_N: tl.constexpr,
        BLOCK_SIZE_K: tl.constexpr,
        GROUP_SIZE_M: tl.constexpr,
):
    ''' Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    '''
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
    # See above `Pointer Arithmetic` section for details
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # Load the next block of A and B, generate a mask by checking the K dimension.
        # If it is out of bounds, set it to 0.
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        # We accumulate along the K dimension.
        accumulator = tl.dot(a, b, accumulator)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    c = accumulator.to(tl.int32)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

def matmul_i8i32(a, b):
    ''' matmul for int8 with int32 accumulators '''
    assert a.shape[1] == b.shape[0], 'incompatible dimensions'
    assert a.is_contiguous(), 'matrix A must be contiguous'
    M, K = a.shape
    K, N = b.shape
    c = t.empty((M, N), device=a.device, dtype=t.int32)
    # launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_SIZE_M=32,
        BLOCK_SIZE_N=32,
        BLOCK_SIZE_K=32,
        GROUP_SIZE_M=8,
    )
    return c

a = t.randint(-128, 128, (512, 1024), device='cuda', dtype=t.int8)
b = t.randint(-128, 128, (1024, 2048), device='cuda', dtype=t.int8)
triton_output = matmul_i8i32(a, b)
torch_output = t.matmul(a.to(t.float32), b.to(t.float32))

print(f'atol: {t.max(t.abs(triton_output - torch_output))}')
print(f'diff > 1e-5: {t.sum(t.abs(triton_output - torch_output) > 1e-5)} / {triton_output.numel()}')
print(f'allclose: {triton_output.float().allclose(torch_output)}')

atol: 0.0
diff > 1e-5: 0 / 1048576
allclose: True


### alternative kernel

In [7]:
from torch._higher_order_ops.out_dtype import out_dtype

a = t.randint(-128, 128, (512, 1024), device='cuda', dtype=t.int8)
b = t.randint(-128, 128, (1024, 2048), device='cuda', dtype=t.int8)
triton_output = out_dtype(t.ops.aten.mm.default, t.int32, a, b)
torch_output = t.matmul(a.to(t.float32), b.to(t.float32))

print(f'atol: {t.max(t.abs(triton_output - torch_output))}')
print(f'diff > 1e-5: {t.sum(t.abs(triton_output - torch_output) > 1e-5)} / {triton_output.numel()}')
print(f'allclose: {triton_output.float().allclose(torch_output)}')

atol: 0.0
diff > 1e-5: 0 / 1048576
allclose: True


## symmetric quantization

In [7]:
def squantize(weights, bits=8):
    assert bits <= 8 # keep my life simple
    maxi = weights.abs().max()
    q_maxi = 2 ** (bits - 1) - 1
    scale = maxi / q_maxi
    quantized_weights = t.clamp(t.round(weights / scale), -q_maxi, q_maxi).to(t.int8)
    return quantized_weights, scale

def sunquantize(quantized_weights, scale):
    return quantized_weights * scale

In [9]:
def test_round_trip(weights, bits=8, threshold=0.02):
    ''' this is flaky at best '''
    qw, scale = squantize(weights, bits=bits)
    w = sunquantize(qw, scale)
    assert weights.allclose(w, atol=threshold), f'{(weights - w).abs().max()=}'

def test_more_bits_do_better(weights):
    ''' flaky too '''
    def encode_decode_max_abs_error(bits):
        qw, scale = squantize(weights, bits=bits)
        w = sunquantize(qw, scale)
        return (weights - w).abs().max()
    error1 = encode_decode_max_abs_error(bits=4)
    error2 = encode_decode_max_abs_error(bits=6)
    error3 = encode_decode_max_abs_error(bits=8)
    assert error1 > error2 > error3, f'{error1=}, {error2=}, {error3=}'

for _ in range(100):
    xs = t.randn(100, 100)
    test_round_trip(xs)
    test_more_bits_do_better(xs)

In [10]:
def squantization_stats(weights):
    qw, scale = squantize(weights, bits=8)
    w = sunquantize(qw, scale)
    print(f'mean: {weights.mean()} {w.mean()}')
    print(f'max: {weights.max()} {w.max()}')
    print(f'min: {weights.min()} {w.min()}')
    print(f'atol: {(weights - w).abs().max()}')
    print(f'mse: {(weights - w).pow(2).mean()}')

xs = t.rand(10000)
squantization_stats(xs)

mean: 0.5008046627044678 0.5007808804512024
max: 0.9999886751174927 0.9999886155128479
min: 0.00014400482177734375 0.0
atol: 0.003935694694519043
mse: 5.147882347955601e-06


$$
\begin{align}
x &= x_{quant} * scale_x \\
w &= w_{quant} * scale_w \\
y &= x @ w \\
  &= (x_{quant} * scale_x) @ (w_{quant} * scale_w) \\
  &= scale_x * scale_w * x_{quant} @ w_{quant}
\end{align}
$$

In [11]:
def squantization_matmul_stats(x, w):
    y = x @ w
    xq, scale_x = squantize(x)
    wq, scale_w = squantize(w)
    yq = matmul_i8i32(xq, wq)
    scale_y = scale_x * scale_w
    yy = sunquantize(yq, scale_y)
    print(f'mean: {y.mean()} {yy.mean()}')
    print(f'max: {y.max()} {yy.max()}')
    print(f'min: {y.min()} {yy.min()}')
    print(f'atol: {(y - yy).abs().max()}')
    print(f'mse: {(y - yy).pow(2).mean()}')
    return y, yy

N=100
x = t.randn(N, N, device=device)
w = t.randn(N, N, device=device)
_, _ = squantization_matmul_stats(x, w)

mean: -0.08168823271989822 -0.08109944313764572
max: 37.009124755859375 36.74298858642578
min: -35.65652084350586 -35.63026809692383
atol: 0.5931930541992188
mse: 0.017821861431002617


In [12]:
class SQuantizedLinear(nn.Module):
    def __init__(self, linear):
        super().__init__()
        w, scale = squantize(linear.weight.T)
        self.w = w
        self.register_buffer('w_matrix', self.w)
        self.scale_w = scale
        self.bias = linear.bias # keep bias unquantized

    def forward(self, x):
        xq, scale_x = squantize(x)
        yq = matmul_i8i32(xq, self.w)
        scale_y = self.scale_w * scale_x
        y = sunquantize(yq, scale_y)
        y = y + self.bias
        return y

In [13]:
def squantize_module(module):
    for name, node in module.named_children():
        if isinstance(node, nn.Linear):
            setattr(module, name, SQuantizedLinear(node))
        else:
            squantize_module(node)

weights = t.load('weights/mnist-1024-128-64-10_2024-08-12_00h35.pt')
mnist_base = Mnist().to(device)
mnist_base.load_state_dict(weights)
mnist = Mnist().to(device)
mnist.load_state_dict(weights)
squantize_module(mnist)

print(f'base fp32: {eval(mnist_base)[1]}')
print(f'quantized int8: {eval(mnist)[1]}')

base fp32: 0.9468
quantized int8: 0.9464


In [17]:
def model_size(model):
    return sum(p.numel() * p.element_size() for p in model.parameters()) + \
        sum(b.numel() * b.element_size() for b in model.buffers())

print(f'base fp32 size: {model_size(mnist_base)}')
print(f'quantized int8 size: {model_size(mnist)}')
print(f'quantize / base ratio: {model_size(mnist) / model_size(mnist_base):.2f}')

base fp32 size: 560424
quantized int8 size: 140712
quantize / base ratio: 0.25


## asymmetric quantization

In [18]:
def quantize(weights, bits=8):
    ''' using the min-max strategy, this is vulnerable to outliers '''
    assert bits <= 8 # keep my life simple
    maxi = weights.max()
    mini = weights.min()
    qmaxi = 2 ** bits - 1
    scale = (maxi - mini) / qmaxi
    zero = int(t.round(-mini / scale))
    quantized_weights = t.clamp(t.round(weights / scale) + zero, 0, qmaxi).to(t.uint8)
    return quantized_weights, scale, zero

def unquantize(quantized_weights, scale, zero):
    quantized_weights = quantized_weights.to(t.int32)
    return (quantized_weights - zero) * scale

In [19]:
def test_bounds(weights, bits=8):
    qw, _, _ = quantize(weights, bits=bits)
    assert qw.min() == 0
    assert qw.max() == 2**bits - 1

def test_round_trip(weights, bits=8, threshold=0.02):
    ''' this is flaky at best '''
    qw, scale, zero = quantize(weights, bits=bits)
    w = unquantize(qw, scale, zero)
    assert weights.allclose(w, atol=threshold), f'{(weights - w).abs().max()=}'

def test_more_bits_do_better(weights):
    ''' flaky too '''
    def encode_decode_max_abs_error(bits):
        qw, scale, zero = quantize(weights, bits=bits)
        w = unquantize(qw, scale, zero)
        return (weights - w).abs().max()
    error1 = encode_decode_max_abs_error(bits=4)
    error2 = encode_decode_max_abs_error(bits=6)
    error3 = encode_decode_max_abs_error(bits=8)
    assert error1 > error2 > error3, f'{error1=}, {error2=}, {error3=}'

for _ in range(100):
    xs = t.randn(100, 100)
    test_bounds(xs)
    test_round_trip(xs)
    test_more_bits_do_better(xs)

In [20]:
def quantization_stats(weights):
    qw, scale, zero = quantize(weights, bits=8)
    w = unquantize(qw, scale, zero)
    print(f'mean: {weights.mean()} {w.mean()}')
    print(f'max: {weights.max()} {w.max()}')
    print(f'min: {weights.min()} {w.min()}')
    print(f'atol: {(weights - w).abs().max()}')
    print(f'mse: {(weights - w).pow(2).mean()}')

xs = t.rand(10000)
quantization_stats(xs)

mean: 0.49838921427726746 0.498381644487381
max: 0.9999037981033325 0.9997287392616272
min: 0.00017511844635009766 0.0
atol: 0.0019600391387939453
mse: 1.2681715588769293e-06


$$
\begin{align}
x &= (x_{quant} - zero_x) * scale_x \\
w &= (w_{quant} - zero_w) * scale_w \\
y &= x @ w \\
  &= ((x_{quant} - zero_x) * scale_x) @ ((w_{quant} - zero_w) * scale_w) \\
  &= scale_x * scale_w * (x_{quant} - zero_x) @ (w_{quant} - zero_w) \\
  &= scale_x * scale_w * (x_{quant} @ w_{quant} - x_{quant} @ zero_w - zero_x @ w_{quant} + zero_x @ zero_w) \\
\end{align}
$$

In [21]:
def fake_uint8_int32_matmul(x, w):
    ''' pretend we have a uint8 matmul with int32 accumulator '''
    x = x.to(t.int32)
    w = w.to(t.int32)
    assert (x >= 0).all()
    assert (w >= 0).all()
    y = t.matmul(x, w)
    assert (y >= 0).all()
    return y

def asymmetric_quant_matmul(xq, wq, scale_x, scale_w, zero_x, zero_w):
    '''
    NOTE: This not efficient, and defeat the purpose of quantization.
    I just want to test the correctness of the formula.
    Original formula lives in https://arxiv.org/pdf/2106.08295 equation (13)
    '''
    # this kills the purpose of quantization.
    xq = xq.to(t.int32)
    wq = wq.to(t.int32)

    # several tricks here, the .sum() make the dimensions match
    # and note the xq.shape[1] here that doesn't appear in the paper!!
    unscaled_y = (
        fake_uint8_int32_matmul(xq, wq)
        - xq.sum(1, keepdim=True) * zero_w
        - zero_x * wq.sum(0, keepdim=True)
        + xq.shape[1] * zero_x * zero_w) # here
    y = scale_x * scale_w * unscaled_y
    return y

In [22]:
def quantization_matmul_stats(x, w):
    y = x @ w
    xq, scale_x, zero_x = quantize(x)
    wq, scale_w, zero_w = quantize(w)
    yy = asymmetric_quant_matmul(xq, wq, scale_x, scale_w, zero_x, zero_w)

    print(f'mean: {y.mean()} {yy.mean()}')
    print(f'max: {y.max()} {yy.max()}')
    print(f'min: {y.min()} {yy.min()}')
    print(f'atol: {(y - yy).abs().max()}')
    print(f'mse: {(y - yy).pow(2).mean()}')
    return y, yy

N=100
x = t.randn(N, N)
w = t.randn(N, N)
_, _ = quantization_matmul_stats(x, w)

mean: 0.017011694610118866 0.01569465361535549
max: 43.04981994628906 43.039249420166016
min: -47.28977966308594 -47.233360290527344
atol: 0.49615478515625
mse: 0.017323683947324753
