In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class VQVAE(nn.Module):
    def __init__(self, num_embeddings=512, embedding_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(128, embedding_dim, 1)
        )
        self.codebook = nn.Embedding(num_embeddings, embedding_dim)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, 128, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 3, 1), nn.Tanh()
        )

    def forward(self, x):
        z_e = self.encoder(x)  # [B, D, H, W]
        flat_z = z_e.permute(0, 2, 3, 1).contiguous().view(-1, z_e.size(1))

        # Quantization
        dists = (flat_z.unsqueeze(1) - self.codebook.weight).pow(2).sum(-1)
        indices = dists.argmin(1)
        z_q = self.codebook(indices).view(z_e.shape).permute(0, 2, 3, 1).contiguous()
        z_q = z_q.permute(0, 3, 1, 2)

        # Straight-through estimator
        z_q_st = z_e + (z_q - z_e).detach()

        x_recon = self.decoder(z_q_st)
        return x_recon, z_e, z_q


In [None]:
def vqvae_loss(x, x_recon, z_e, z_q, beta=0.25):
    recon_loss = F.mse_loss(x_recon, x)
    commit_loss = F.mse_loss(z_e.detach(), z_q)
    codebook_loss = F.mse_loss(z_e, z_q.detach())
    return recon_loss + codebook_loss + beta * commit_loss


In [None]:
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose, Normalize, Resize

transform = Compose([Resize(32), ToTensor(), Normalize([0.5]*3, [0.5]*3)])
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)


In [None]:
device = "cuda"

In [None]:
num_epochs = 50

In [None]:
from tqdm import tqdm

model = VQVAE().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    epoch_loss = 0
    pbar = tqdm(trainloader, desc=f"Epoch {epoch+1}", leave=False)

    for x, _ in pbar:
        x = x.to(device)
        x_recon, z_e, z_q = model(x)
        loss = vqvae_loss(x, x_recon, z_e, z_q)

        opt.zero_grad()
        loss.backward()
        opt.step()

        epoch_loss += loss.item()
        pbar.set_postfix(loss=loss.item())

    print(f"Epoch {epoch+1}: avg_loss = {epoch_loss / len(trainloader):.4f}")


In [None]:
import matplotlib.pyplot as plt

x, _ = next(iter(trainloader))
x = x.to(device)

with torch.no_grad():
    x_recon, _, _ = model(x)

def show_images(orig, recon, num=6):
    orig = orig[:num].cpu()
    recon = recon[:num].cpu()

    plt.figure(figsize=(num * 2, 4))
    for i in range(num):
        # Original
        plt.subplot(2, num, i + 1)
        plt.imshow(orig[i].permute(1, 2, 0).clamp(0, 1))
        plt.axis('off')
        if i == 0: plt.title("Original")

        # Reconstructed
        plt.subplot(2, num, num + i + 1)
        plt.imshow(recon[i].permute(1, 2, 0).clamp(0, 1))
        plt.axis('off')
        if i == 0: plt.title("Reconstructed")

    plt.tight_layout()
    plt.show()

show_images(x, x_recon)


In [None]:
with torch.no_grad():
    z_e = model.encoder(x)  # shape [B, D, H, W]
    z_flat = z_e.permute(0, 2, 3, 1).reshape(-1, model.codebook.embedding_dim)
    dists = torch.cdist(z_flat, model.codebook.weight)
    tokens = dists.argmin(dim=1).reshape(x.size(0), z_e.size(2), z_e.size(3))  # [B, H, W]
    token_seqs = tokens.view(x.size(0), -1)  # Flatten for transformer


In [None]:
import matplotlib.pyplot as plt

def visualize_token_grid(tokens, num_images=6):
    plt.figure(figsize=(num_images * 2, 2.5))
    for i in range(num_images):
        grid = tokens[i].cpu().numpy()  # shape [H, W]
        plt.subplot(1, num_images, i + 1)
        plt.imshow(grid, cmap='viridis')  # or 'plasma', 'gray'
        plt.title(f"Token Grid {i+1}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

visualize_token_grid(tokens, num_images=6)


In [None]:
def show_image_and_tokens(imgs, tokens, index=0):
    img = imgs[index].cpu().permute(1, 2, 0)  # [C, H, W] -> [H, W, C]
    img = (img * 0.5 + 0.5).clamp(0, 1)        # unnormalize

    token_grid = tokens[index].cpu().numpy()

    plt.figure(figsize=(6, 3))
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.title("Original Image")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(token_grid, cmap='viridis')
    plt.title("Token IDs")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

show_image_and_tokens(x, tokens, index=0)
