In [1]:
# CELL 0 - SELECT DEVICE & LOAD DEPENDENCIES

!pip install -q timm
import os
import json
import pickle
import torch
import random
import timm
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔥 Using device: {device}")


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
🔥 Using device: cuda


In [2]:
# Cell 1 - Encoder Only (decoder-ready, no projection head)


import timm
import torch
import torch.nn as nn
import torch.nn.functional as F

# ✅ Encoder that returns patch-wise feature map: (B, 64, C)
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(
            "efficientnetv2_s",
            pretrained=False,
            features_only=True
        )
        self.pool = nn.AdaptiveAvgPool2d((8, 8))  # ensures 8×8 spatial output

    def forward(self, x):
        x = self.backbone(x)[-1]              # (B, C, H, W) → final block
        x = self.pool(x)                      # (B, C, 8, 8)
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1)             # (B, 8, 8, C)
        x = x.reshape(B, H * W, C)            # (B, 64, C)
        return x

# Instantiate encoder first
encoder = Encoder().to(device)

# 🔍 Dynamically detect feature dim (C) from dummy input
with torch.no_grad():
    dummy = torch.randn(1, 3, 256, 256).to(device)
    out = encoder(dummy)
    feature_dim = out.shape[-1]  # usually 1280 for effnetv2_s
    print(f"🧪 Patch token shape: {out.shape}")  # (1, 64, 1280)

# 💾 Load encoder weights with fallback
try:
    encoder.load_state_dict(torch.load("encoder_epoch_50.pt", map_location=device), strict=True)
    print("✅ Encoder weights loaded with strict=True")
except RuntimeError as e:
    print("⚠️ Strict loading failed:", e)
    print("🔁 Retrying with strict=False")
    encoder.load_state_dict(torch.load("encoder_epoch_50.pt", map_location=device), strict=False)

print(f"🧠 Encoder ready (feature dim = {feature_dim})")



🧪 Patch token shape: torch.Size([1, 64, 256])
✅ Encoder weights loaded with strict=True
🧠 Encoder ready (feature dim = 256)


  encoder.load_state_dict(torch.load("encoder_epoch_50.pt", map_location=device), strict=True)


In [3]:
# CELL 2 - CAPTION PROCESSING


with open("captions_train2017.json", 'r') as f:
    annotations = json.load(f)["annotations"]

captions_dict = {}
for ann in annotations:
    img_id = ann["image_id"]
    cap = ann["caption"]
    captions_dict.setdefault(img_id, []).append(cap)

# Basic tokenizer & vocab
def tokenize(text):
    return text.lower().strip().split()

word_freq = {}
for caps in captions_dict.values():
    for cap in caps:
        for token in tokenize(cap):
            word_freq[token] = word_freq.get(token, 0) + 1

# Build vocab
vocab = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
for word, freq in word_freq.items():
    if freq >= 5:  # min frequency cutoff
        vocab[word] = len(vocab)

word2idx = vocab
idx2word = {idx: word for word, idx in vocab.items()}
vocab_size = len(vocab)

with open("vocab.pkl", "wb") as f:
    pickle.dump({"word2idx": word2idx, "idx2word": idx2word}, f)


In [4]:
# CELL 3 - DATASET + DATALOADER

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True  # ✅ allows loading broken JPEGs

class ImageCaptionDataset(Dataset):
    def __init__(self, folder, image2caption, word2idx, transform):
        self.folder = folder
        self.mapping = list(image2caption.items())
        self.word2idx = word2idx
        self.transform = transform
        self.fallback_count = 0  # 🧠 optional: track black image fallbacks

    def __getitem__(self, i):
        img_id, captions = self.mapping[i]
        img_path = os.path.join(self.folder, f"{img_id:012}.jpg")

        # ✅ Safe image load with grayscale/corrupt fallback
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            self.fallback_count += 1
            print(f"⚠️ Fallback: Could not load {img_path} — {e}")
            image = Image.new("RGB", (256, 256), color=(0, 0, 0))  # black dummy

        image = self.transform(image)

        # 🧾 Caption tokenization
        caption = ["<SOS>"] + random.choice(captions).lower().strip().split() + ["<EOS>"]
        tokens = [self.word2idx.get(w.strip(".,!?"), self.word2idx["<UNK>"]) for w in caption]
        tokens = tokens[:20] + [self.word2idx["<PAD>"]] * (20 - len(tokens))

        return image, torch.tensor(tokens)

    def __len__(self):
        return len(self.mapping)

# ✅ Standard transform (resize, normalize)
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# ✅ Instantiate dataset
dataset = ImageCaptionDataset("images", captions_dict, word2idx, transform)

# ✅ FINAL DataLoader with high batch size + PIL-safe settings
loader = DataLoader(
    dataset,
    batch_size=1024,      # 🧠 set to 1024 if you tested and confirmed it's safe
    shuffle=True,
    num_workers=0,       # ✅ Single-threaded (PIL-safe)
    pin_memory=False     # ✅ Avoid async GPU transfer issues
)


In [5]:

# CELL 4 - TRANSFORMER DECODER

class CaptionDecoder(nn.Module):
    def __init__(self, vocab_size, feature_dim, hidden_dim=512, num_layers=6, nhead=8, max_len=20):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embed = nn.Parameter(torch.randn(max_len, hidden_dim))
        self.img_proj = nn.Linear(feature_dim, hidden_dim)

        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=nhead, activation='gelu', batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, image_tokens, caption_tokens):
        B, T = caption_tokens.shape
        tgt = self.token_embed(caption_tokens) + self.pos_embed[:T]
        memory = self.img_proj(image_tokens)

        tgt_mask = nn.Transformer.generate_square_subsequent_mask(T).to(caption_tokens.device)
        return self.fc_out(self.decoder(tgt, memory, tgt_mask=tgt_mask))


In [6]:
# CELL 5 - ENCODER + DECODER 

class ImageCaptioningModel(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, images, captions):
        image_tokens = self.encoder(images)           # (B, 64, C)
        return self.decoder(image_tokens, captions)   # (B, T, vocab_size)


In [7]:
# CELL 6 - LOSS, OPTIMIZER & AMP 

PAD_ID = word2idx["<PAD>"]

decoder = CaptionDecoder(
    vocab_size=len(word2idx),
    feature_dim=encoder(torch.randn(1, 3, 256, 256).to(device)).shape[-1]
).to(device)

model = ImageCaptioningModel(encoder, decoder).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID)

optimizer = torch.optim.AdamW([
    {"params": model.encoder.parameters(), "lr": 1e-5},
    {"params": model.decoder.parameters(), "lr": 1e-4}
])

from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()




  scaler = GradScaler()


In [8]:
# CELL 7 - TRAINING LOOP

import os
from torch.cuda.amp import autocast, GradScaler

EPOCHS = 10
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

checkpoint_path = os.path.join(CHECKPOINT_DIR, "captioning_latest.pth")

# 🔁 Resume logic
start_epoch = 0
scaler = GradScaler()

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scaler.load_state_dict(checkpoint["scaler_state_dict"])
    start_epoch = checkpoint["epoch"]
    print(f"🔄 Resuming training from epoch {start_epoch}")

# 🏋️ Training loop
for epoch in range(start_epoch, EPOCHS):
    model.train()
    total_loss = 0.0

    for step, (images, captions) in enumerate(loader):
        images, captions = images.to(device), captions.to(device)
        optimizer.zero_grad()

        with autocast():
            outputs = model(images, captions[:, :-1])  # input
            loss = criterion(
                outputs.reshape(-1, outputs.size(-1)),
                captions[:, 1:].reshape(-1)
            )

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

        # ✅ Live GPU memory print (once per 5 steps)
        if step % 5 == 0:
            mem_gb = torch.cuda.memory_allocated() / 1024**3
            print(f"🧠 Step {step:03d} | Loss: {loss.item():.4f} | GPU Mem: {mem_gb:.2f} GB")

    avg_loss = total_loss / len(loader)
    print(f"\n✅ Epoch {epoch+1}/{EPOCHS} | Avg Loss: {avg_loss:.6f}\n")

    # 💾 Save checkpoint after each epoch
    ckpt = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'config': {
            'vocab_size': len(word2idx),
            'hidden_dim': 512,
            'num_heads': 8,
            'num_layers': 6
        }
    }
    torch.save(ckpt, checkpoint_path)
    print(f"💾 Checkpoint saved to {checkpoint_path}")


  scaler = GradScaler()
  with autocast():


🧠 Step 000 | Loss: 9.7175 | GPU Mem: 2.23 GB
🧠 Step 005 | Loss: 8.1538 | GPU Mem: 2.23 GB
🧠 Step 010 | Loss: 7.3898 | GPU Mem: 2.23 GB
🧠 Step 015 | Loss: 7.0409 | GPU Mem: 2.23 GB
🧠 Step 020 | Loss: 6.7757 | GPU Mem: 2.23 GB
🧠 Step 025 | Loss: 6.5165 | GPU Mem: 2.23 GB
🧠 Step 030 | Loss: 6.2754 | GPU Mem: 2.23 GB
🧠 Step 035 | Loss: 6.0958 | GPU Mem: 2.23 GB
🧠 Step 040 | Loss: 5.9553 | GPU Mem: 2.23 GB
🧠 Step 045 | Loss: 5.8152 | GPU Mem: 2.23 GB
🧠 Step 050 | Loss: 5.6288 | GPU Mem: 2.23 GB
🧠 Step 055 | Loss: 5.4940 | GPU Mem: 2.23 GB
🧠 Step 060 | Loss: 5.4055 | GPU Mem: 2.23 GB
🧠 Step 065 | Loss: 5.3146 | GPU Mem: 2.23 GB
🧠 Step 070 | Loss: 5.2387 | GPU Mem: 2.23 GB
🧠 Step 075 | Loss: 5.1555 | GPU Mem: 2.23 GB
🧠 Step 080 | Loss: 5.0369 | GPU Mem: 2.23 GB
🧠 Step 085 | Loss: 5.0171 | GPU Mem: 2.23 GB
🧠 Step 090 | Loss: 4.8977 | GPU Mem: 2.23 GB
🧠 Step 095 | Loss: 4.8328 | GPU Mem: 2.23 GB
🧠 Step 100 | Loss: 4.7934 | GPU Mem: 2.23 GB
🧠 Step 105 | Loss: 4.7507 | GPU Mem: 2.23 GB
🧠 Step 110

In [9]:
import torch
from model import Encoder  # use your actual Encoder class definition

# Load full training checkpoint
ckpt = torch.load("checkpoints/captioning_latest.pth", map_location="cpu")

# Rebuild encoder
encoder = Encoder()
encoder.load_state_dict(ckpt["model_state_dict"], strict=False)

# Save entire encoder object as .pt (not just weights)
torch.save(encoder, "encoder_epoch_52.pt")
print("✅ Saved fine-tuned encoder to encoder_epoch_52.pt")
