In [13]:
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tvt
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder, CIFAR100, CIFAR10
from tqdm import tqdm

from compvit.factory import compvit_factory
from compvit.layers.bottleneck import conv_bottleneck
from compvit.layers.compressor import Compressor
from compvit.layers.inverted_bottleneck import inverted_mlp
from dinov2.factory import dinov2_factory
from dinov2.layers import Mlp

In [14]:
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')

Using cache found in /home/jajal/.cache/torch/hub/facebookresearch_dinov2_main


ModuleNotFoundError: No module named 'dinov2.hub'

In [2]:
class CompresserDecoder(nn.Module):
    def __init__(self, compressor, total_tokens, num_compressed_tokens) -> None:
        super().__init__()
        self.compressor = compressor
        self.decoder = nn.Linear(num_compressed_tokens, total_tokens)
    
    def forward(self, x):
        x = self.compressor(x)
        x = self.decoder(x.mT).mT
        return x

# Setup Dataset

In [3]:
def calculate_mean_std(dataset):
    mean = 0
    std = 0
    loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
    for img, _ in loader:
        batch_mean = torch.mean(img, dim=(0, 2, 3))
        batch_std = torch.std(img, dim=(0, 2, 3))
        mean += batch_mean
        std += batch_std
    mean /= len(loader)
    std /= len(loader)
    return mean, std

In [4]:
mean, std = calculate_mean_std(ImageFolder(root="data/synthetic-checkerboard/train", transform=tvt.ToTensor()))
transform = tvt.Compose([tvt.ToTensor(), tvt.Normalize(mean, std)])
train_dataset = ImageFolder("data/synthetic-checkerboard/train", transform=transform)
val_dataset = ImageFolder("data/synthetic-checkerboard/val", transform=transform)

In [5]:
TRANSFORM = tvt.Compose(
    [
        tvt.RandomCrop(32, padding=4),
        tvt.Resize(224),
        tvt.RandomHorizontalFlip(),
        tvt.ToTensor(),
        tvt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

TRANSFORM_TEST = tvt.Compose(
    [
        tvt.Resize(224),
        tvt.ToTensor(),
        tvt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_dataset = CIFAR10("toy_experiments/data", transform=TRANSFORM, download=True)
val_dataset = CIFAR10(
    "toy_experiments/data", transform=TRANSFORM_TEST, train=False, download=True
)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=1)

# Evaluate and Training Code

In [7]:
def evaluate_reconstruction(model, compressor_decoder, val_loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for img, _ in tqdm(val_loader):
            img = img.to("cuda")
            tokens = model.prepare_tokens_with_masks(img)
            pred_tokens = compressor_decoder(tokens)
            loss = F.mse_loss(pred_tokens, tokens, reduction="mean")
            total_loss += loss.item()
    return total_loss / len(val_loader)

In [8]:
def train(model, compressor_decoder, train_loader, epochs, optimizer):
    for epoch in range(epochs):
        compressor_decoder.train()
        total_loss = 0.0
        for img, _ in tqdm(train_loader):
            img = img.to("cuda")
            with torch.no_grad():
                tokens = model.prepare_tokens_with_masks(img)
            pred_tokens = compressor_decoder(tokens)
            loss = F.mse_loss(pred_tokens, tokens, reduction="mean")
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_loss = loss.item()
        print(f"Epoch {epoch}: {total_loss / len(train_loader):.2e}")
        print(f"Validation loss: {evaluate_reconstruction(model, compressor_decoder, val_loader):.2e}")

# Setup Default Models and Parameters

In [9]:
embed_dim = 384
num_heads = 6
mlp_ratio = 4
qkv_bias = True
ffn_bias = True
proj_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU
ffn_layer = Mlp
init_values = 1.0
num_compressed_tokens = 17
total_tokens = 257

bottleneck_size = 1
bottleneck = partial(
    conv_bottleneck,
    dim=embed_dim,
    ratio=mlp_ratio,
    bottleneck_size=bottleneck_size,
)

inv_bottle_size = 1
codebook_ratio = 2
inv_bottleneck = partial(
    inverted_mlp,
    dim=embed_dim,
    ratio=codebook_ratio,
)

In [10]:
model, cfg = dinov2_factory("dinov2_vits14")
model.load_state_dict(torch.load("dinov2/checkpoints/dinov2_vits14_pretrain.pth"))

<All keys matched successfully>

In [11]:
def create_compressor(num_codebook_tokens):
    compressor = Compressor(
        dim=embed_dim,
        num_heads=num_heads,
        mlp_ratio=mlp_ratio,
        qkv_bias=qkv_bias,
        proj_bias=proj_bias,
        ffn_bias=ffn_bias,
        norm_layer=norm_layer,
        act_layer=act_layer,
        ffn_layer=ffn_layer,
        init_values=init_values,
        num_compressed_tokens=num_compressed_tokens,
        num_tokens=total_tokens,
        bottleneck=bottleneck,
        num_codebook_tokens=num_codebook_tokens,
        inv_bottleneck=inv_bottleneck,
    )
    return compressor

# Setup Up Models -- codebook tokens > total tokens

In [12]:
num_codebook_tokens = 384
compressor = create_compressor(num_codebook_tokens)

In [13]:
model = model.to("cuda")
compressor_decoder = CompresserDecoder(
    compressor,
    total_tokens,
    num_compressed_tokens,
).to("cuda")
total_params = sum(p.numel() for p in compressor_decoder.parameters())
print(f"Number of parameters in CompressorDecoder: {total_params:,}")

Number of parameters in CompressorDecoder: 5,160,476


In [14]:
epochs = 50
optimizer = torch.optim.AdamW(compressor_decoder.parameters(), lr=1e-6, weight_decay=5e-2)

In [15]:
train(model, compressor_decoder, train_loader, epochs, optimizer)
test_loss = evaluate_reconstruction(model, compressor_decoder, val_loader)
print(f"Final test loss: {test_loss:.2e}")

100%|██████████| 782/782 [00:32<00:00, 24.07it/s]


Epoch 0: 3.49e-05


100%|██████████| 313/313 [00:06<00:00, 47.50it/s]


Validation loss: 2.67e-02


100%|██████████| 782/782 [00:32<00:00, 24.12it/s]


Epoch 1: 3.31e-05


100%|██████████| 313/313 [00:06<00:00, 48.26it/s]


Validation loss: 2.53e-02


100%|██████████| 782/782 [00:32<00:00, 24.12it/s]


Epoch 2: 3.19e-05


100%|██████████| 313/313 [00:06<00:00, 47.28it/s]


Validation loss: 2.45e-02


100%|██████████| 782/782 [00:32<00:00, 24.20it/s]


Epoch 3: 3.06e-05


100%|██████████| 313/313 [00:06<00:00, 48.20it/s]


Validation loss: 2.38e-02


100%|██████████| 782/782 [00:32<00:00, 24.15it/s]


Epoch 4: 3.03e-05


100%|██████████| 313/313 [00:06<00:00, 48.45it/s]


Validation loss: 2.32e-02


100%|██████████| 782/782 [00:32<00:00, 24.11it/s]


Epoch 5: 2.90e-05


100%|██████████| 313/313 [00:06<00:00, 48.16it/s]


Validation loss: 2.26e-02


100%|██████████| 782/782 [00:32<00:00, 24.09it/s]


Epoch 6: 2.86e-05


100%|██████████| 313/313 [00:06<00:00, 47.45it/s]


Validation loss: 2.20e-02


100%|██████████| 782/782 [00:32<00:00, 24.02it/s]


Epoch 7: 2.77e-05


100%|██████████| 313/313 [00:06<00:00, 47.79it/s]


Validation loss: 2.15e-02


100%|██████████| 782/782 [00:32<00:00, 24.12it/s]


Epoch 8: 2.75e-05


100%|██████████| 313/313 [00:06<00:00, 47.58it/s]


Validation loss: 2.11e-02


100%|██████████| 782/782 [00:32<00:00, 24.08it/s]


Epoch 9: 2.67e-05


100%|██████████| 313/313 [00:06<00:00, 47.40it/s]


Validation loss: 2.06e-02


100%|██████████| 782/782 [00:32<00:00, 24.13it/s]


Epoch 10: 2.64e-05


100%|██████████| 313/313 [00:06<00:00, 47.95it/s]


Validation loss: 2.02e-02


100%|██████████| 782/782 [00:32<00:00, 24.10it/s]


Epoch 11: 2.57e-05


100%|██████████| 313/313 [00:06<00:00, 46.87it/s]


Validation loss: 1.97e-02


100%|██████████| 782/782 [00:32<00:00, 24.09it/s]


Epoch 12: 2.53e-05


100%|██████████| 313/313 [00:06<00:00, 47.65it/s]


Validation loss: 1.93e-02


100%|██████████| 782/782 [00:32<00:00, 24.09it/s]


Epoch 13: 2.44e-05


100%|██████████| 313/313 [00:06<00:00, 47.65it/s]


Validation loss: 1.88e-02


100%|██████████| 782/782 [00:32<00:00, 24.10it/s]


Epoch 14: 2.37e-05


100%|██████████| 313/313 [00:06<00:00, 45.90it/s]


Validation loss: 1.84e-02


100%|██████████| 782/782 [00:32<00:00, 24.15it/s]


Epoch 15: 2.33e-05


100%|██████████| 313/313 [00:06<00:00, 48.06it/s]


Validation loss: 1.79e-02


100%|██████████| 782/782 [00:32<00:00, 24.05it/s]


Epoch 16: 2.25e-05


100%|██████████| 313/313 [00:06<00:00, 46.43it/s]


Validation loss: 1.74e-02


100%|██████████| 782/782 [00:32<00:00, 24.06it/s]


Epoch 17: 2.24e-05


100%|██████████| 313/313 [00:06<00:00, 47.75it/s]


Validation loss: 1.70e-02


100%|██████████| 782/782 [00:32<00:00, 24.08it/s]


Epoch 18: 2.12e-05


100%|██████████| 313/313 [00:06<00:00, 47.95it/s]


Validation loss: 1.65e-02


100%|██████████| 782/782 [00:32<00:00, 24.09it/s]


Epoch 19: 2.11e-05


100%|██████████| 313/313 [00:06<00:00, 46.75it/s]


Validation loss: 1.61e-02


100%|██████████| 782/782 [00:32<00:00, 24.08it/s]


Epoch 20: 2.02e-05


100%|██████████| 313/313 [00:06<00:00, 47.45it/s]


Validation loss: 1.56e-02


100%|██████████| 782/782 [00:32<00:00, 24.04it/s]


Epoch 21: 2.00e-05


100%|██████████| 313/313 [00:06<00:00, 47.62it/s]


Validation loss: 1.52e-02


100%|██████████| 782/782 [00:32<00:00, 24.10it/s]


Epoch 22: 1.91e-05


100%|██████████| 313/313 [00:06<00:00, 47.87it/s]


Validation loss: 1.48e-02


100%|██████████| 782/782 [00:32<00:00, 24.10it/s]


Epoch 23: 1.88e-05


100%|██████████| 313/313 [00:06<00:00, 47.80it/s]


Validation loss: 1.43e-02


100%|██████████| 782/782 [00:32<00:00, 23.97it/s]


Epoch 24: 1.81e-05


100%|██████████| 313/313 [00:06<00:00, 47.82it/s]


Validation loss: 1.39e-02


100%|██████████| 782/782 [00:32<00:00, 24.03it/s]


Epoch 25: 1.74e-05


100%|██████████| 313/313 [00:06<00:00, 46.95it/s]


Validation loss: 1.35e-02


100%|██████████| 782/782 [00:32<00:00, 24.02it/s]


Epoch 26: 1.69e-05


100%|██████████| 313/313 [00:06<00:00, 47.01it/s]


Validation loss: 1.31e-02


100%|██████████| 782/782 [00:32<00:00, 24.10it/s]


Epoch 27: 1.67e-05


100%|██████████| 313/313 [00:06<00:00, 48.19it/s]


Validation loss: 1.27e-02


100%|██████████| 782/782 [00:32<00:00, 24.12it/s]


Epoch 28: 1.60e-05


100%|██████████| 313/313 [00:06<00:00, 47.62it/s]


Validation loss: 1.23e-02


100%|██████████| 782/782 [00:32<00:00, 24.15it/s]


Epoch 29: 1.55e-05


100%|██████████| 313/313 [00:06<00:00, 47.83it/s]


Validation loss: 1.19e-02


100%|██████████| 782/782 [00:32<00:00, 24.07it/s]


Epoch 30: 1.53e-05


100%|██████████| 313/313 [00:06<00:00, 47.67it/s]


Validation loss: 1.15e-02


100%|██████████| 782/782 [00:32<00:00, 24.06it/s]


Epoch 31: 1.46e-05


100%|██████████| 313/313 [00:06<00:00, 47.24it/s]


Validation loss: 1.12e-02


100%|██████████| 782/782 [00:32<00:00, 24.09it/s]


Epoch 32: 1.43e-05


100%|██████████| 313/313 [00:06<00:00, 48.25it/s]


Validation loss: 1.08e-02


100%|██████████| 782/782 [00:32<00:00, 24.00it/s]


Epoch 33: 1.38e-05


100%|██████████| 313/313 [00:06<00:00, 46.63it/s]


Validation loss: 1.05e-02


100%|██████████| 782/782 [00:32<00:00, 24.07it/s]


Epoch 34: 1.34e-05


100%|██████████| 313/313 [00:06<00:00, 46.08it/s]


Validation loss: 1.01e-02


100%|██████████| 782/782 [00:32<00:00, 24.07it/s]


Epoch 35: 1.30e-05


100%|██████████| 313/313 [00:06<00:00, 47.02it/s]


Validation loss: 9.80e-03


100%|██████████| 782/782 [00:32<00:00, 24.12it/s]


Epoch 36: 1.26e-05


100%|██████████| 313/313 [00:06<00:00, 47.13it/s]


Validation loss: 9.48e-03


100%|██████████| 782/782 [00:32<00:00, 24.03it/s]


Epoch 37: 1.20e-05


100%|██████████| 313/313 [00:06<00:00, 47.21it/s]


Validation loss: 9.17e-03


100%|██████████| 782/782 [00:32<00:00, 24.07it/s]


Epoch 38: 1.16e-05


100%|██████████| 313/313 [00:06<00:00, 48.27it/s]


Validation loss: 8.86e-03


100%|██████████| 782/782 [00:32<00:00, 24.08it/s]


Epoch 39: 1.14e-05


100%|██████████| 313/313 [00:06<00:00, 48.68it/s]


Validation loss: 8.57e-03


100%|██████████| 782/782 [00:32<00:00, 24.08it/s]


Epoch 40: 1.07e-05


100%|██████████| 313/313 [00:06<00:00, 47.38it/s]


Validation loss: 8.28e-03


100%|██████████| 782/782 [00:32<00:00, 24.06it/s]


Epoch 41: 1.05e-05


100%|██████████| 313/313 [00:06<00:00, 48.28it/s]


Validation loss: 8.00e-03


100%|██████████| 782/782 [00:32<00:00, 24.01it/s]


Epoch 42: 9.93e-06


100%|██████████| 313/313 [00:06<00:00, 46.34it/s]


Validation loss: 7.73e-03


100%|██████████| 782/782 [00:32<00:00, 24.09it/s]


Epoch 43: 9.83e-06


100%|██████████| 313/313 [00:06<00:00, 46.95it/s]


Validation loss: 7.47e-03


100%|██████████| 782/782 [00:32<00:00, 24.00it/s]


Epoch 44: 9.44e-06


100%|██████████| 313/313 [00:06<00:00, 47.40it/s]


Validation loss: 7.21e-03


100%|██████████| 782/782 [00:32<00:00, 24.10it/s]


Epoch 45: 9.23e-06


100%|██████████| 313/313 [00:06<00:00, 47.50it/s]


Validation loss: 6.97e-03


100%|██████████| 782/782 [00:32<00:00, 24.06it/s]


Epoch 46: 9.08e-06


100%|██████████| 313/313 [00:06<00:00, 47.85it/s]


Validation loss: 6.73e-03


100%|██████████| 782/782 [00:32<00:00, 24.08it/s]


Epoch 47: 8.96e-06


100%|██████████| 313/313 [00:06<00:00, 48.43it/s]


Validation loss: 6.50e-03


100%|██████████| 782/782 [00:32<00:00, 24.04it/s]


Epoch 48: 8.53e-06


100%|██████████| 313/313 [00:06<00:00, 48.02it/s]


Validation loss: 6.27e-03


100%|██████████| 782/782 [00:32<00:00, 24.05it/s]


Epoch 49: 8.22e-06


100%|██████████| 313/313 [00:06<00:00, 47.31it/s]


Validation loss: 6.05e-03


100%|██████████| 313/313 [00:06<00:00, 48.05it/s]

Final test loss: 6.05e-03





# Setup Up Models -- codebook tokens = total tokens

In [16]:
num_codebook_tokens = 257
compressor = create_compressor(num_codebook_tokens)

In [17]:
model = model.to("cuda")
compressor_decoder = CompresserDecoder(
    compressor,
    total_tokens,
    num_compressed_tokens,
).to("cuda")
total_params = sum(p.numel() for p in compressor_decoder.parameters())
print(f"Number of parameters in CompressorDecoder: {total_params:,}")

Number of parameters in CompressorDecoder: 5,160,476


In [18]:
epochs = 50
optimizer = torch.optim.AdamW(compressor_decoder.parameters(), lr=1e-6, weight_decay=5e-2)

In [19]:
train(model, compressor_decoder, train_loader, epochs, optimizer)
test_loss = evaluate_reconstruction(model, compressor_decoder, val_loader)
print(f"Final test loss: {test_loss:.2e}")

100%|██████████| 782/782 [00:29<00:00, 26.55it/s]


Epoch 0: 3.44e-05


100%|██████████| 313/313 [00:06<00:00, 48.27it/s]


Validation loss: 2.64e-02


100%|██████████| 782/782 [00:29<00:00, 26.53it/s]


Epoch 1: 3.32e-05


100%|██████████| 313/313 [00:06<00:00, 46.88it/s]


Validation loss: 2.53e-02


100%|██████████| 782/782 [00:29<00:00, 26.49it/s]


Epoch 2: 3.16e-05


100%|██████████| 313/313 [00:06<00:00, 47.91it/s]


Validation loss: 2.46e-02


100%|██████████| 782/782 [00:29<00:00, 26.44it/s]


Epoch 3: 3.08e-05


100%|██████████| 313/313 [00:06<00:00, 46.60it/s]


Validation loss: 2.39e-02


100%|██████████| 782/782 [00:29<00:00, 26.56it/s]


Epoch 4: 2.99e-05


100%|██████████| 313/313 [00:06<00:00, 47.39it/s]


Validation loss: 2.32e-02


100%|██████████| 782/782 [00:29<00:00, 26.53it/s]


Epoch 5: 2.91e-05


100%|██████████| 313/313 [00:06<00:00, 46.87it/s]


Validation loss: 2.25e-02


100%|██████████| 782/782 [00:29<00:00, 26.49it/s]


Epoch 6: 2.89e-05


100%|██████████| 313/313 [00:06<00:00, 47.81it/s]


Validation loss: 2.20e-02


100%|██████████| 782/782 [00:29<00:00, 26.49it/s]


Epoch 7: 2.79e-05


100%|██████████| 313/313 [00:06<00:00, 47.66it/s]


Validation loss: 2.16e-02


100%|██████████| 782/782 [00:29<00:00, 26.49it/s]


Epoch 8: 2.74e-05


100%|██████████| 313/313 [00:06<00:00, 45.75it/s]


Validation loss: 2.11e-02


100%|██████████| 782/782 [00:30<00:00, 25.91it/s]


Epoch 9: 2.69e-05


100%|██████████| 313/313 [00:06<00:00, 46.22it/s]


Validation loss: 2.07e-02


100%|██████████| 782/782 [00:30<00:00, 25.82it/s]


Epoch 10: 2.64e-05


100%|██████████| 313/313 [00:06<00:00, 46.50it/s]


Validation loss: 2.03e-02


100%|██████████| 782/782 [00:30<00:00, 25.92it/s]


Epoch 11: 2.57e-05


100%|██████████| 313/313 [00:06<00:00, 45.33it/s]


Validation loss: 1.99e-02


100%|██████████| 782/782 [00:30<00:00, 25.87it/s]


Epoch 12: 2.50e-05


100%|██████████| 313/313 [00:06<00:00, 45.02it/s]


Validation loss: 1.95e-02


100%|██████████| 782/782 [00:30<00:00, 25.84it/s]


Epoch 13: 2.47e-05


100%|██████████| 313/313 [00:06<00:00, 46.25it/s]


Validation loss: 1.91e-02


100%|██████████| 782/782 [00:30<00:00, 25.78it/s]


Epoch 14: 2.43e-05


100%|██████████| 313/313 [00:06<00:00, 46.23it/s]


Validation loss: 1.87e-02


100%|██████████| 782/782 [00:30<00:00, 25.91it/s]


Epoch 15: 2.36e-05


100%|██████████| 313/313 [00:06<00:00, 46.41it/s]


Validation loss: 1.83e-02


100%|██████████| 782/782 [00:30<00:00, 25.94it/s]


Epoch 16: 2.34e-05


100%|██████████| 313/313 [00:06<00:00, 46.76it/s]


Validation loss: 1.79e-02


100%|██████████| 782/782 [00:30<00:00, 26.00it/s]


Epoch 17: 2.29e-05


100%|██████████| 313/313 [00:06<00:00, 46.70it/s]


Validation loss: 1.75e-02


100%|██████████| 782/782 [00:30<00:00, 26.02it/s]


Epoch 18: 2.21e-05


100%|██████████| 313/313 [00:06<00:00, 46.59it/s]


Validation loss: 1.70e-02


100%|██████████| 782/782 [00:30<00:00, 25.97it/s]


Epoch 19: 2.17e-05


100%|██████████| 313/313 [00:06<00:00, 46.82it/s]


Validation loss: 1.66e-02


100%|██████████| 782/782 [00:30<00:00, 25.98it/s]


Epoch 20: 2.09e-05


100%|██████████| 313/313 [00:07<00:00, 44.62it/s]


Validation loss: 1.62e-02


100%|██████████| 782/782 [00:30<00:00, 25.86it/s]


Epoch 21: 2.02e-05


100%|██████████| 313/313 [00:06<00:00, 46.35it/s]


Validation loss: 1.58e-02


100%|██████████| 782/782 [00:30<00:00, 25.77it/s]


Epoch 22: 2.04e-05


100%|██████████| 313/313 [00:06<00:00, 45.86it/s]


Validation loss: 1.54e-02


100%|██████████| 782/782 [00:29<00:00, 26.23it/s]


Epoch 23: 1.98e-05


100%|██████████| 313/313 [00:06<00:00, 46.27it/s]


Validation loss: 1.50e-02


100%|██████████| 782/782 [00:29<00:00, 26.35it/s]


Epoch 24: 1.90e-05


100%|██████████| 313/313 [00:06<00:00, 46.40it/s]


Validation loss: 1.46e-02


100%|██████████| 782/782 [00:30<00:00, 25.78it/s]


Epoch 25: 1.86e-05


100%|██████████| 313/313 [00:06<00:00, 45.72it/s]


Validation loss: 1.42e-02


100%|██████████| 782/782 [00:30<00:00, 25.68it/s]


Epoch 26: 1.82e-05


100%|██████████| 313/313 [00:06<00:00, 46.70it/s]


Validation loss: 1.39e-02


100%|██████████| 782/782 [00:30<00:00, 25.63it/s]


Epoch 27: 1.74e-05


100%|██████████| 313/313 [00:06<00:00, 45.84it/s]


Validation loss: 1.35e-02


100%|██████████| 782/782 [00:30<00:00, 25.47it/s]


Epoch 28: 1.70e-05


100%|██████████| 313/313 [00:06<00:00, 45.57it/s]


Validation loss: 1.31e-02


100%|██████████| 782/782 [00:30<00:00, 25.60it/s]


Epoch 29: 1.65e-05


100%|██████████| 313/313 [00:06<00:00, 45.67it/s]


Validation loss: 1.27e-02


100%|██████████| 782/782 [00:30<00:00, 25.63it/s]


Epoch 30: 1.61e-05


100%|██████████| 313/313 [00:06<00:00, 45.67it/s]


Validation loss: 1.23e-02


100%|██████████| 782/782 [00:30<00:00, 25.58it/s]


Epoch 31: 1.56e-05


100%|██████████| 313/313 [00:06<00:00, 44.85it/s]


Validation loss: 1.20e-02


100%|██████████| 782/782 [00:30<00:00, 25.99it/s]


Epoch 32: 1.52e-05


100%|██████████| 313/313 [00:06<00:00, 46.39it/s]


Validation loss: 1.16e-02


100%|██████████| 782/782 [00:30<00:00, 25.95it/s]


Epoch 33: 1.47e-05


100%|██████████| 313/313 [00:06<00:00, 46.22it/s]


Validation loss: 1.13e-02


100%|██████████| 782/782 [00:30<00:00, 25.83it/s]


Epoch 34: 1.42e-05


100%|██████████| 313/313 [00:06<00:00, 45.72it/s]


Validation loss: 1.09e-02


100%|██████████| 782/782 [00:30<00:00, 25.72it/s]


Epoch 35: 1.42e-05


100%|██████████| 313/313 [00:06<00:00, 45.76it/s]


Validation loss: 1.05e-02


100%|██████████| 782/782 [00:30<00:00, 25.76it/s]


Epoch 36: 1.34e-05


100%|██████████| 313/313 [00:06<00:00, 45.68it/s]


Validation loss: 1.02e-02


100%|██████████| 782/782 [00:30<00:00, 25.75it/s]


Epoch 37: 1.29e-05


100%|██████████| 313/313 [00:06<00:00, 45.78it/s]


Validation loss: 9.86e-03


100%|██████████| 782/782 [00:30<00:00, 25.77it/s]


Epoch 38: 1.28e-05


100%|██████████| 313/313 [00:06<00:00, 46.61it/s]


Validation loss: 9.53e-03


100%|██████████| 782/782 [00:30<00:00, 25.77it/s]


Epoch 39: 1.22e-05


100%|██████████| 313/313 [00:06<00:00, 46.09it/s]


Validation loss: 9.20e-03


100%|██████████| 782/782 [00:30<00:00, 25.72it/s]


Epoch 40: 1.17e-05


100%|██████████| 313/313 [00:06<00:00, 45.90it/s]


Validation loss: 8.88e-03


100%|██████████| 782/782 [00:30<00:00, 25.39it/s]


Epoch 41: 1.13e-05


100%|██████████| 313/313 [00:06<00:00, 45.73it/s]


Validation loss: 8.57e-03


100%|██████████| 782/782 [00:30<00:00, 25.60it/s]


Epoch 42: 1.08e-05


100%|██████████| 313/313 [00:06<00:00, 45.62it/s]


Validation loss: 8.26e-03


100%|██████████| 782/782 [00:30<00:00, 25.55it/s]


Epoch 43: 1.05e-05


100%|██████████| 313/313 [00:06<00:00, 45.70it/s]


Validation loss: 7.96e-03


100%|██████████| 782/782 [00:30<00:00, 25.65it/s]


Epoch 44: 1.02e-05


100%|██████████| 313/313 [00:06<00:00, 46.12it/s]


Validation loss: 7.67e-03


100%|██████████| 782/782 [00:30<00:00, 25.58it/s]


Epoch 45: 1.01e-05


100%|██████████| 313/313 [00:06<00:00, 45.74it/s]


Validation loss: 7.39e-03


100%|██████████| 782/782 [00:30<00:00, 25.55it/s]


Epoch 46: 9.47e-06


100%|██████████| 313/313 [00:06<00:00, 45.60it/s]


Validation loss: 7.11e-03


100%|██████████| 782/782 [00:30<00:00, 25.60it/s]


Epoch 47: 9.00e-06


100%|██████████| 313/313 [00:06<00:00, 45.73it/s]


Validation loss: 6.84e-03


100%|██████████| 782/782 [00:30<00:00, 25.56it/s]


Epoch 48: 8.83e-06


100%|██████████| 313/313 [00:06<00:00, 45.66it/s]


Validation loss: 6.58e-03


100%|██████████| 782/782 [00:30<00:00, 25.66it/s]


Epoch 49: 8.38e-06


100%|██████████| 313/313 [00:06<00:00, 45.72it/s]


Validation loss: 6.33e-03


100%|██████████| 313/313 [00:06<00:00, 46.75it/s]

Final test loss: 6.33e-03





# Setup Up Models -- codebook tokens < total tokens

In [20]:
num_codebook_tokens = 128
compressor = create_compressor(num_codebook_tokens)

In [21]:
model = model.to("cuda")
compressor_decoder = CompresserDecoder(
    compressor,
    total_tokens,
    num_compressed_tokens,
).to("cuda")
total_params = sum(p.numel() for p in compressor_decoder.parameters())
print(f"Number of parameters in CompressorDecoder: {total_params:,}")


Number of parameters in CompressorDecoder: 5,160,476


In [22]:
epochs = 50
optimizer = torch.optim.AdamW(compressor_decoder.parameters(), lr=1e-6, weight_decay=5e-2)

In [23]:
train(model, compressor_decoder, train_loader, epochs, optimizer)
test_loss = evaluate_reconstruction(model, compressor_decoder, val_loader)
print(f"Final test loss: {test_loss:.2e}")

100%|██████████| 782/782 [00:27<00:00, 28.82it/s]


Epoch 0: 3.40e-05


100%|██████████| 313/313 [00:06<00:00, 45.84it/s]


Validation loss: 2.65e-02


100%|██████████| 782/782 [00:27<00:00, 28.91it/s]


Epoch 1: 3.12e-05


100%|██████████| 313/313 [00:06<00:00, 46.21it/s]


Validation loss: 2.43e-02


100%|██████████| 782/782 [00:27<00:00, 28.86it/s]


Epoch 2: 3.02e-05


100%|██████████| 313/313 [00:06<00:00, 46.26it/s]


Validation loss: 2.33e-02


100%|██████████| 782/782 [00:27<00:00, 28.95it/s]


Epoch 3: 2.89e-05


100%|██████████| 313/313 [00:06<00:00, 46.29it/s]


Validation loss: 2.25e-02


100%|██████████| 782/782 [00:26<00:00, 28.99it/s]


Epoch 4: 2.88e-05


100%|██████████| 313/313 [00:06<00:00, 45.99it/s]


Validation loss: 2.18e-02


100%|██████████| 782/782 [00:26<00:00, 28.98it/s]


Epoch 5: 2.72e-05


100%|██████████| 313/313 [00:06<00:00, 45.99it/s]


Validation loss: 2.13e-02


100%|██████████| 782/782 [00:26<00:00, 28.97it/s]


Epoch 6: 2.68e-05


100%|██████████| 313/313 [00:06<00:00, 45.91it/s]


Validation loss: 2.08e-02


100%|██████████| 782/782 [00:26<00:00, 28.97it/s]


Epoch 7: 2.67e-05


100%|██████████| 313/313 [00:06<00:00, 46.22it/s]


Validation loss: 2.03e-02


100%|██████████| 782/782 [00:27<00:00, 28.96it/s]


Epoch 8: 2.59e-05


100%|██████████| 313/313 [00:06<00:00, 46.71it/s]


Validation loss: 1.99e-02


100%|██████████| 782/782 [00:26<00:00, 28.99it/s]


Epoch 9: 2.51e-05


100%|██████████| 313/313 [00:06<00:00, 46.72it/s]


Validation loss: 1.94e-02


100%|██████████| 782/782 [00:27<00:00, 28.91it/s]


Epoch 10: 2.45e-05


100%|██████████| 313/313 [00:06<00:00, 46.43it/s]


Validation loss: 1.90e-02


100%|██████████| 782/782 [00:26<00:00, 28.99it/s]


Epoch 11: 2.42e-05


100%|██████████| 313/313 [00:06<00:00, 46.11it/s]


Validation loss: 1.86e-02


100%|██████████| 782/782 [00:26<00:00, 29.04it/s]


Epoch 12: 2.37e-05


100%|██████████| 313/313 [00:06<00:00, 46.43it/s]


Validation loss: 1.82e-02


100%|██████████| 782/782 [00:27<00:00, 28.94it/s]


Epoch 13: 2.29e-05


100%|██████████| 313/313 [00:06<00:00, 46.13it/s]


Validation loss: 1.78e-02


100%|██████████| 782/782 [00:26<00:00, 28.97it/s]


Epoch 14: 2.26e-05


100%|██████████| 313/313 [00:06<00:00, 46.35it/s]


Validation loss: 1.73e-02


100%|██████████| 782/782 [00:26<00:00, 28.98it/s]


Epoch 15: 2.18e-05


100%|██████████| 313/313 [00:06<00:00, 46.35it/s]


Validation loss: 1.69e-02


100%|██████████| 782/782 [00:27<00:00, 28.93it/s]


Epoch 16: 2.14e-05


100%|██████████| 313/313 [00:06<00:00, 46.36it/s]


Validation loss: 1.65e-02


100%|██████████| 782/782 [00:27<00:00, 28.91it/s]


Epoch 17: 2.10e-05


100%|██████████| 313/313 [00:06<00:00, 46.21it/s]


Validation loss: 1.61e-02


100%|██████████| 782/782 [00:26<00:00, 29.13it/s]


Epoch 18: 2.05e-05


100%|██████████| 313/313 [00:06<00:00, 46.46it/s]


Validation loss: 1.57e-02


100%|██████████| 782/782 [00:26<00:00, 29.05it/s]


Epoch 19: 2.00e-05


100%|██████████| 313/313 [00:06<00:00, 45.63it/s]


Validation loss: 1.53e-02


100%|██████████| 782/782 [00:26<00:00, 29.13it/s]


Epoch 20: 1.92e-05


100%|██████████| 313/313 [00:06<00:00, 46.61it/s]


Validation loss: 1.49e-02


100%|██████████| 782/782 [00:26<00:00, 29.01it/s]


Epoch 21: 1.89e-05


100%|██████████| 313/313 [00:06<00:00, 46.38it/s]


Validation loss: 1.46e-02


100%|██████████| 782/782 [00:26<00:00, 29.07it/s]


Epoch 22: 1.81e-05


100%|██████████| 313/313 [00:06<00:00, 46.24it/s]


Validation loss: 1.42e-02


100%|██████████| 782/782 [00:26<00:00, 29.09it/s]


Epoch 23: 1.80e-05


100%|██████████| 313/313 [00:06<00:00, 46.69it/s]


Validation loss: 1.38e-02


100%|██████████| 782/782 [00:26<00:00, 29.00it/s]


Epoch 24: 1.76e-05


100%|██████████| 313/313 [00:06<00:00, 46.56it/s]


Validation loss: 1.35e-02


100%|██████████| 782/782 [00:26<00:00, 29.06it/s]


Epoch 25: 1.78e-05


100%|██████████| 313/313 [00:06<00:00, 46.25it/s]


Validation loss: 1.31e-02


100%|██████████| 782/782 [00:26<00:00, 29.04it/s]


Epoch 26: 1.66e-05


100%|██████████| 313/313 [00:06<00:00, 46.49it/s]


Validation loss: 1.27e-02


100%|██████████| 782/782 [00:26<00:00, 28.97it/s]


Epoch 27: 1.62e-05


100%|██████████| 313/313 [00:06<00:00, 46.55it/s]


Validation loss: 1.24e-02


100%|██████████| 782/782 [00:26<00:00, 29.08it/s]


Epoch 28: 1.57e-05


100%|██████████| 313/313 [00:06<00:00, 46.32it/s]


Validation loss: 1.21e-02


100%|██████████| 782/782 [00:26<00:00, 29.04it/s]


Epoch 29: 1.54e-05


100%|██████████| 313/313 [00:06<00:00, 46.03it/s]


Validation loss: 1.17e-02


100%|██████████| 782/782 [00:26<00:00, 29.07it/s]


Epoch 30: 1.47e-05


100%|██████████| 313/313 [00:06<00:00, 45.96it/s]


Validation loss: 1.14e-02


100%|██████████| 782/782 [00:26<00:00, 28.99it/s]


Epoch 31: 1.45e-05


100%|██████████| 313/313 [00:06<00:00, 46.25it/s]


Validation loss: 1.11e-02


100%|██████████| 782/782 [00:27<00:00, 28.78it/s]


Epoch 32: 1.43e-05


100%|██████████| 313/313 [00:06<00:00, 46.03it/s]


Validation loss: 1.08e-02


100%|██████████| 782/782 [00:27<00:00, 28.90it/s]


Epoch 33: 1.37e-05


100%|██████████| 313/313 [00:06<00:00, 46.06it/s]


Validation loss: 1.05e-02


100%|██████████| 782/782 [00:26<00:00, 29.32it/s]


Epoch 34: 1.33e-05


100%|██████████| 313/313 [00:06<00:00, 46.87it/s]


Validation loss: 1.02e-02


100%|██████████| 782/782 [00:26<00:00, 29.25it/s]


Epoch 35: 1.30e-05


100%|██████████| 313/313 [00:06<00:00, 46.60it/s]


Validation loss: 9.91e-03


100%|██████████| 782/782 [00:26<00:00, 29.10it/s]


Epoch 36: 1.26e-05


100%|██████████| 313/313 [00:06<00:00, 45.65it/s]


Validation loss: 9.62e-03


100%|██████████| 782/782 [00:26<00:00, 29.20it/s]


Epoch 37: 1.21e-05


100%|██████████| 313/313 [00:06<00:00, 46.24it/s]


Validation loss: 9.35e-03


100%|██████████| 782/782 [00:26<00:00, 29.04it/s]


Epoch 38: 1.23e-05


100%|██████████| 313/313 [00:06<00:00, 46.32it/s]


Validation loss: 9.08e-03


100%|██████████| 782/782 [00:27<00:00, 28.85it/s]


Epoch 39: 1.18e-05


100%|██████████| 313/313 [00:06<00:00, 46.00it/s]


Validation loss: 8.82e-03


100%|██████████| 782/782 [00:27<00:00, 28.77it/s]


Epoch 40: 1.13e-05


100%|██████████| 313/313 [00:06<00:00, 46.27it/s]


Validation loss: 8.56e-03


100%|██████████| 782/782 [00:26<00:00, 29.23it/s]


Epoch 41: 1.11e-05


100%|██████████| 313/313 [00:06<00:00, 46.62it/s]


Validation loss: 8.32e-03


100%|██████████| 782/782 [00:26<00:00, 29.13it/s]


Epoch 42: 1.06e-05


100%|██████████| 313/313 [00:06<00:00, 45.67it/s]


Validation loss: 8.07e-03


100%|██████████| 782/782 [00:26<00:00, 29.14it/s]


Epoch 43: 1.06e-05


100%|██████████| 313/313 [00:06<00:00, 46.62it/s]


Validation loss: 7.84e-03


100%|██████████| 782/782 [00:26<00:00, 29.21it/s]


Epoch 44: 1.01e-05


100%|██████████| 313/313 [00:06<00:00, 46.75it/s]


Validation loss: 7.61e-03


100%|██████████| 782/782 [00:26<00:00, 29.33it/s]


Epoch 45: 9.61e-06


100%|██████████| 313/313 [00:06<00:00, 46.58it/s]


Validation loss: 7.38e-03


100%|██████████| 782/782 [00:26<00:00, 29.27it/s]


Epoch 46: 9.49e-06


100%|██████████| 313/313 [00:06<00:00, 46.47it/s]


Validation loss: 7.17e-03


100%|██████████| 782/782 [00:26<00:00, 29.21it/s]


Epoch 47: 9.48e-06


100%|██████████| 313/313 [00:06<00:00, 45.43it/s]


Validation loss: 6.96e-03


100%|██████████| 782/782 [00:26<00:00, 29.13it/s]


Epoch 48: 9.00e-06


100%|██████████| 313/313 [00:06<00:00, 46.40it/s]


Validation loss: 6.75e-03


100%|██████████| 782/782 [00:26<00:00, 29.16it/s]


Epoch 49: 8.80e-06


100%|██████████| 313/313 [00:06<00:00, 46.26it/s]


Validation loss: 6.55e-03


100%|██████████| 313/313 [00:06<00:00, 46.49it/s]

Final test loss: 6.55e-03



