## Step 1: Training Setup

In [1]:
import torch
import torch.nn as nn
import math
import torch.utils.data as data
import numpy as np
import os
import requests
import time
import sys

from torchvision import transforms
from pycocotools.coco import COCO

from utils.data_loader import get_loader
from utils.model import EncoderCNN, DecoderRNN

In [2]:
## Select appropriate values for the Python variables below.
batch_size = 32          # batch size
vocab_threshold = 5        # minimum word count threshold
vocab_from_file = True    # if True, load existing vocab file
embed_size = 512           # dimensionality of image and word embeddings
hidden_size = 512          # number of features in hidden state of the RNN decoder
num_epochs = 3             # number of training epochs
save_every = 100             # determines frequency of saving model weights
print_every = 100          # determines window for printing average loss
log_file = 'logs/training_log.txt'       # name of file with saved training loss and perplexity

In [3]:
clip_value = 2             # the maximum gradient value for clipping
num_layers = 3 

In [4]:
# Amend the image transform below.
transform_train = transforms.Compose([ 
    transforms.Resize(256),                          # smaller edge of image resized to 256
    transforms.RandomCrop(224),                      # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(),               # horizontally flip image with probability=0.5
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))])

In [5]:
# Build data loader.
data_loader = get_loader(transform=transform_train,
                         mode='train',
                         batch_size=batch_size,
                         vocab_threshold=vocab_threshold,
                         vocab_from_file=vocab_from_file)

Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...
Done (t=0.64s)
creating index...
index created!
Obtaining caption lengths...


HBox(children=(FloatProgress(value=0.0, max=414113.0), HTML(value='')))




In [6]:
# The size of the vocabulary.
vocab_size = len(data_loader.dataset.vocab)

In [7]:
# Initialize the encoder and decoder. 
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

In [8]:
# Move models to GPU if CUDA is available. 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

DecoderRNN(
  (embedding): Embedding(9955, 512)
  (lstm): LSTM(512, 512, num_layers=3, batch_first=True)
  (fc): Linear(in_features=512, out_features=9955, bias=True)
)

In [9]:
# Define the loss function. 
criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

In [10]:
# Specify the learnable parameters of the model.
params = list(decoder.parameters()) + list(encoder.embed.parameters())

In [11]:
# Define the optimizer.
optimizer = torch.optim.Adam(params, lr=0.003)

In [12]:
# Set the total number of training steps per epoch.
total_step = math.ceil(len(data_loader.dataset.caption_lengths) / data_loader.batch_sampler.batch_size)

## Step 2: Train your Model

In [13]:
f = open(log_file, 'w')

In [None]:
for epoch in range(1, num_epochs+1):
    
    for i_step in range(1, total_step+1):
        
        # Randomly sample a caption length, and sample indices with that length.
        indices = data_loader.dataset.get_train_indices()
        # Create and assign a batch sampler to retrieve a batch with the sampled indices.
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler
        
        # Obtain the batch.
        images, captions = next(iter(data_loader))

        # Move batch of images and captions to GPU if CUDA is available.
        images = images.to(device)
        captions = captions.to(device)
        
        # Zero the gradients.
        decoder.zero_grad()
        encoder.zero_grad()
        
        # Pass the inputs through the CNN-RNN model.
        features = encoder(images)
        outputs = decoder(features, captions)
        
        # Calculate the batch loss.
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
        
        # Backward pass.
        loss.backward()
        
        torch.nn.utils.clip_grad_value_(decoder.parameters(), clip_value)
        
        # Update the parameters in the optimizer.
        optimizer.step()
            
        # Get training statistics.
        stats = 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch, num_epochs, i_step, total_step, loss.item(), np.exp(loss.item()))
        
        # Print training statistics (on same line).
        print('\r' + stats, end="")
        sys.stdout.flush()
        
        # Print training statistics to file.
        f.write(stats + '\n')
        f.flush()
        
        # Print training statistics (on different line).
        if i_step % print_every == 0:
            print('\r' + stats)
            
        # Save the weights.
        if i_step % save_every == 0:
            torch.save(decoder.state_dict(), os.path.join('./models', 'decoder-%d.pkl' % epoch))
            torch.save(encoder.state_dict(), os.path.join('./models', 'encoder-%d.pkl' % epoch))

# Close the training log file.
f.close()

Epoch [1/3], Step [100/12942], Loss: 4.5127, Perplexity: 91.1654
Epoch [1/3], Step [200/12942], Loss: 4.4598, Perplexity: 86.47393
Epoch [1/3], Step [300/12942], Loss: 4.3214, Perplexity: 75.29472
Epoch [1/3], Step [400/12942], Loss: 3.6646, Perplexity: 39.04097
Epoch [1/3], Step [500/12942], Loss: 3.9778, Perplexity: 53.40171
Epoch [1/3], Step [600/12942], Loss: 3.9036, Perplexity: 49.5790
Epoch [1/3], Step [700/12942], Loss: 4.3871, Perplexity: 80.4065
Epoch [1/3], Step [800/12942], Loss: 3.1655, Perplexity: 23.70127
Epoch [1/3], Step [900/12942], Loss: 4.0322, Perplexity: 56.3853
Epoch [1/3], Step [1000/12942], Loss: 3.3926, Perplexity: 29.7434
Epoch [1/3], Step [1100/12942], Loss: 3.6995, Perplexity: 40.4253
Epoch [1/3], Step [1200/12942], Loss: 3.1487, Perplexity: 23.30663
Epoch [1/3], Step [1300/12942], Loss: 3.0449, Perplexity: 21.00890
Epoch [1/3], Step [1400/12942], Loss: 2.8780, Perplexity: 17.7790
Epoch [1/3], Step [1500/12942], Loss: 3.1136, Perplexity: 22.5011
Epoch [1/3],