In [1]:
!mkdir data
!mkdir models

In [3]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence


class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
        
    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seg_length = max_seq_length
        
    def forward(self, features, captions, lengths):
        """Decode image feature vectors and generates captions."""
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        return outputs
    
    def sample(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)          # hiddens: (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))            # outputs:  (batch_size, vocab_size)
            _, predicted = outputs.max(1)                        # predicted: (batch_size)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)                       # inputs: (batch_size, embed_size)
            inputs = inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)
        sampled_ids = torch.stack(sampled_ids, 1)                # sampled_ids: (batch_size, max_seq_length)
        return sampled_ids


In [4]:
class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)


In [5]:
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def main():
    # Hyperparameters & paths
    model_path = '/kaggle/working/models/'
    crop_size = 224
    vocab_path = '/kaggle/input/tokenized-vocab/vocab.pkl'
    image_dir = '/kaggle/input/resized-imagescoco/resized2014'
    caption_path = '/kaggle/input/coco-image-caption/annotations_trainval2014/annotations/captions_train2014.json'
    log_step = 10
    save_step = 1000
    
    embed_size = 256
    hidden_size = 512
    num_layers = 1
    
    num_epochs = 5
    batch_size = 128
    num_workers = 2
    learning_rate = 0.001
    
    # Create model directory
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    
    # Image preprocessing
    transform = transforms.Compose([ 
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])
    
    # Load vocabulary
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)
    
    # Build data loader
    data_loader = get_loader(image_dir, caption_path, vocab, 
                             transform, batch_size,
                             shuffle=True, num_workers=num_workers) 

    # Build the models
    encoder = EncoderCNN(embed_size).to(device)
    decoder = DecoderRNN(embed_size, hidden_size, len(vocab), num_layers).to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=learning_rate)
    
    # Training loop
    total_step = len(data_loader)
    for epoch in range(num_epochs):
        for i, (images, captions, lengths) in enumerate(data_loader):
            
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
            
            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

            # Logging
            if i % log_step == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                      .format(epoch, num_epochs, i, total_step, loss.item(), np.exp(loss.item()))) 
                
            # Save checkpoints
            if (i+1) % save_step == 0:
                torch.save(decoder.state_dict(), os.path.join(
                    model_path, f'decoder-{epoch+1}-{i+1}.ckpt'))
                torch.save(encoder.state_dict(), os.path.join(
                    model_path, f'encoder-{epoch+1}-{i+1}.ckpt'))

if __name__ == '__main__':
    main()

Epoch [1/5], Step [80/3236], Loss: 2.1820, Perplexity: 8.8639
Epoch [1/5], Step [90/3236], Loss: 2.2243, Perplexity: 9.2473
Epoch [1/5], Step [100/3236], Loss: 2.1687, Perplexity: 8.7466
Epoch [1/5], Step [110/3236], Loss: 2.0501, Perplexity: 7.7687
Epoch [1/5], Step [120/3236], Loss: 2.1077, Perplexity: 8.2289
Epoch [1/5], Step [130/3236], Loss: 2.1296, Perplexity: 8.4117
Epoch [1/5], Step [140/3236], Loss: 2.1856, Perplexity: 8.8957
Epoch [1/5], Step [150/3236], Loss: 2.1648, Perplexity: 8.7125
Epoch [1/5], Step [160/3236], Loss: 2.0680, Perplexity: 7.9091
Epoch [1/5], Step [170/3236], Loss: 2.1382, Perplexity: 8.4840
Epoch [1/5], Step [180/3236], Loss: 2.2571, Perplexity: 9.5550
Epoch [1/5], Step [190/3236], Loss: 2.1295, Perplexity: 8.4107
Epoch [1/5], Step [200/3236], Loss: 2.2033, Perplexity: 9.0551
Epoch [1/5], Step [210/3236], Loss: 2.2001, Perplexity: 9.0261
Epoch [1/5], Step [220/3236], Loss: 2.1454, Perplexity: 8.5452
Epoch [1/5], Step [230/3236], Loss: 2.1323, Perplexity: 8

2025-08-24 12:41:12.253810
