## Encoding a matrix in a single bitstream using ANS coding

snippets taken from `https://github.com/bamler-lab/webgl-entropy-coding`

Quantize weights using uniform grid quantization to 4 bits per weight. 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [167]:
np.random.seed(2025)
input_dim =    4096    # LLAMA 3 8B hidden dim 
output_dim = 14336  # and FFN Dimension (https://arxiv.org/pdf/2407.21783 Table 3)

weight_bits = 8 # TODO add different coarsness
prob_bits = 12


""""
If our weights are sampled iid from an standard normal distribution the 
"""
delta = 1/52 # TODO this value here needs to be tuned to avert an assertion error

w = np.random.randn(output_dim, input_dim)
quant_w = np.round(w / delta).astype(np.int8) # this requires manual tuning
w_min = quant_w.flatten().min()
# plt.hist(w.flatten()) # check that distribution of weights is not uniform

v, c = np.unique(quant_w, return_counts=True)
order = np.argsort(c)[::-1]
v = v[order]
c = c[order]

print(f"Unique Values: {len(v)}")
assert len(v) == 2**weight_bits
assert quant_w.flatten().astype(np.uint32).max() - w_min >= 2


Unique Values: 256


In [None]:
# creating 12 bit cdf for 
eps = (input_dim*output_dim /(1 << (prob_bits)))
print(eps)

# c12bit is a misnomer since also other bit sizes are supported
c12bit = np.maximum(np.round(c/ eps).astype(np.uint32), 1)

# normalizing c12bit model
excess = sum(c12bit) - (1 << prob_bits)
print(f"Excess: {excess}")
assert excess >= 0 and excess <= len(c)
if excess != 0:
    assert c12bit[excess - 1] > 1
c12bit[:excess] -= 1
assert c12bit.sum() == (1<< prob_bits)

# creating cdf for 
cdf = np.cumsum(c12bit)

# print(cdf)
ent = np.log2(input_dim*output_dim) - c @np.log2(c) /(input_dim*output_dim)
ce =  prob_bits - c @ np.log2(c12bit) / (c.sum())

print(f"Entropy of 4 bit  weights with full precision: {ent:.4f} bits.")
print(f"Cross entropy of 4 bit  weights using {prob_bits}-bit model: {ce:.4f} bits.")

13619.199999999999
Excess: 211
Entropy of 4 bit  weights with full precision: 7.6923 bits.
Cross entropy of 4 bit  weights using 12-bit model: 7.6936 bits.
