In [14]:
# notebooks/visualize_decoder.ipynb

# ----------------------------
# 1. Imports
# ----------------------------
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
import matplotlib.pyplot as plt
import sys
import os

sys.path.append(os.path.abspath("../src"))

from models.visual_decoder import VisualDecoder
from models.ac_predictor import VisionTransformerPredictorAC



  from .autonotebook import tqdm as notebook_tqdm


In [15]:
# ----------------------------
# 2. Config
# ----------------------------
BATCH_SIZE = 8
EPOCHS = 2
LR = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [16]:
# ----------------------------
# 3. Dataset (CIFAR-10 for demo)
# ----------------------------
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100.0%


Extracting ./data/cifar-10-python.tar.gz to ./data


OSError: [Errno 28] No space left on device

In [None]:
# ----------------------------
# 4. Encoder (frozen) + Decoder
# ----------------------------
encoder = Visual(pretrained=True).eval().to(DEVICE)
for p in encoder.parameters():
    p.requires_grad = False

decoder = VisualDecoder(latent_dim=1024, out_channels=3).to(DEVICE)  # adjust latent_dim if needed
criterion = nn.MSELoss()
optimizer = optim.Adam(decoder.parameters(), lr=LR)

In [None]:
# ----------------------------
# 5. Training Loop (demo, 2 epochs)
# ----------------------------
for epoch in range(EPOCHS):
    total_loss = 0
    for imgs, _ in loader:
        imgs = imgs.to(DEVICE)
        with torch.no_grad():
            feats = encoder(imgs)   # check shape, may need reshape/permute
        recons = decoder(feats)

        loss = criterion(recons, imgs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} Loss: {total_loss/len(loader):.4f}")

In [None]:
# ----------------------------
# 6. Visualization
# ----------------------------
def show_side_by_side(originals, recons, n=4):
    originals = originals[:n].cpu()
    recons = recons[:n].cpu()

    grid = torch.cat([originals, recons], dim=0)
    grid = vutils.make_grid(grid, nrow=n, normalize=True, scale_each=True)

    plt.figure(figsize=(12,6))
    plt.axis("off")
    plt.title("Top: Original | Bottom: Reconstruction")
    plt.imshow(grid.permute(1, 2, 0))
    plt.show()

# get a sample batch
imgs, _ = next(iter(loader))
imgs = imgs.to(DEVICE)
with torch.no_grad():
    feats = encoder(imgs)
    recons = decoder(feats)

show_side_by_side(imgs, recons)