In [1]:
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset, TensorDataset
from torchvision import transforms, utils
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import pickle
from tqdm import tqdm

In [2]:
class ImageCaptionDataset(Dataset):
    def __init__(self,path_images,path_caption_file,word_to_id,transform=None):
        super().__init__()
        self.path_images = path_images
        self.path_caption_file = path_caption_file
        self.transform = transform
        self.idxtoImageCaption= {}
        self.word_to_id = word_to_id
        with open(path_caption_file,'r') as f:
            for i,line in enumerate(f):
                image_name,caption = line.split('\t')
                tokenized_caption = self.tokenize(caption)
                self.idxtoImageCaption[i] = [image_name,tokenized_caption]

        
    def __len__(self):
        return len(self.idxtoImageCaption)
    
    def __getitem__(self,idx):
        image_name,caption = self.idxtoImageCaption[idx]
        try:
            image = Image.open(self.path_images+image_name).convert("RGB")
        except:
            return False,0,0,0
        
        if self.transform !=None:
            image = self.transform(image)

        return True,image,torch.tensor(caption),image_name
    
    def tokenize(self,caption):
        tokens = [self.word_to_id["<start>"]]
        words = caption.split()
        for word in words:
            tokens.append(self.word_to_id[word])
            
        tokens.append(self.word_to_id["<end>"])
        return tokens
            

In [3]:
path_images = '/home/student/Image_Captioning_Project/Data/Images/'
path_caption_file_train = '/home/student/Image_Captioning_Project/Data/train_captions.txt'
path_caption_file_test = '/home/student/Image_Captioning_Project/Data/test_captions.txt'
with open('/home/student/Image_Captioning_Project/Data/word_to_id.pkl','rb') as f:
    word_to_id = pickle.load(f)

In [84]:
num_epochs = 10
batch_size = 100
learning_rate = 0.001
embedding_dim = 256
hidden_dim = 256
acumulate_grad_steps = 50

transform_train = transforms.Compose([transforms.Resize((299,299)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
train_dataset = ImageCaptionDataset(path_images,path_caption_file_train,word_to_id,transform=transform_train)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=1, 
                                           shuffle=False)

In [6]:
def to_gpu(x):
    return x.cuda() if torch.cuda.is_available() else x

In [90]:
class CaptionGenerator(nn.Module):
    def __init__(self,vocabulary_size,embedding_dim,hidden_dim):
        super(CaptionGenerator, self).__init__()
        self.hidden_dim = hidden_dim
        self.vgg11 = models.vgg11(pretrained=True)
        for child in self.vgg11.children():
            for param in child.parameters():
                param.requires_grad = False
        
        self.embedding = nn.Embedding(vocabulary_size,embedding_dim)
        
        self.decoder = nn.LSTM(embedding_dim,hidden_dim,batch_first=True)
        self.dense = nn.Sequential(nn.Linear(hidden_dim+1000,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,vocabulary_size))
        self.criterion = nn.NLLLoss()
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x,h=None,c=None):
        image,caption = x
        image_features = self.vgg11(image)
        embedding = None
        loss = to_gpu(torch.zeros(1).requires_grad_())
        predictions = []
        if h is None:
            h= to_gpu(torch.zeros(1,1, self.hidden_dim).requires_grad_())
            c = to_gpu(torch.zeros(1,1, self.hidden_dim).requires_grad_())
            
        for i in range(caption.shape[1]):
            if i ==(caption.shape[1]-1):
                continue
            if self.training == True:
                embeddings = self.embedding(caption[:,i]).view(1,1,-1)
            
                
            lstm_out,(h,c) = self.decoder(embeddings,(h,c))
            
            flatten = lstm_out.view(1,hidden_dim)
            cat = torch.cat([flatten,image_features],dim=1)
            probs = self.softmax(self.dense(cat))
            l = self.criterion(probs,caption[:,i+1])
            loss += l

        return loss,h,c,predictions

In [93]:
model = CaptionGenerator(len(train_dataset.word_to_id),embedding_dim,hidden_dim)
model = to_gpu(model)
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [94]:
h = None
c = None
size = len(train_dataset)
for epoch in range(num_epochs):
    loss_train_total = 0 # To keep track of the loss value
    for i, (flag,image, caption,file) in enumerate(tqdm(train_loader)):

        if flag ==False:
            continue
            
        image = to_gpu(image)
        caption = to_gpu(caption)

        loss,h,c,predictions = model((image,caption),h,c)

        loss = loss/acumulate_grad_steps
        

        loss.backward(retain_graph=True)

        if i % acumulate_grad_steps == 0:
            optimizer.step()
            model.zero_grad()

        loss_train_total += loss.item()
        h = h.detach()
        c = h.detach()
    loss_train_total = loss_train_total / size
    print("Epoch {} Completed,\tTrain Loss: {}".format(epoch + 1, loss_train_total))

 39%|███▊      | 14065/36415 [10:04<16:00, 23.26it/s]


KeyboardInterrupt: 