In [7]:
# ──────────────────────────────────────────────────────────────
"""
Input: a batch of four random feature vectors (size 6 in the toy run or 784 in the MNIST run).
Output: 3-class logits (toy) or 10-class logits (MNIST) **after INT-8 activation quantisation**.
Values are multiples of 0.047 (the INT-8 step when `rng = 6`) because activations are uniformly quantised.

### Key take-aways

1. Posit-16 behaves almost like FP32
   Same accuracy, no overflow/denormal fuss, half the memory.
2. FP16 and BF16 need a bit more care (learning-rate, scaling) but work.
3. **8-bit activations** still converge on an easy task, just lose a few percent accuracy.
4. The demo gives you a plug-and-play template: swap one line (`ACTIVE_Q = ...`) to benchmark any other precision scheme you invent.

This program trains a tiny two-layer neural network on the MNIST digits, but lets you swap the numeric format used inside the network with one line.
It first defines quick “quantisers” for FP32, FP16, BF16, Posit-16, Posit-8, FP8 and INT-8. Whichever quantiser you set in ACTIVE_Q is applied to every weight, bias and activation during the forward pass, while gradients flow normally thanks to a straight-through estimator.
The model then runs three training epochs, printing each epoch’s training loss, test accuracy, and time, to show how different precisions affect learning.
"""
# ──────────────────────────────────────────────────────────────
import torch, torch.nn as nn, torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np, struct, math, time, random, os
torch.manual_seed(0);  random.seed(0);  np.random.seed(0)

# ──────────────────────────────────────────────────────────────
# 1.  Quantiser zoo
#     (same helpers as before, plus wrapper make_quant)
# ──────────────────────────────────────────────────────────────
def posit_quant_np(x, nbits=16, es=2):
    x = np.asarray(x, np.float32)
    useed = 2 ** (2 ** es)
    max_pos = np.finfo(np.float32).max if nbits > 12 else useed ** ((2 ** (nbits-2))-1)
    x = np.clip(x, -max_pos, max_pos)
    sign = np.sign(x);  mag = np.abs(x) + 1e-30
    log2_mag = np.log2(mag)
    scale = 2 ** (nbits - 2)
    logq = np.round(log2_mag * scale) / scale
    return sign * np.exp2(logq)

def fp8_e4m3_np(x):
    x = np.asarray(x, np.float32)
    sign = np.sign(x);  mag = np.abs(x) + 1e-30
    exp  = np.floor(np.log2(mag)).astype(int)
    mant = mag / np.exp2(exp) - 1.0
    mant_q = np.round(mant * 8) / 8
    exp_q  = np.clip(exp, -8, 7)
    out = sign * (1.0 + mant_q) * np.exp2(exp_q)
    out[exp < -8] = 0.0;  out[exp > 7] = sign[exp > 7]*np.exp2(8)
    return out

def int_uniform_np(bits, rng=6.0):
    step = (2*rng)/(2**bits-1)
    return lambda x: np.clip(np.round((x+rng)/step)*step - rng, -rng, rng)

def f16_np(x):   return np.array(x, dtype=np.float16).astype(np.float32)
def bf16_np(x):  return struct.unpack(">f",(struct.pack(">I", (np.asarray(x,np.float32).view(np.uint32)>>16)<<16)))[0]

class ErrorProbe:
    """
    Hooks into any tensor you pass it and records
    1) mean relative error
    2) average bits of precision
    """
    def __init__(self, name, baseline_tensor):
        self.name = name
        self.x_ref = baseline_tensor.detach().cpu().numpy()

    def __call__(self, quantised_tensor):
        q = quantised_tensor.detach().cpu().numpy()
        rel = np.abs(q - self.x_ref) / (np.abs(self.x_ref) + 1e-30)
        bits = -np.log2(rel + 1e-30)
        self.rel_err = rel.mean()
        self.bits    = bits.mean()
        return quantised_tensor          # passthrough

class Net(nn.Module):
    def __init__(self, q_fn):
        super().__init__()
        self.q = q_fn
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x, probes=None):
        # baseline (FP32) copy for comparison
        if probes is not None:
            probes['inp'] = ErrorProbe('inp', x)

        x_q = self.q(x);                   # quantise input
        if probes is not None:
            probes['inp'](x_q)

        h = F.relu(self.q(self.fc1(self.q(x_q))))
        if probes is not None:
            probes['h'] = ErrorProbe('h', h.detach())   # baseline h
            probes['h'](h)

        out = self.q(self.fc2(self.q(h)))
        if probes is not None:
            probes['out'] = ErrorProbe('out', out.detach())
            probes['out'](out)
        return out

class STE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, t, fn): return torch.from_numpy(fn(t.cpu().numpy())).to(t.device)
    @staticmethod
    def backward(ctx, g):    return g, None
def make_q(fn): return lambda t: STE.apply(t, fn)

QUANTISERS = {
    "FP32" : lambda t: t,                         # identity
    "Posit16": make_q(lambda x: posit_quant_np(x,16,2)),
    "FP16"  : make_q(f16_np),
    "BF16"  : make_q(bf16_np),
    "INT8"  : make_q(int_uniform_np(8)),
    "Posit8" : make_q(lambda x: posit_quant_np(x,8,1)),
    "FP8"    : make_q(fp8_e4m3_np),
}

ACTIVE_Q = "FP32"          # <<< change to FP32 / FP16 / BF16 / INT8

# ──────────────────────────────────────────────────────────────
# 2.  Dataset (MNIST 28×28 → flatten 784)
# ──────────────────────────────────────────────────────────────
tr = transforms.Compose([transforms.ToTensor(),
                         transforms.Lambda(lambda x: x.view(-1))])
train_ds = datasets.MNIST(root=".", train=True,  transform=tr, download=True)
test_ds  = datasets.MNIST(root=".", train=False, transform=tr)

train_ld = torch.utils.data.DataLoader(train_ds, batch_size=256, shuffle=True)
test_ld  = torch.utils.data.DataLoader(test_ds, batch_size=1024)

# ──────────────────────────────────────────────────────────────
# 3.  Quantised MLP
# ──────────────────────────────────────────────────────────────
class QLinear(nn.Module):
    def __init__(self, in_f, out_f, q): super().__init__(); self.w=nn.Parameter(torch.randn(out_f,in_f)*0.02); self.b=nn.Parameter(torch.zeros(out_f)); self.q=q
    def forward(self,x): return self.q(F.linear(self.q(x), self.q(self.w), self.q(self.b)))

class Net(nn.Module):
    def __init__(self,q): super().__init__(); self.fc1=QLinear(784,256,q); self.fc2=QLinear(256,10,q); self.q=q
    def forward(self,x): return self.fc2(F.relu(self.fc1(x)))

net = Net(QUANTISERS[ACTIVE_Q]).to("cpu")
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
ce  = nn.CrossEntropyLoss()

# ──────────────────────────────────────────────────────────────
# 4.  Training 3 epochs
# ──────────────────────────────────────────────────────────────
for epoch in range(3):
    net.train();  t0=time.time();  loss_cum=0
    for xb,yb in train_ld:
        opt.zero_grad()
        loss=ce(net(xb), yb); loss.backward(); opt.step()
        loss_cum+=loss.item()*xb.size(0)

# probes = {}
# # _ = net(batch, probes=probes)
# # for name, p in probes.items():
# #     print(f"{name}: rel-err {p.rel_err:.2e} | bits {p.bits:.2f}")
# for epoch in range(3):
#     net.train(); loss_cum = 0
#     for i, (xb, yb) in enumerate(train_ld, 1):
#         ...
#         if i % 50 == 0:   # every 50 mini-batches
#             print(f"epoch {epoch+1} batch {i} / {len(train_ld)}")

    # test
    net.eval(); acc=0
    with torch.no_grad():
        for xb,yb in test_ld:
            pred=net(xb).argmax(1);  acc+=(pred==yb).sum().item()
    print(f"[{ACTIVE_Q}] epoch {epoch+1}: "
          f"train CE {loss_cum/len(train_ds):.3f}  "
          f"test acc {acc/len(test_ds):.4f}  "
          f"time {time.time()-t0:.1f}s")


"""
Output
[Posit8] epoch 1: train CE 0.492  test acc 0.9296  time 11.6s
[Posit8] epoch 2: train CE 0.206  test acc 0.9515  time 11.4s
[Posit8] epoch 3: train CE 0.146  test acc 0.9589  time 11.5s

[Posit16] epoch 1: train CE 0.492  test acc 0.9299  time 11.4s
[Posit16] epoch 2: train CE 0.206  test acc 0.9516  time 11.8s
[Posit16] epoch 3: train CE 0.146  test acc 0.9600  time 11.5s
\nOutput\n[Posit8] epoch 1: train CE 0.492  test acc 0.9296  time 11.6s\n[Posit8] epoch 2: train CE 0.206  test acc 0.9515  time 11.4s\n[Posit8] epoch 3: train CE 0.146  test acc 0.9589  time 11.5s\n\n\n
"""


[FP32] epoch 1: train CE 0.492  test acc 0.9299  time 8.0s
[FP32] epoch 2: train CE 0.206  test acc 0.9514  time 8.7s
[FP32] epoch 3: train CE 0.146  test acc 0.9598  time 8.7s


'\nOutput\n[Posit8] epoch 1: train CE 0.492  test acc 0.9296  time 11.6s\n[Posit8] epoch 2: train CE 0.206  test acc 0.9515  time 11.4s\n[Posit8] epoch 3: train CE 0.146  test acc 0.9589  time 11.5s\n\n[Posit16] epoch 1: train CE 0.492  test acc 0.9299  time 11.4s\n[Posit16] epoch 2: train CE 0.206  test acc 0.9516  time 11.8s\n[Posit16] epoch 3: train CE 0.146  test acc 0.9600  time 11.5s\n\nOutput\n[Posit8] epoch 1: train CE 0.492  test acc 0.9296  time 11.6s\n[Posit8] epoch 2: train CE 0.206  test acc 0.9515  time 11.4s\n[Posit8] epoch 3: train CE 0.146  test acc 0.9589  time 11.5s\n\n\n\n'