In [1]:
import math
from model import EncoderCNN, DecoderRNN
from data_loader import get_loader
from data_loader_val import get_loader as val_loader
from pycocotools.coco import COCO
from torchvision import transforms
from tqdm.notebook import tqdm
import torch.nn as nn
import torch
import torch.utils.data as data
from collections import defaultdict
import json
import os
import sys
import numpy as np
from nlp_utils import clean_sentence, bleu_score

%load_ext autoreload
%autoreload 2

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/satvikahuja13/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
cocoapi_dir = r"cocoapi/"
folders = [folder for folder in os.listdir("cocoapi/")]
folders

['.DS_Store', 'images', 'annotations']

In [3]:
batch_size = 128
vocab_threshold = 5
vocab_from_file = True
embed_size = 256
hidden_size = 512
num_epochs = 3
save_every = 1
print_every = 20
log_file = "training_log.txt"

In [4]:
transform_train = transforms.Compose(
    [
     transforms.Resize(256),
     transforms.RandomCrop(224),
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.Normalize(
         (0.485, 0.456, 0.406),
         (0.229, 0.224, 0.225),
     ),
    ]
)

In [5]:
data_loader = get_loader(
    transform=transform_train,
    mode="train",
    batch_size=batch_size,
    vocab_threshold=vocab_threshold,
    vocab_from_file=vocab_from_file,
    cocoapi_loc = cocoapi_dir,
)

Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...
Done (t=0.49s)
creating index...
index created!
Obtaining caption lengths...


100%|██████████| 591753/591753 [00:29<00:00, 20382.69it/s]


In [6]:
vocab_size = len(data_loader.dataset.vocab)
print("vocab size is :", vocab_size)
#initializing the encoder and decoder
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

#move models to device
device = torch.device("mps")
encoder.to(device)
decoder.to(device)

#defining the loss function
criterion = (
             nn.CrossEntropyLoss().to(device)
)

#specifying the learnable parameters of the mode
params = list(decoder.parameters()) + list(encoder.embed.parameters())

#Defining the optimize
optimizer = torch.optim.Adam(params, lr=0.001)

#Set the total number of training steps per epoc
total_step = math.ceil(len(data_loader.dataset)/data_loader.batch_sampler.batch_size)

vocab size is : 11543




In [7]:
print(total_step)

4624


## **Training the Model**

In [8]:
f = open(log_file, "w")

for epoch in range(1, num_epochs+1):
    for i_step in range(1, total_step+1):

        #Randomly sample a caption length, and sample indices with that length.
        indices = data_loader.dataset.get_train_indices()
        #Create and assign a batch sampler to retrieve a batch with the sampled indices.
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler

        #Obtain the batch.
        images, captions = next(iter(data_loader))
        
        #Move batch of images and captions to GPU
        images = images.to(device)
        captions = captions.to(device)
        
        # Zero the gradients.
        decoder.zero_grad()
        encoder.zero_grad()

        #passing the inputs through the CNN-RNN model
        features = encoder(images)
        outputs = decoder(features, captions)

        #Calculating the batch Loss.
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))

        #Backwarding pass

        loss.backward()
        #updating the parameters in the optimizer
        optimizer.step()

        #Getting training statistics
        stats = (
            f"Epoch [{epoch}/{num_epochs}], Step [{i_step}/{total_step}], "
            f"Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):.4f}"
        )

        #printing training statistics to file
        f.write(stats + "\n")
        f.flush()

        #Print training statistics (on different line).
        if i_step % print_every == 0:
            print("\r" + stats)
    #save the weights
    if epoch% save_every == 0:
        torch.save(
            decoder.state_dict(), os.path.join("models", "decoder-%d.pkl" % epoch)
        )
        torch.save(
            encoder.state_dict(), os.path.join("models", "encoder-%d.pkl" %epoch)
        )
#close the training log file.
f.close()




Epoch [1/3], Step [20/4624], Loss: 5.1296, Perplexity: 168.9420
Epoch [1/3], Step [40/4624], Loss: 4.3825, Perplexity: 80.0355
Epoch [1/3], Step [60/4624], Loss: 4.0272, Perplexity: 56.1009
Epoch [1/3], Step [80/4624], Loss: 4.6615, Perplexity: 105.7949
Epoch [1/3], Step [100/4624], Loss: 3.9281, Perplexity: 50.8078
Epoch [1/3], Step [120/4624], Loss: 3.7151, Perplexity: 41.0629
Epoch [1/3], Step [140/4624], Loss: 3.5556, Perplexity: 35.0083
Epoch [1/3], Step [160/4624], Loss: 3.6330, Perplexity: 37.8246
Epoch [1/3], Step [180/4624], Loss: 3.5242, Perplexity: 33.9250
Epoch [1/3], Step [200/4624], Loss: 3.6510, Perplexity: 38.5131
Epoch [1/3], Step [220/4624], Loss: 3.5533, Perplexity: 34.9273
Epoch [1/3], Step [240/4624], Loss: 3.5534, Perplexity: 34.9325
Epoch [1/3], Step [260/4624], Loss: 3.2923, Perplexity: 26.9043
Epoch [1/3], Step [280/4624], Loss: 3.3987, Perplexity: 29.9254
Epoch [1/3], Step [300/4624], Loss: 3.2355, Perplexity: 25.4185
Epoch [1/3], Step [320/4624], Loss: 3.3826