In [49]:
import json
import os
from collections import defaultdict

# Paths
DATA_DIR = '/kaggle/input/coco-2017-dataset/coco2017'
ANNOTATION_FILE = os.path.join(DATA_DIR, 'annotations', 'captions_train2017.json')
ANNOTATION_FILE2 = os.path.join(DATA_DIR, 'annotations', 'captions_val2017.json')
IMAGE_FOLDER = os.path.join(DATA_DIR, 'train2017')
IMAGE_FOLDER2 = os.path.join(DATA_DIR, 'val2017')

# Load annotations
with open(ANNOTATION_FILE, 'r') as f:
    annotations = json.load(f)

# Build a dictionary: image_id -> list of captions
captions_dict = defaultdict(list)
for ann in annotations['annotations']:
    image_id = ann['image_id']
    caption = ann['caption']
    captions_dict[image_id].append(caption)

# Check sample
sample_image_id = list(captions_dict.keys())[0]
print(f"Image ID: {sample_image_id}")
print("Captions:")
for cap in captions_dict[sample_image_id]:
    print("-", cap)

Image ID: 203564
Captions:
- A bicycle replica with a clock as the front wheel.
- The bike has a clock as a tire.
- A black metal bicycle with a clock inside the front wheel.
- A bicycle figurine in which the front wheel is replaced with a clock

- A clock with the appearance of the wheel of a bicycle 


In [50]:
import re
import nltk
nltk.download('punkt')  # for word_tokenize
from nltk.tokenize import word_tokenize

def clean_caption(caption):
    caption = caption.lower()                            # Lowercase
    caption = re.sub(r"[^a-z0-9\s]", "", caption)        # Remove punctuation
    caption = re.sub(r"\s+", " ", caption).strip()       # Trim extra spaces
    return caption

# Clean and tokenize all captions
cleaned_captions_dict = {}
for image_id, captions in captions_dict.items():
    cleaned_captions = []
    for cap in captions:
        clean_cap = clean_caption(cap)
        tokens = word_tokenize(clean_cap)
        cleaned_captions.append(tokens)
    cleaned_captions_dict[image_id] = cleaned_captions

# Check cleaned sample
print("Cleaned captions for image ID:", sample_image_id)
for cap in cleaned_captions_dict[sample_image_id]:
    print(cap)

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Cleaned captions for image ID: 203564
['a', 'bicycle', 'replica', 'with', 'a', 'clock', 'as', 'the', 'front', 'wheel']
['the', 'bike', 'has', 'a', 'clock', 'as', 'a', 'tire']
['a', 'black', 'metal', 'bicycle', 'with', 'a', 'clock', 'inside', 'the', 'front', 'wheel']
['a', 'bicycle', 'figurine', 'in', 'which', 'the', 'front', 'wheel', 'is', 'replaced', 'with', 'a', 'clock']
['a', 'clock', 'with', 'the', 'appearance', 'of', 'the', 'wheel', 'of', 'a', 'bicycle']


In [51]:
from collections import Counter

min_word_freq = 5  # You can tune this
word_freq = Counter()

# Count word frequencies
for captions in cleaned_captions_dict.values():
    for tokens in captions:
        word_freq.update(tokens)

# Filter words below the threshold
words = [word for word in word_freq if word_freq[word] >= min_word_freq]

# Special tokens
word_map = {
    '<pad>': 0,
    '<start>': 1,
    '<end>': 2,
    '<unk>': 3
}

# Add the remaining words
for i, word in enumerate(words, start=4):
    word_map[word] = i

# Reverse map
idx2word = {v: k for k, v in word_map.items()}

print(f"Vocabulary size: {len(word_map)}")
print("Sample word map entries:")
for i, (word, idx) in enumerate(list(word_map.items())[:10]):
    print(f"{word}: {idx}")

Vocabulary size: 10307
Sample word map entries:
<pad>: 0
<start>: 1
<end>: 2
<unk>: 3
a: 4
bicycle: 5
replica: 6
with: 7
clock: 8
as: 9


In [52]:
encoded_captions = {}

for image_id, captions in cleaned_captions_dict.items():
    encoded = []
    for tokens in captions:
        # Encode each word or use <unk> if not in vocab
        enc = [word_map.get(word, word_map['<unk>']) for word in tokens]
        # Add <start> and <end> tokens
        enc = [word_map['<start>']] + enc + [word_map['<end>']]
        encoded.append(enc)
    encoded_captions[image_id] = encoded

# Check sample
print("Encoded captions for image ID:", sample_image_id)
for cap in encoded_captions[sample_image_id]:
    print(cap)

Encoded captions for image ID: 203564
[1, 4, 5, 6, 7, 4, 8, 9, 10, 11, 12, 2]
[1, 10, 13, 14, 4, 8, 9, 4, 15, 2]
[1, 4, 16, 17, 5, 7, 4, 8, 18, 10, 11, 12, 2]
[1, 4, 5, 19, 20, 21, 10, 11, 12, 22, 23, 7, 4, 8, 2]
[1, 4, 8, 7, 10, 24, 25, 10, 12, 25, 4, 5, 2]


In [53]:
import json

# Save encoded captions
with open('encoded_captions.json', 'w') as f:
    json.dump({str(k): v for k, v in encoded_captions.items()}, f)

# Save word map
with open('word_map.json', 'w') as f:
    json.dump(word_map, f)

In [54]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import random
import json
import os

class CaptionDataset(Dataset):
    def __init__(self, image_folder, encoded_captions_file, word_map_file, transform=None):
        # Load encoded captions and word map
        with open(encoded_captions_file, 'r') as j:
            self.captions = json.load(j)
        with open(word_map_file, 'r') as j:
            self.word_map = json.load(j)

        self.image_folder = image_folder
        self.image_ids = list(self.captions.keys())
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, index):
        image_id = self.image_ids[index]
        image_path = os.path.join(self.image_folder, f"{int(image_id):012}.jpg")
        
        # Load image
        img = Image.open(image_path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        # Randomly select one caption for the image
        caps = self.captions[image_id]
        caption = random.choice(caps)
        caption = torch.tensor(caption, dtype=torch.long)

        return img, caption

In [55]:
def caption_collate_fn(batch):
    """
    Custom collate function to handle batches of (image, caption) with variable-length captions.
    """
    images = []
    captions = []

    for img, cap in batch:
        images.append(img)
        captions.append(cap)

    # Stack images (they are all same size)
    images = torch.stack(images, dim=0)

    # Pad captions to the max length in the batch
    lengths = [len(cap) for cap in captions]
    max_len = max(lengths)
    padded_captions = torch.zeros(len(captions), max_len, dtype=torch.long)

    for i, cap in enumerate(captions):
        end = lengths[i]
        padded_captions[i, :end] = cap[:end]

    return images, padded_captions, torch.tensor(lengths)

In [56]:
from torch.utils.data import DataLoader
from torchvision import transforms

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # from ImageNet
                         std=[0.229, 0.224, 0.225])
])

# Dataset
dataset = CaptionDataset(
    image_folder='/kaggle/input/coco-2017-dataset/coco2017/train2017',
    encoded_captions_file='encoded_captions.json',
    word_map_file='word_map.json',
    transform=transform
)

# DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    collate_fn=caption_collate_fn
)

# Check sample batch
for images, captions, lengths in dataloader:
    print("Image batch shape:", images.shape)
    print("Caption batch shape:", captions.shape)
    print("Lengths:", lengths)
    break

Image batch shape: torch.Size([4, 3, 256, 256])
Caption batch shape: torch.Size([4, 13])
Lengths: tensor([12, 13, 12, 11])


In [57]:
# ----------- Process Validation Captions -----------
with open(ANNOTATION_FILE2, 'r') as f:
    val_annotations = json.load(f)

val_captions_dict = defaultdict(list)
for ann in val_annotations['annotations']:
    image_id = ann['image_id']
    caption = ann['caption']
    val_captions_dict[image_id].append(caption)

# Clean and tokenize validation captions
cleaned_val_captions_dict = {}
for image_id, captions in val_captions_dict.items():
    cleaned_captions = []
    for cap in captions:
        clean_cap = clean_caption(cap)
        tokens = word_tokenize(clean_cap)
        cleaned_captions.append(tokens)
    cleaned_val_captions_dict[image_id] = cleaned_captions

# Encode validation captions
encoded_val_captions = {}
for image_id, captions in cleaned_val_captions_dict.items():
    encoded = []
    for tokens in captions:
        enc = [word_map.get(word, word_map['<unk>']) for word in tokens]
        enc = [word_map['<start>']] + enc + [word_map['<end>']]
        encoded.append(enc)
    encoded_val_captions[image_id] = encoded

# Save encoded val captions
with open('encoded_captions_val.json', 'w') as f:
    json.dump({str(k): v for k, v in encoded_val_captions.items()}, f)

In [58]:
# Validation dataset
val_dataset = CaptionDataset(
    image_folder='/kaggle/input/coco-2017-dataset/coco2017/val2017',
    encoded_captions_file='encoded_captions_val.json',  # You need to create this
    word_map_file='word_map.json',
    transform=transform
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=caption_collate_fn
)

# Check sample batch
for images, captions, lengths in dataloader:
    print("Image batch shape:", images.shape)
    print("Caption batch shape:", captions.shape)
    print("Lengths:", lengths)
    break

Image batch shape: torch.Size([4, 3, 256, 256])
Caption batch shape: torch.Size([4, 15])
Lengths: tensor([12, 12, 15, 12])


## Attention Module

In [59]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # Linear layer to transform encoder's output
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # Linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # Combine them and produce scalar energy
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # Softmax over the pixels

    def forward(self, encoder_out, decoder_hidden):
        """
        encoder_out: encoded images, shape -> (batch_size, num_pixels, encoder_dim)
        decoder_hidden: previous decoder hidden state, shape -> (batch_size, decoder_dim)
        """
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)

        return attention_weighted_encoding, alpha

## Decoder with Attention

In [60]:
class DecoderWithAttention(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)  # attention network

        self.embedding = nn.Embedding(vocab_size, embed_dim)  # embedding layer
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # initialize hidden state
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # initialize cell state
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # create a gating scalar
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # output layer

        self.init_weights()  # initialize weights

    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        Forward propagation.
        :param encoder_out: encoded images, shape (batch_size, num_pixels, encoder_dim)
        :param encoded_captions: encoded captions, shape (batch_size, max_caption_length)
        :param caption_lengths: caption lengths, shape (batch_size, 1)
        :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
        """
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths
        # Corrected line
        caption_lengths, sort_ind = caption_lengths.sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word prediction scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(encoder_out.device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(encoder_out.device)

        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar
            attention_weighted_encoding = gate * attention_weighted_encoding

            input_lstm = torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1)
            h, c = self.decode_step(input_lstm, (h[:batch_size_t], c[:batch_size_t]))  # LSTM step
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

In [61]:
import torchvision.models as models

class Encoder(nn.Module):
    def __init__(self, encoded_image_size=14):
        super().__init__()
        self.cnn = models.resnet101(pretrained=True)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
        self.fine_tune(fine_tune=False)

    def forward(self, images):
        x = self.cnn.conv1(images)
        x = self.cnn.bn1(x)
        x = self.cnn.relu(x)
        x = self.cnn.maxpool(x)

        x = self.cnn.layer1(x)
        x = self.cnn.layer2(x)
        x = self.cnn.layer3(x)
        x = self.cnn.layer4(x)  # Shape: (batch_size, 2048, 7, 7)
        
        x = self.adaptive_pool(x)  # (batch_size, 2048, encoded_image_size, encoded_image_size)
        x = x.permute(0, 2, 3, 1)  # (batch_size, encoded_size, encoded_size, 2048)
        x = x.view(x.size(0), -1, x.size(-1))  # (batch_size, num_pixels=encoded_size^2, 2048)
        return x

    def fine_tune(self, fine_tune=True):
        for p in self.cnn.parameters():
            p.requires_grad = fine_tune

In [62]:
# Test encoder-decoder integration
import torch

# Check if GPU is available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

encoder = Encoder().to(device)
decoder = DecoderWithAttention(
    attention_dim=512,
    embed_dim=512,
    decoder_dim=512,
    vocab_size=len(word_map),
    encoder_dim=2048,
    dropout=0.5
).to(device)

# Test forward pass
images, captions, lengths = next(iter(dataloader))
images = images.to(device)
captions = captions.to(device)

encoder_out = encoder(images)
predictions, _, _, _, _ = decoder(encoder_out, captions, torch.tensor(lengths))

print("Encoder output shape:", encoder_out.shape)  # Should be (batch_size, 196, 2048)
print("Predictions shape:", predictions.shape)     # Should be (batch_size, max_len, vocab_size)

Using device: cuda
Encoder output shape: torch.Size([4, 196, 2048])
Predictions shape: torch.Size([4, 10, 10307])


  predictions, _, _, _, _ = decoder(encoder_out, captions, torch.tensor(lengths))


In [63]:
class MaskedCrossEntropyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.criterion = nn.CrossEntropyLoss(reduction='none', ignore_index=0)  # ignore <pad>

    def forward(self, predictions, targets, lengths):
        batch_size, max_len, vocab_size = predictions.shape

        predictions = predictions.view(-1, vocab_size)      # (batch_size * max_len, vocab_size)
        targets = targets.contiguous().view(-1)              # (batch_size * max_len)

        losses = self.criterion(predictions, targets)        # (batch_size * max_len)

        # Create mask
        mask = torch.arange(max_len).expand(batch_size, max_len).to(lengths.device)
        mask = (mask < lengths.unsqueeze(1)).float()         # (batch_size, max_len)
        mask = mask.view(-1)                                 # Flatten to (batch_size * max_len)

        losses = losses * mask
        return losses.sum() / mask.sum()


In [64]:
encoder = Encoder().to(device)
decoder = DecoderWithAttention(
    attention_dim=512,
    embed_dim=512,
    decoder_dim=512,
    vocab_size=len(word_map),
    encoder_dim=2048,
    dropout=0.5
).to(device)

# Only fine-tune the encoder's adaptive pool layer
encoder_params = list(encoder.adaptive_pool.parameters()) + list(encoder.cnn.layer4.parameters())
decoder_params = decoder.parameters()

optimizer = torch.optim.Adam(
    params=[
        {'params': encoder_params, 'lr': 1e-4},  # Lower LR for encoder
        {'params': decoder_params, 'lr': 4e-4}    # Higher LR for decoder
    ],
    weight_decay=1e-5
)

In [65]:
def train_epoch(encoder, decoder, dataloader, criterion, optimizer, device, grad_clip=5.0):
    encoder.train()
    decoder.train()
    total_loss = 0
    
    for i, (images, captions, lengths) in enumerate(dataloader):
        images = images.to(device)
        captions = captions.to(device)
        lengths = torch.tensor(lengths).to(device)
        
        # Forward pass
        encoder_out = encoder(images)
        predictions, _, decode_lengths, _, _ = decoder(encoder_out, captions, lengths)
        
        # Remove <start> token and truncate to actual lengths
        targets = captions[:, 1:]  # (batch_size, max_len-1)
        predictions = predictions[:, :max(decode_lengths), :]  # (batch_size, actual_max_len, vocab_size)

        # Calculate loss
        loss = criterion(predictions, targets, decode_lengths)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), grad_clip)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), grad_clip)
        
        optimizer.step()
        
        total_loss += loss.item()
        
        if i % 100 == 0:
            print(f"Batch [{i}/{len(dataloader)}] Loss: {loss.item():.4f}")
    
    return total_loss / len(dataloader)

In [66]:
def validate(encoder, decoder, val_loader, criterion, device):
    encoder.eval()
    decoder.eval()
    total_loss = 0
    
    with torch.no_grad():
        for images, captions, lengths in val_loader:
            images = images.to(device)
            captions = captions.to(device)
            lengths = torch.tensor(lengths).to(device)
            
            encoder_out = encoder(images)
            predictions, _, decode_lengths, _, _ = decoder(encoder_out, captions, lengths)
            
            targets = captions[:, 1:]
            predictions = predictions[:, :max(decode_lengths), :]
            
            loss = criterion(predictions, targets, decode_lengths)
            total_loss += loss.item()
    
    return total_loss / len(val_loader)

In [67]:
# Initialize components
criterion = MaskedCrossEntropyLoss().to(device)
num_epochs = 10  # For initial test

# Quick test with 1 batch
test_images, test_captions, test_lengths = next(iter(dataloader))
test_images = test_images.to(device)
test_captions = test_captions.to(device)
test_lengths = torch.tensor(test_lengths).to(device)

# Forward test
encoder_out = encoder(test_images)
predictions, _, decode_lengths, _, _ = decoder(encoder_out, test_captions, test_lengths)
targets = test_captions[:, 1:]

# Convert decode_lengths to tensor
# decode_lengths = torch.tensor(decode_lengths).to(device)

# loss = criterion(predictions, targets, decode_lengths)
# decode_lengths = torch.tensor(decode_lengths).to(device)
# loss = criterion(predictions, targets, decode_lengths)


print(f"Initial loss: {loss.item():.4f}")  # Should be ~log(vocab_size) = ~9.2 for vocab_size=10307
optimizer.step()  # Verify backprop works without errors

Initial loss: 8.6433


  test_lengths = torch.tensor(test_lengths).to(device)


In [68]:
for epoch in range(num_epochs):
    print(f"\n--- Epoch {epoch + 1} ---")

    train_loss = train_epoch(
        encoder, decoder, dataloader,
        criterion, optimizer, device
    )
    print(f"Train Loss: {train_loss:.4f}")

    val_loss = validate(
        encoder, decoder, val_dataloader,
        criterion, device
    )


--- Epoch 1 ---


  lengths = torch.tensor(lengths).to(device)


AttributeError: 'list' object has no attribute 'device'