In [3]:
import json
import os
import random
from collections import defaultdict

# Paths
DATA_DIR = '/kaggle/input/coco-2017-dataset/coco2017'
ANNOTATION_FILE = os.path.join(DATA_DIR, 'annotations', 'captions_train2017.json')
ANNOTATION_FILE2 = os.path.join(DATA_DIR, 'annotations', 'captions_val2017.json')
IMAGE_FOLDER = os.path.join(DATA_DIR, 'train2017')
IMAGE_FOLDER2 = os.path.join(DATA_DIR, 'val2017')

# Load annotations
with open(ANNOTATION_FILE, 'r') as f:
    annotations = json.load(f)

# Build a dictionary: image_id -> list of captions
captions_dict = defaultdict(list)
for ann in annotations['annotations']:
    image_id = ann['image_id']
    caption = ann['caption']
    captions_dict[image_id].append(caption)

# Check sample
# sample_image_id = list(captions_dict.keys())[0]
sample_image_id = random.choice(list(captions_dict.keys()))
print(f"Image ID: {sample_image_id}")
print("Captions:")
for cap in captions_dict[sample_image_id]:
    print("-", cap)

Image ID: 283627
Captions:
- A black dog with a red collar under a pink blanket.
- A dog asleep in a bed with a blanket over him 
- Black dog in a bed on a pillow under a pink blanket 
- a black lab lying in bed under covers
- A dog takes a nap under a blanket. 


In [4]:
import re
import nltk
nltk.download('punkt')  # for word_tokenize
from nltk.tokenize import word_tokenize

def clean_caption(caption):
    caption = caption.lower()                            # Lowercase
    caption = re.sub(r"[^a-z0-9\s]", "", caption)        # Remove punctuation
    caption = re.sub(r"\s+", " ", caption).strip()       # Trim extra spaces
    return caption

# Clean and tokenize all captions
cleaned_captions_dict = {}
for image_id, captions in captions_dict.items():
    cleaned_captions = []
    for cap in captions:
        clean_cap = clean_caption(cap)
        tokens = word_tokenize(clean_cap)
        cleaned_captions.append(tokens)
    cleaned_captions_dict[image_id] = cleaned_captions

# Check cleaned sample
print("Cleaned captions for image ID:", sample_image_id)
for cap in cleaned_captions_dict[sample_image_id]:
    print(cap)

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Cleaned captions for image ID: 283627
['a', 'black', 'dog', 'with', 'a', 'red', 'collar', 'under', 'a', 'pink', 'blanket']
['a', 'dog', 'asleep', 'in', 'a', 'bed', 'with', 'a', 'blanket', 'over', 'him']
['black', 'dog', 'in', 'a', 'bed', 'on', 'a', 'pillow', 'under', 'a', 'pink', 'blanket']
['a', 'black', 'lab', 'lying', 'in', 'bed', 'under', 'covers']
['a', 'dog', 'takes', 'a', 'nap', 'under', 'a', 'blanket']


In [5]:
import os 
# os.makedirs('/kaggle/working/models', exist_ok=True)

In [6]:
# open('/kaggle/working/models/__init__.py', 'a').close()

In [7]:
from collections import Counter

min_word_freq = 5  # You can tune this
word_freq = Counter()

# Count word frequencies
for captions in cleaned_captions_dict.values():
    for tokens in captions:
        word_freq.update(tokens)

# Filter words below the threshold
words = [word for word in word_freq if word_freq[word] >= min_word_freq]

# Special tokens
word_map = {
    '<pad>': 0,
    '<start>': 1,
    '<end>': 2,
    '<unk>': 3
}

# Add the remaining words
for i, word in enumerate(words, start=4):
    word_map[word] = i

# Reverse map
idx2word = {v: k for k, v in word_map.items()}

print(f"Vocabulary size: {len(word_map)}")
print("Sample word map entries:")
for i, (word, idx) in enumerate(list(word_map.items())[:10]):
    print(f"{word}: {idx}")

Vocabulary size: 10307
Sample word map entries:
<pad>: 0
<start>: 1
<end>: 2
<unk>: 3
a: 4
bicycle: 5
replica: 6
with: 7
clock: 8
as: 9


In [8]:
encoded_captions = {}

for image_id, captions in cleaned_captions_dict.items():
    encoded = []
    for tokens in captions:
        # Encode each word or use <unk> if not in vocab
        enc = [word_map.get(word, word_map['<unk>']) for word in tokens]
        # Add <start> and <end> tokens
        enc = [word_map['<start>']] + enc + [word_map['<end>']]
        encoded.append(enc)
    encoded_captions[image_id] = encoded

# Check sample
print("Encoded captions for image ID:", sample_image_id)
for cap in encoded_captions[sample_image_id]:
    print(cap)

Encoded captions for image ID: 283627
[1, 4, 16, 372, 7, 4, 89, 1490, 859, 4, 328, 1018, 2]
[1, 4, 372, 945, 20, 4, 779, 7, 4, 1018, 286, 780, 2]
[1, 16, 372, 20, 4, 779, 39, 4, 1184, 859, 4, 328, 1018, 2]
[1, 4, 16, 6338, 1054, 20, 779, 859, 2302, 2]
[1, 4, 372, 475, 4, 1439, 859, 4, 1018, 2]


In [9]:
import json

# Save encoded captions
with open('encoded_captions.json', 'w') as f:
    json.dump({str(k): v for k, v in encoded_captions.items()}, f)

# Save word map
with open('word_map.json', 'w') as f:
    json.dump(word_map, f)

In [10]:
!wget http://nlp.stanford.edu/data/glove.6B.zip
!unzip glove.6B.zip

--2025-04-20 16:25:20--  http://nlp.stanford.edu/data/glove.6B.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://nlp.stanford.edu/data/glove.6B.zip [following]
--2025-04-20 16:25:20--  https://nlp.stanford.edu/data/glove.6B.zip
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following]
--2025-04-20 16:25:21--  https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 862182613 (822M) [application/zip]
Saving to: ‘glove.6B.zip’


202

In [11]:
import numpy as np

# Path to GloVe 300d
glove_path = '/kaggle/working/glove.6B.300d.txt'
embedding_dim = 300
vocab_size = len(word_map)

# Load GloVe embeddings
print("Loading GloVe...")
glove = {}
with open(glove_path, 'r', encoding='utf-8') as f:
    for line in f:
        tokens = line.split()
        word = tokens[0]
        vec = np.array(tokens[1:], dtype=np.float32)
        glove[word] = vec

# Create embedding matrix
print("Building embedding matrix...")
embedding_matrix = np.random.uniform(-0.1, 0.1, (vocab_size, embedding_dim)).astype(np.float32)

for word, idx in word_map.items():
    if word in glove:
        embedding_matrix[idx] = glove[word]

print("Done. Shape:", embedding_matrix.shape)


Loading GloVe...
Building embedding matrix...
Done. Shape: (10307, 300)


In [12]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import random
import json
import os

class CaptionDataset(Dataset):
    def __init__(self, image_folder, encoded_captions_file, word_map_file, transform=None):
        # Load encoded captions and word map
        with open(encoded_captions_file, 'r') as j:
            self.captions = json.load(j)
        with open(word_map_file, 'r') as j:
            self.word_map = json.load(j)

        self.image_folder = image_folder
        self.image_ids = list(self.captions.keys())
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, index):
        image_id = self.image_ids[index]
        image_path = os.path.join(self.image_folder, f"{int(image_id):012}.jpg")
        
        # Load image
        img = Image.open(image_path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        # Randomly select one caption for the image
        caps = self.captions[image_id]
        caption = random.choice(caps)
        caption = torch.tensor(caption, dtype=torch.long)

        return img, caption

In [13]:
def caption_collate_fn(batch):
    """
    Custom collate function to handle batches of (image, caption) with variable-length captions.
    """
    images = []
    captions = []

    for img, cap in batch:
        images.append(img)
        captions.append(cap)

    # Stack images (they are all same size)
    images = torch.stack(images, dim=0)

    # Pad captions to the max length in the batch
    lengths = [len(cap) for cap in captions]
    max_len = max(lengths)
    padded_captions = torch.zeros(len(captions), max_len, dtype=torch.long)

    for i, cap in enumerate(captions):
        end = lengths[i]
        padded_captions[i, :end] = cap[:end]

    return images, padded_captions, torch.tensor(lengths)

In [14]:
from torch.utils.data import DataLoader
from torchvision import transforms

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # from ImageNet
                         std=[0.229, 0.224, 0.225])
])

# Dataset
dataset = CaptionDataset(
    image_folder='/kaggle/input/coco-2017-dataset/coco2017/train2017',
    encoded_captions_file='encoded_captions.json',
    word_map_file='word_map.json',
    transform=transform
)

# DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    collate_fn=caption_collate_fn
)

# Check sample batch
for images, captions, lengths in dataloader:
    print("Image batch shape:", images.shape)
    print("Caption batch shape:", captions.shape)
    print("Lengths:", lengths)
    break

Image batch shape: torch.Size([4, 3, 256, 256])
Caption batch shape: torch.Size([4, 13])
Lengths: tensor([10, 11, 12, 13])


In [15]:
# ----------- Process Validation Captions -----------
with open(ANNOTATION_FILE2, 'r') as f:
    val_annotations = json.load(f)

val_captions_dict = defaultdict(list)
for ann in val_annotations['annotations']:
    image_id = ann['image_id']
    caption = ann['caption']
    val_captions_dict[image_id].append(caption)

# Clean and tokenize validation captions
cleaned_val_captions_dict = {}
for image_id, captions in val_captions_dict.items():
    cleaned_captions = []
    for cap in captions:
        clean_cap = clean_caption(cap)
        tokens = word_tokenize(clean_cap)
        cleaned_captions.append(tokens)
    cleaned_val_captions_dict[image_id] = cleaned_captions

# Encode validation captions
encoded_val_captions = {}
for image_id, captions in cleaned_val_captions_dict.items():
    encoded = []
    for tokens in captions:
        enc = [word_map.get(word, word_map['<unk>']) for word in tokens]
        enc = [word_map['<start>']] + enc + [word_map['<end>']]
        encoded.append(enc)
    encoded_val_captions[image_id] = encoded

# Save encoded val captions
with open('encoded_captions_val.json', 'w') as f:
    json.dump({str(k): v for k, v in encoded_val_captions.items()}, f)

In [16]:
# Validation dataset
val_dataset = CaptionDataset(
    image_folder='/kaggle/input/coco-2017-dataset/coco2017/val2017',
    encoded_captions_file='encoded_captions_val.json',  # You need to create this
    word_map_file='word_map.json',
    transform=transform
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=caption_collate_fn
)

# Check sample batch
for images, captions, lengths in dataloader:
    print("Image batch shape:", images.shape)
    print("Caption batch shape:", captions.shape)
    print("Lengths:", lengths)
    break

Image batch shape: torch.Size([4, 3, 256, 256])
Caption batch shape: torch.Size([4, 13])
Lengths: tensor([11, 13, 10, 13])


## Attention Module

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # Linear layer to transform encoder's output
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # Linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # Combine them and produce scalar energy
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # Softmax over the pixels

    def forward(self, encoder_out, decoder_hidden):
        """
        encoder_out: encoded images, shape -> (batch_size, num_pixels, encoder_dim)
        decoder_hidden: previous decoder hidden state, shape -> (batch_size, decoder_dim)
        """
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)

        return attention_weighted_encoding, alpha

## Decoder with Attention

In [18]:
class DecoderWithAttention(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)  # attention network

        self.embedding = nn.Embedding(vocab_size, embed_dim)  # embedding layer
        self.embedding.weight.data.copy_(torch.from_numpy(embedding_matrix))
        self.embedding.weight.requires_grad = False  # Optional: freeze during training

        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # initialize hidden state
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # initialize cell state
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # create a gating scalar
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # output layer

        self.init_weights()  # initialize weights

    def init_weights(self):
        # self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        Forward propagation.
        :param encoder_out: encoded images, shape (batch_size, num_pixels, encoder_dim)
        :param encoded_captions: encoded captions, shape (batch_size, max_caption_length)
        :param caption_lengths: caption lengths, shape (batch_size, 1)
        :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
        """
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths
        # Corrected line
        caption_lengths, sort_ind = caption_lengths.sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        decode_lengths = caption_lengths - 1

        # Create tensors to hold word prediction scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(encoder_out.device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(encoder_out.device)

        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar
            attention_weighted_encoding = gate * attention_weighted_encoding

            input_lstm = torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1)
            h, c = self.decode_step(input_lstm, (h[:batch_size_t], c[:batch_size_t]))  # LSTM step
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

In [19]:
import torchvision.models as models

class Encoder(nn.Module):
    def __init__(self, encoded_image_size=14):
        super().__init__()
        self.cnn = models.resnet101(pretrained=True)
        # self.cnn = models.resnet101(weights=weights)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
        self.fine_tune(fine_tune=False)

    def forward(self, images):
        x = self.cnn.conv1(images)
        x = self.cnn.bn1(x)
        x = self.cnn.relu(x)
        x = self.cnn.maxpool(x)

        x = self.cnn.layer1(x)
        x = self.cnn.layer2(x)
        x = self.cnn.layer3(x)
        x = self.cnn.layer4(x)  # Shape: (batch_size, 2048, 7, 7)
        
        x = self.adaptive_pool(x)  # (batch_size, 2048, encoded_image_size, encoded_image_size)
        x = x.permute(0, 2, 3, 1)  # (batch_size, encoded_size, encoded_size, 2048)
        x = x.view(x.size(0), -1, x.size(-1))  # (batch_size, num_pixels=encoded_size^2, 2048)
        return x

    def fine_tune(self, fine_tune=True):
        for p in self.cnn.parameters():
            p.requires_grad = fine_tune

In [20]:
import os
import glob

def save_checkpoint(encoder, decoder, optimizer, epoch, train_loss, val_loss, word_map, 
                   checkpoint_dir, best_val_loss=float('inf'), is_best=False):
    """Save model checkpoint"""
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Save regular checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch + 1}.pth')
    
    checkpoint = {
        'epoch': epoch + 1,  # Save as next epoch to resume from
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'word_map': word_map,
        'best_val_loss': best_val_loss
    }
    
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")
    
    # Save best model separately if this is the best one
    if is_best:
        best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')
        torch.save(checkpoint, best_model_path)
        print(f"New best model saved with validation loss: {val_loss:.4f}")

def resume_from_checkpoint(checkpoint_path, encoder, decoder, optimizer, device):
    """Load checkpoint and resume training"""
    print(f"Loading checkpoint from {checkpoint_path}")
    
    # Load checkpoint on CPU to avoid GPU memory issues
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Load model states
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    
    # Move models to device after loading
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    
    # Load optimizer state
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # Return the starting epoch and best validation loss
    start_epoch = checkpoint['epoch']  # Continue from next epoch
    best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    
    print(f"Resuming from epoch {start_epoch} with best validation loss: {best_val_loss:.4f}")
    return encoder, decoder, optimizer, start_epoch, best_val_loss

In [21]:
# Test encoder-decoder integration
import torch
from torchvision.models import ResNet101_Weights

# Check if GPU is available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# encoder = Encoder(weights=ResNet101_Weights.IMAGENET1K_V1).to(device)
encoder = Encoder().to(device)
# encoder = Encoder(weights=ResNet101_Weights.DEFAULT).to(device)

decoder = DecoderWithAttention(
    attention_dim=512,
    embed_dim=300,
    decoder_dim=512,
    vocab_size=len(word_map),
    encoder_dim=2048,
    dropout=0.5
).to(device)

# Test forward pass
images, captions, lengths = next(iter(dataloader))
images = images.to(device)
captions = captions.to(device)

encoder_out = encoder(images)
# predictions, _, _, _, _ = decoder(encoder_out, captions, torch.tensor(lengths, device=device))
predictions, _, _, _, _ = decoder(encoder_out, captions, torch.tensor(lengths).clone().detach().to(device))

print("Encoder output shape:", encoder_out.shape)  # Should be (batch_size, 196, 2048)
print("Predictions shape:", predictions.shape)     # Should be (batch_size, max_len, vocab_size)

Using device: cuda


Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:00<00:00, 223MB/s] 
  predictions, _, _, _, _ = decoder(encoder_out, captions, torch.tensor(lengths).clone().detach().to(device))


Encoder output shape: torch.Size([4, 196, 2048])
Predictions shape: torch.Size([4, 16, 10307])


In [22]:
class MaskedCrossEntropyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.criterion = nn.CrossEntropyLoss(reduction='none', ignore_index=0)  # ignore <pad>

    def forward(self, predictions, targets, lengths):
        batch_size, max_len, vocab_size = predictions.shape

        predictions = predictions.view(-1, vocab_size)      # (batch_size * max_len, vocab_size)
        targets = targets.contiguous().view(-1)              # (batch_size * max_len)

        losses = self.criterion(predictions, targets)        # (batch_size * max_len)

        # Create mask
        mask = torch.arange(max_len).expand(batch_size, max_len).to(lengths.device)
        mask = (mask < lengths.unsqueeze(1)).float()         # (batch_size, max_len)
        mask = mask.view(-1)                                 # Flatten to (batch_size * max_len)

        losses = losses * mask
        return losses.sum() / mask.sum()


In [23]:
encoder = Encoder().to(device)
decoder = DecoderWithAttention(
    attention_dim=512,
    embed_dim=300,
    decoder_dim=512,
    vocab_size=len(word_map),
    encoder_dim=2048,
    dropout=0.5
).to(device)

# Only fine-tune the encoder's adaptive pool layer
encoder_params = list(encoder.adaptive_pool.parameters()) + list(encoder.cnn.layer4.parameters())
decoder_params = decoder.parameters()

optimizer = torch.optim.Adam(
    params=[
        {'params': encoder_params, 'lr': 1e-4},  # Lower LR for encoder
        {'params': decoder_params, 'lr': 4e-4}    # Higher LR for decoder
    ],
    weight_decay=1e-5
)

In [24]:
def train_epoch(encoder, decoder, dataloader, criterion, optimizer, device, grad_clip=5.0):
    encoder.train()
    decoder.train()
    total_loss = 0
    
    for i, (images, captions, lengths) in enumerate(dataloader):
        images = images.to(device)
        captions = captions.to(device)
        # lengths = torch.tensor(lengths).to(device)
        lengths_tensor = torch.tensor(lengths).to(device)

        # Forward pass
        encoder_out = encoder(images)
        predictions, _, decode_lengths, _, _ = decoder(encoder_out, captions, lengths)
        
        # Remove <start> token and truncate to actual lengths
        targets = captions[:, 1:]  # (batch_size, max_len-1)
        predictions = predictions[:, :max(decode_lengths), :]  # (batch_size, actual_max_len, vocab_size)

        # Calculate loss
        loss = criterion(predictions, targets, lengths_tensor)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), grad_clip)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), grad_clip)
        
        optimizer.step()
        
        total_loss += loss.item()
        
        if i % 100 == 0:
            print(f"Batch [{i}/{len(dataloader)}] Loss: {loss.item():.4f}")
    
    return total_loss / len(dataloader)

In [25]:
def validate(encoder, decoder, val_loader, criterion, device):
    encoder.eval()
    decoder.eval()
    total_loss = 0
    
    with torch.no_grad():
        for images, captions, lengths in val_loader:
            images = images.to(device)
            captions = captions.to(device)
            # lengths = torch.tensor(lengths).to(device)
            lengths_tensor = torch.tensor(lengths).to(device)

            
            encoder_out = encoder(images)
            predictions, _, decode_lengths, _, _ = decoder(encoder_out, captions, lengths)
            
            targets = captions[:, 1:]
            predictions = predictions[:, :max(decode_lengths), :]
            
            loss = criterion(predictions, targets, lengths_tensor)
            total_loss += loss.item()
    
    return total_loss / len(val_loader)

In [26]:
import os
import glob

def save_checkpoint(encoder, decoder, optimizer, epoch, train_loss, val_loss, word_map, 
                   checkpoint_dir, best_val_loss=float('inf'), is_best=False):
    """Save model checkpoint"""
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Save regular checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch + 1}.pth')
    
    checkpoint = {
        'epoch': epoch + 1,  # Save as next epoch to resume from
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'word_map': word_map,
        'best_val_loss': best_val_loss
    }
    
    try:
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")
        
        # Save best model separately if this is the best one
        if is_best:
            best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')
            torch.save(checkpoint, best_model_path)
            print(f"New best model saved with validation loss: {val_loss:.4f}")
            
        # Keep only the last 3 checkpoints to save disk space
        checkpoint_files = sorted(glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pth')))
        if len(checkpoint_files) > 3:
            for old_checkpoint in checkpoint_files[:-3]:
                os.remove(old_checkpoint)
                
    except Exception as e:
        print(f"Error saving checkpoint: {e}")
        # Try alternative save location
        torch.save(checkpoint, '/kaggle/working/emergency_checkpoint.pth')

def resume_from_checkpoint(checkpoint_path, encoder, decoder, optimizer, device):
    """Load checkpoint and resume training"""
    print(f"Loading checkpoint from {checkpoint_path}")
    
    # Load checkpoint on CPU to avoid GPU memory issues
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Load model states
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    
    # Move models to device after loading
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    
    # Load optimizer state
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # Move optimizer state to device
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)
    
    # Return the starting epoch and best validation loss
    start_epoch = checkpoint['epoch']  # Continue from next epoch
    best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    
    print(f"Resuming from epoch {start_epoch} with best validation loss: {best_val_loss:.4f}")
    return encoder, decoder, optimizer, start_epoch, best_val_loss

In [27]:
# Initialize components
criterion = MaskedCrossEntropyLoss().to(device)
num_epochs = 1  # For initial test

# Quick test with 1 batch
test_images, test_captions, test_lengths = next(iter(dataloader))
test_images = test_images.to(device)
test_captions = test_captions.to(device)
# test_lengths = torch.tensor(test_lengths).to(device)
test_lengths = test_lengths.clone().detach().to(device)


# Forward test
encoder_out = encoder(test_images)
predictions, _, decode_lengths, _, _ = decoder(encoder_out, test_captions, test_lengths)
targets = test_captions[:, 1:]

# Convert decode_lengths to tensor
# decode_lengths = torch.tensor(decode_lengths).to(device)

loss = criterion(predictions, targets, decode_lengths)
# decode_lengths = torch.tensor(decode_lengths).to(device)
# loss = criterion(predictions, targets, decode_lengths)

print(f"Initial loss: {loss.item():.4f}")  # Should be ~log(vocab_size) = ~9.2 for vocab_size=10307
optimizer.step()  # Verify backprop works without errors

Initial loss: 8.3699


In [28]:
# Define checkpoint directory
checkpoint_dir = '/kaggle/working/checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Find latest checkpoint if it exists
latest_checkpoint = None
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pth'))
if checkpoint_files:
    latest_checkpoint = max(checkpoint_files, key=os.path.getctime)

# Resume from checkpoint if available
start_epoch = 0
best_val_loss = float('inf')

if latest_checkpoint:
    encoder, decoder, optimizer, start_epoch, best_val_loss = resume_from_checkpoint(
        latest_checkpoint, encoder, decoder, optimizer, device
    )

num_epochs = 1  # Set your desired number of epochs

# Training loop
for epoch in range(start_epoch, num_epochs):
    print(f"\n--- Epoch {epoch + 1}/{num_epochs} ---")
    
    # Train for one epoch
    train_loss = train_epoch(
        encoder, decoder, dataloader,
        criterion, optimizer, device
    )
    print(f"Train Loss: {train_loss:.4f}")
    
    # Validate
    val_loss = validate(
        encoder, decoder, val_dataloader,
        criterion, device
    )
    print(f"Validation Loss: {val_loss:.4f}")
    
    # Check if this is the best model
    is_best = val_loss < best_val_loss
    if is_best:
        best_val_loss = val_loss
    
    # Save checkpoint
    save_checkpoint(
        encoder, decoder, optimizer,
        epoch, train_loss, val_loss,
        word_map, checkpoint_dir,
        best_val_loss, is_best
    )

# Mark training as complete
with open(os.path.join(checkpoint_dir, 'TRAINING_COMPLETE'), 'w') as f:
    f.write('Training completed successfully')


--- Epoch 1/1 ---


  lengths_tensor = torch.tensor(lengths).to(device)


Batch [0/29572] Loss: 8.7839
Batch [100/29572] Loss: 6.0670
Batch [200/29572] Loss: 5.7006
Batch [300/29572] Loss: 5.1492
Batch [400/29572] Loss: 4.5294
Batch [500/29572] Loss: 4.3419
Batch [600/29572] Loss: 5.4314
Batch [700/29572] Loss: 5.2734
Batch [800/29572] Loss: 5.3129
Batch [900/29572] Loss: 5.0946
Batch [1000/29572] Loss: 5.5115
Batch [1100/29572] Loss: 5.1531
Batch [1200/29572] Loss: 4.2772
Batch [1300/29572] Loss: 5.8138
Batch [1400/29572] Loss: 4.6566
Batch [1500/29572] Loss: 5.5000
Batch [1600/29572] Loss: 5.3487
Batch [1700/29572] Loss: 5.8966
Batch [1800/29572] Loss: 4.5530
Batch [1900/29572] Loss: 4.4664
Batch [2000/29572] Loss: 4.7916
Batch [2100/29572] Loss: 4.8951
Batch [2200/29572] Loss: 4.8632
Batch [2300/29572] Loss: 5.1999
Batch [2400/29572] Loss: 4.2018
Batch [2500/29572] Loss: 5.5036
Batch [2600/29572] Loss: 4.8867
Batch [2700/29572] Loss: 5.1693
Batch [2800/29572] Loss: 5.4849
Batch [2900/29572] Loss: 5.2772
Batch [3000/29572] Loss: 5.0046
Batch [3100/29572] L

  lengths_tensor = torch.tensor(lengths).to(device)


Validation Loss: 4.9426
Checkpoint saved: /kaggle/working/checkpoints/checkpoint_epoch_1.pth
New best model saved with validation loss: 4.9426


In [29]:
# for epoch in range(num_epochs):
#     print(f"\n--- Epoch {epoch + 1} ---")

#     train_loss = train_epoch(
#         encoder, decoder, dataloader,
#         criterion, optimizer, device
#     )
#     print(f"Train Loss: {train_loss:.4f}")

#     val_loss = validate(
#         encoder, decoder, val_dataloader,
#         criterion, device
#     )
#     print(f"Validation Loss: {val_loss:.4f}")

#     # Save the model after each epoch
#     torch.save({
#         'epoch': epoch,
#         'encoder_state_dict': encoder.state_dict(),
#         'decoder_state_dict': decoder.state_dict(),
#         'optimizer_state_dict': optimizer.state_dict(),
#         'loss': train_loss
#     }, f'model_epoch_{epoch + 1}.pth')

In [37]:
checkpoint_path = '/kaggle/working/checkpoint.pth'

def save_checkpoint(state, filename=checkpoint_path):
    torch.save(state, filename)

In [31]:
def load_checkpoint(checkpoint_path, encoder, decoder, optimizer):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

In [32]:
!pip install gensim==4.3.2  # Uses pre-built wheels, avoiding compilation

Collecting gensim==4.3.2
  Downloading gensim-4.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)
Downloading gensim-4.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.5/26.5 MB[0m [31m46.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hInstalling collected packages: gensim
  Attempting uninstall: gensim
    Found existing installation: gensim 4.3.3
    Uninstalling gensim-4.3.3:
      Successfully uninstalled gensim-4.3.3
Successfully installed gensim-4.3.2


In [33]:
# First, install all required dependencies
!pip install numpy scipy scikit-learn pandas gensim==4.3.2 theano

# Install nlg-eval from GitHub (bypass PyPI)
!pip install git+https://github.com/Maluuba/nlg-eval.git --no-deps

Collecting git+https://github.com/Maluuba/nlg-eval.git
  Cloning https://github.com/Maluuba/nlg-eval.git to /tmp/pip-req-build-w2_2x4hz
  Running command git clone --filter=blob:none --quiet https://github.com/Maluuba/nlg-eval.git /tmp/pip-req-build-w2_2x4hz
  Resolved https://github.com/Maluuba/nlg-eval.git to commit 2ab4528fad5548315cf61e40c2249fec8c8ad233
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: nlg-eval
  Building wheel for nlg-eval (setup.py) ... [?25l[?25hdone
  Created wheel for nlg-eval: filename=nlg_eval-2.4.1-py3-none-any.whl size=98924372 sha256=ff6bcb59a111edc2c2b619eb17d1393ea61f7c3a634977fffaacdd764d264c7e
  Stored in directory: /tmp/pip-ephem-wheel-cache-u6uhjt5d/wheels/89/06/a3/78b62739ab38973883fc8239cfbc41cbf08643e105ddd745d8
Successfully built nlg-eval
Installing collected packages: nlg-eval
Successfully installed nlg-eval-2.4.1


In [34]:
!pip install --no-deps git+https://github.com/Theano/Theano.git@adfe319ce6b781083d8dc3200fb4481b00853791

Collecting git+https://github.com/Theano/Theano.git@adfe319ce6b781083d8dc3200fb4481b00853791
  Cloning https://github.com/Theano/Theano.git (to revision adfe319ce6b781083d8dc3200fb4481b00853791) to /tmp/pip-req-build-da6wcw5c
  Running command git clone --filter=blob:none --quiet https://github.com/Theano/Theano.git /tmp/pip-req-build-da6wcw5c
  Running command git rev-parse -q --verify 'sha^adfe319ce6b781083d8dc3200fb4481b00853791'
  Running command git fetch -q https://github.com/Theano/Theano.git adfe319ce6b781083d8dc3200fb4481b00853791
  Running command git checkout -q adfe319ce6b781083d8dc3200fb4481b00853791
  Resolved https://github.com/Theano/Theano.git to commit adfe319ce6b781083d8dc3200fb4481b00853791
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: Theano
  Building wheel for Theano (setup.py) ... [?25l[?25hdone
  Created wheel for Theano: filename=Theano-0.9.0.dev1-py3-none-any.whl size=2762815 sha256=b92fa48a4a5a2ea9b7ea2d1cee0a

In [35]:
!pip install git+https://github.com/Maluuba/nlg-eval.git

Collecting git+https://github.com/Maluuba/nlg-eval.git
  Cloning https://github.com/Maluuba/nlg-eval.git to /tmp/pip-req-build-pq8riq6w
  Running command git clone --filter=blob:none --quiet https://github.com/Maluuba/nlg-eval.git /tmp/pip-req-build-pq8riq6w
  Resolved https://github.com/Maluuba/nlg-eval.git to commit 2ab4528fad5548315cf61e40c2249fec8c8ad233
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nltk>=3.4.5 (from nlg-eval==2.4.1)
  Downloading nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Collecting gensim~=3.8.3 (from nlg-eval==2.4.1)
  Downloading gensim-3.8.3.tar.gz (23.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.4/23.4 MB[0m [31m78.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0mm
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting xdg (from nlg-eval==2.4.1)
  Downloading xdg-6.0.0-py3-none-any.whl.metadata (1.3 kB)
Downloading nltk-3.9.1-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [36]:
from nlgeval import NLGEval
nlgeval = NLGEval()  # Should work without errors

ImportError: cannot import name 'MutableMapping' from 'collections' (/usr/lib/python3.10/collections/__init__.py)

In [None]:
# Load the best model for evaluation
best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')
if os.path.exists(best_model_path):
    print(f"Loading best model from {best_model_path}")
    checkpoint = torch.load(best_model_path, map_location=device)
else:
    # Fall back to latest checkpoint if no best model exists
    latest_checkpoint = max(glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pth')), 
                           key=os.path.getctime)
    print(f"No best model found, loading latest checkpoint: {latest_checkpoint}")
    checkpoint = torch.load(latest_checkpoint, map_location=device)

encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])

# Set models to evaluation mode
encoder.eval()
decoder.eval()

In [None]:
# checkpoint = torch.load('/kaggle/working/model_epoch_1.pth', map_location=device, weights_only=True)

# encoder.load_state_dict(checkpoint['encoder_state_dict'])
# decoder.load_state_dict(checkpoint['decoder_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# start_epoch = checkpoint['epoch'] + 1
# print(f"Resuming from epoch {start_epoch}")

# !pip install git+https://github.com/Maluuba/nlg-eval.git

In [None]:
pip install git+https://github.com/Maluuba/nlg-eval.git@2ab4528fad554831

In [None]:
!pip install nltk rouge pycocoevalcap

In [None]:
!conda install -y gensim

In [None]:
from nlgeval import NLGEval
nlg = NLGEval(no_skipthoughts=True, no_glove=True)

references = []
hypotheses = []

encoder.eval()
decoder.eval()

with torch.no_grad():
    for images, captions, lengths in val_dataloader:
        images = images.to(device)
        encoder_out = encoder(images)

        for i in range(images.size(0)):
            img_enc = encoder_out[i].unsqueeze(0)
            generated_caption = generate_caption(decoder, img_enc)
            hypotheses.append(generated_caption)

            # Use first reference caption
            ref_tokens = [idx2word[idx] for idx in captions[i].tolist() if idx not in {0, 1, 2, 3}]
            references.append([' '.join(ref_tokens)])

# Evaluate
metrics = nlg.evaluate(hypotheses, references)
print(metrics)

In [None]:
def generate_caption(decoder, encoder_out, max_len=20):
    vocab_size = decoder.vocab_size
    decoder.eval()

    h, c = decoder.init_hidden_state(encoder_out)
    encoder_out = encoder_out.view(1, -1, encoder_out.size(-1))

    word_map_rev = {v: k for k, v in word_map.items()}

    word = torch.tensor([word_map['<start>']]).to(device)
    caption = []
    
    for _ in range(max_len):
        embeddings = decoder.embedding(word).unsqueeze(0)  # (1, 1, embed_dim)
        awe, _ = decoder.attention(encoder_out, h)  # (1, encoder_dim)
        gate = decoder.sigmoid(decoder.f_beta(h))
        awe = gate * awe

        h, c = decoder.decode_step(torch.cat([embeddings.squeeze(1), awe], dim=1), (h, c))
        preds = decoder.fc(h)
        word = preds.argmax(1)

        predicted_word = word.item()
        if predicted_word == word_map['<end>']:
            break
        caption.append(word_map_rev.get(predicted_word, '<unk>'))

    return ' '.join(caption)

## Visualizing Attention Maps

In [None]:
def generate_caption_with_attention(decoder, encoder_out, word_map, max_len=20):
    decoder.eval()

    h, c = decoder.init_hidden_state(encoder_out)
    encoder_out = encoder_out.view(1, -1, encoder_out.size(-1))
    word = torch.tensor([word_map['<start>']]).to(device)

    rev_word_map = {v: k for k, v in word_map.items()}

    caption = []
    alphas = []

    for _ in range(max_len):
        embeddings = decoder.embedding(word).unsqueeze(0)  # (1, 1, embed_dim)
        awe, alpha = decoder.attention(encoder_out, h)
        gate = decoder.sigmoid(decoder.f_beta(h))
        awe = gate * awe

        h, c = decoder.decode_step(torch.cat([embeddings.squeeze(1), awe], dim=1), (h, c))
        preds = decoder.fc(h)
        word = preds.argmax(1)

        predicted_word = word.item()
        if predicted_word == word_map['<end>']:
            break

        caption.append(rev_word_map.get(predicted_word, '<unk>'))
        alphas.append(alpha.cpu().detach().numpy())

    return caption, alphas

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

def visualize_attention(image_path, caption, alphas, smooth=True):
    image = Image.open(image_path).convert("RGB")
    image = image.resize([224, 224], Image.LANCZOS)

    plt.figure(figsize=(15, 15))
    for t in range(len(caption)):
        plt.subplot(np.ceil(len(caption) / 5.), 5, t + 1)

        plt.text(0, 1, '%s' % caption[t], color='black', backgroundcolor='white', fontsize=12)
        plt.imshow(image)

        alpha = alphas[t].reshape(14, 14)  # attention is 14x14 from ResNet
        if smooth:
            import cv2
            alpha = cv2.GaussianBlur(alpha, (5, 5), 0)

        plt.imshow(alpha, alpha=0.6, extent=(0, 224, 224, 0), cmap='viridis')
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
# Load image
image_path = "/kaggle/input/coco-2017-dataset/coco2017/val2017/000000391895.jpg"
image = Image.open(image_path).convert("RGB")

# Transform image
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0).to(device)

# Encode image
encoder_out = encoder(image_tensor)
caption, alphas = generate_caption_with_attention(decoder, encoder_out, word_map)

# Visualize
visualize_attention(image_path, caption, alphas)