# NICE

This is an algorithm for compressing images, modeled as arrays of 8x8 RGB blocks, into much smaller latents.

### Dataset

Download the dataset from Kaggle. The dataset was built from COCO, where each image was sampled for 8x8 RGB blocks.
The dataset consists of a training and validation split. Each split is a flat Nx3x8x8 tensor with an 8-bit integer data type.

In [None]:
import kagglehub
from pathlib import Path
datasets_path = Path(kagglehub.dataset_download('tay10r/image-block-compression'))

import torch
class BinDataset(torch.utils.data.Dataset):
    """
    This dataset is for sampling from .bin files containing an array of 3x8x8 blocks.
    """
    def __init__(self, filename: Path):
        with open(filename, 'rb') as f:
            self.__data = f.read()
        self.__num_samples = len(self.__data) // (3 * 8 * 8)

    def __len__(self) -> int:
        return self.__num_samples

    def __getitem__(self, index) -> torch.Tensor:
        s = 3 * 8 * 8
        offset = index * s
        block: bytes = self.__data[offset:offset+s]
        # This will produce a warning about a non-writable tensor.
        # It's annoying, but just try to ignore it.
        return torch.frombuffer(block, dtype=torch.uint8, count=s).reshape(3, 8, 8)

import numpy as np
class RemoteImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_url: str):
        import requests
        response = requests.get(image_url)
        response.raise_for_status()
        from PIL import Image
        from io import BytesIO
        self.__image = np.transpose(np.asarray(Image.open(fp=BytesIO(response.content))), axes=(2, 0, 1))

    def __len__(self) -> int:
        C, H, W = self.__image.shape
        y = H // 8
        x = W // 8
        return x * y

    def __getitem__(self, index: int) -> torch.Tensor:
        _, H, W = self.__image.shape
        tiles_w = W // 8
        tiles_h = H // 8
        x = index % tiles_w
        y = index // tiles_w
        if y >= tiles_h:
            raise IndexError(index)
        y0 = (y + 0) * 8
        y1 = (y + 1) * 8
        x0 = (x + 0) * 8
        x1 = (x + 1) * 8
        return torch.from_numpy(self.__image[:, y0:y1, x0:x1])

train_data = BinDataset(datasets_path / 'train.bin')
val_data = BinDataset(datasets_path / 'val.bin')
demo_data = RemoteImageDataset(image_url='https://raw.githubusercontent.com/mikolalysenko/baboon-image/master/baboon.png')

### Network

Now we'll define the network, which consists of an encoder and decoder.
We define these as two separate networks, since in practice they are used separately.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, hidden: int, num_bits: int):
        super().__init__()
        self.__layers = nn.Sequential(
            nn.Linear(3 * 8 * 8, hidden, bias=False),
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(),
            nn.Linear(hidden, num_bits, bias=False),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.flatten(start_dim=1)
        x = self.__layers(x)
        x = F.softsign(x)
        bits = x > 0.0
        bits = bits + x - x.detach()
        return bits

class Decoder(nn.Module):
    def __init__(self, hidden: int, num_bits: int):
        super().__init__()
        self.__layers = nn.Sequential(
            nn.Linear(num_bits, hidden, bias=False),
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(),
            nn.Linear(hidden, 8 * 8 * 3, bias=False)
        )

    def forward(self, bits: torch.Tensor) -> torch.Tensor:
        x: torch.Tensor = self.__layers(bits)
        x = x.reshape(x.shape[0], 3, 8, 8)
        return x

class Net(nn.Module):
    def __init__(self, num_bits: int, encoder_hidden: int, decoder_hidden: int):
        super().__init__()
        self.encoder = Encoder(hidden=encoder_hidden, num_bits=num_bits)
        self.decoder = Decoder(hidden=decoder_hidden, num_bits=num_bits)

    def forward(self, x: torch.Tensor):
        bits = self.encoder(x)
        #recon = self.decoder(F.softsign(bits))
        recon = self.decoder(bits)
        return recon, bits


device_name = "cuda" if torch.cuda.is_available() else "cpu"
print(f"compute device: {device_name}")
dev = torch.device(device_name)
net = Net(num_bits=48, encoder_hidden=256, decoder_hidden=256).to(dev)


### Training the Network

We're going to train on the COCO dataset. Since it's very large and takes a long time,
we're going to break it up into 1000 iterations. So every 1000 training samples, we will
check the validation loss and evaluate the demo image.

In [None]:
from torch.utils.tensorboard import SummaryWriter
import math

def forward(net: Net, x: torch.Tensor) -> torch.Tensor:
    y, z = net(x)
    loss = torch.nn.functional.mse_loss(y, x)
    return loss

def val(net: Net, data: torch.utils.data.DataLoader, dev: torch.device) -> float:
    net.eval()
    val_loss_sum = 0.0
    with torch.no_grad():
        for batch in data:
            x: torch.Tensor = batch
            x = x.to(dev).float() * (1.0 / 255.0)
            val_loss_sum += forward(net, x).item()
    return val_loss_sum / len(data)

def train(net: Net, train_data: torch.utils.data.Dataset, val_data: torch.utils.data.Dataset, dev: torch.device):
    writer = SummaryWriter()
    batch_size = 4096
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=128, shuffle=False)
    optimizer = torch.optim.AdamW(net.parameters(), lr=1.0e-3)
    counter = 0
    net.train()
    epochs = 16
    best_loss = math.inf
    for epoch in range(epochs):
        for batch in train_loader:
            x: torch.Tensor = batch
            x = x.to(dev).float() * (1.0 / 255.0)
            loss = forward(net, x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            counter += 1
            if counter % 1000 == 0:
                epoch += 1
                val_loss = val(net, val_loader, dev)
                writer.add_scalar('loss', val_loss, global_step=(counter // 1000))
                writer.flush()
                print(f'[{counter // 1000}]: {val_loss:.06f}')
                if val_loss < best_loss:
                    best_loss = val_loss
                    torch.save(net.state_dict(), 'best_weights.pt')
                net.train()

train(net, train_data, val_data, dev)

## Analyzing the Bit Distribution

A fully utilized bit will average 0.5 across a large dataset.
Let's verify and see how uniform our bit utilization is.

In [None]:
import matplotlib.pyplot as plt

def bit_means(net, val_loader, dev):
    net.eval()
    num_bits = 48
    sums = torch.zeros(num_bits, device=dev)
    count = 0

    with torch.no_grad():
        for batch in val_loader:
            x = batch.to(dev).float() * (1.0 / 255.0)
            _, bits = net(x)
            sums += bits.sum(dim=0)
            count += bits.shape[0]

    means = (sums / count).cpu()
    return means

best_net = Net(num_bits=48, encoder_hidden=256, decoder_hidden=256)
best_net.load_state_dict(torch.load('best_weights.pt'))
best_net.to(dev)

# usage
means = bit_means(best_net, val_loader=torch.utils.data.DataLoader(val_data, batch_size=128, shuffle=False), dev=dev)

plt.figure(figsize=(10,4))
plt.bar(range(len(means)), means.numpy())
plt.axhline(0.5, linestyle='--')
plt.title("Bit Activation Means (Validation Set)")
plt.xlabel("Bit Index")
plt.ylabel("Mean Activation")
plt.show()
