In [None]:
import json
from collections import Counter
import pickle
import torch_geometric
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.nn import GATConv, global_mean_pool
import matplotlib.pyplot as plt
import seaborn as sns
import math
from torch_geometric.data import Data, DataLoader
import random
import time
import nltk
from earlystopping import EarlyStopping
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA

In [None]:
df = pd.read_csv('Flickr8k.token.txt', sep = "\t")

In [None]:
encoded_features = np.load('googlenet_1024_features.npy')

In [None]:
pca = PCA(n_components=64)

In [None]:
encoded_features = pca.fit_transform(encoded_features)

In [None]:
nltk.download('averaged_perceptron_tagger')

In [None]:
for i in range(5):
    df = df[df['FileName'] != '2258277193_586949ec62.jpg.1#' + str(i)]

In [None]:
def preProcess(s):
    s = s.lower()
    
    # initializing punctuations string
    punc = '''!()-[]{};:'"\, <>./?@#$%^&*_~'''
    # Removing punctuations in string
    # Using loop + punctuation string
    for ele in s: 
        if ele in punc: 
            s = s.replace(ele, " ")
    s = s.split()
    l = []
    banned = ['IN', 'DT']
    for i in s:
        #print(i)
        if(nltk.pos_tag([i])[0][1] not in banned):
            l.append(i)
            #print(i)
    return l

In [None]:
nltk.pos_tag(['is'])[0][1]

In [None]:
preProcess("for")

In [None]:
df.Description = df.Description.apply(preProcess)

In [None]:
df.Description

In [None]:
mp = {}
for idx, rows in df.iterrows():
    fname = rows['FileName'][:-2]
    if(fname not in mp):
        mp[fname] = []
    mp[fname].append(rows['Description'])

In [None]:
df = pd.DataFrame(mp.items(), columns = ['FileName', 'Description'])

In [None]:
counts = {}
for idx, rows in df.iterrows():
    for j in rows['Description']:
        for k in j:
            if(k not in counts):
                counts[k] = 1
            else:
                counts[k] += 1

In [None]:

common_words = Counter(counts).most_common(5000)

In [None]:
counts

In [None]:
map_vocab = {}
cnt = 1
for i in common_words:
    map_vocab[i[0]] = cnt
    cnt += 1
map_vocab['UNK'] = cnt
cnt += 1
map_vocab["SOF"] = cnt
cnt += 1
map_vocab["EOF"] = cnt
map_vocab["Padding"] = 0

In [None]:
inv_mapping = {}
for i in map_vocab:
    inv_mapping[map_vocab[i]] = i

In [None]:
counts['UNK'] = counts['man']

In [None]:
weights = [0]
for i in range(1, 5002):
    x = np.log(1/counts[inv_mapping[i]]) + 12
    weights.append(x)

In [None]:
weights.append(0)
weights.append(np.log(1/40000) + 12)

In [None]:
len(weights)

In [None]:
VOCAB_SIZE = 5004

SEQ_LEN = 60

In [None]:
train_data = []

In [None]:
def convertVocab(x):
    l = []
    for i in x:
        m = []
        m.append(map_vocab['SOF'])
        for j in i:
            #print(j)
            if(j in map_vocab):
                m.append(map_vocab[j])
            else:
                m.append(map_vocab['UNK'])
        m = m[:min(len(m), SEQ_LEN - 1)]
        m.append(map_vocab['EOF'])
        l.append(m)
    return l

In [None]:
df.Description = df.Description.apply(convertVocab)

In [None]:
train_data = []
for idx, rows in df.iterrows():
    curr_x = torch.tensor(encoded_features[idx])
    for j in rows['Description']:
        #j = j[:min(len(j), SEQ_LEN)]
        if(len(j) < SEQ_LEN):
            j = torch.cat([torch.tensor(j), torch.zeros(SEQ_LEN - len(j))])
        else:
            j = torch.tensor(j)
        train_data.append((curr_x, j.long()))

In [None]:
split = int(len(train_data) * 80 / 100)
validation_data = train_data[split:]
train_data = train_data[:split]

In [None]:
BATCH_SIZE = 64
train_loader = DataLoader(train_data, BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(validation_data, BATCH_SIZE, shuffle = True, num_workers = 4, pin_memory = True)

In [None]:
WORD_EMBEDDING_DIM = 300
HIDDEN_SIZE = 64

class Decoder(torch.nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.embeddingLayer = nn.Embedding(VOCAB_SIZE, WORD_EMBEDDING_DIM)
        self.dropout = nn.Dropout(.5)
        self.GRU = nn.GRU(input_size = WORD_EMBEDDING_DIM, hidden_size = HIDDEN_SIZE, 
                          batch_first = True, num_layers = 1)
        self.linear = nn.Linear(in_features = HIDDEN_SIZE, out_features = VOCAB_SIZE)

    def forward(self, x, h_0):
        x = x.to(device)
        h_0 = h_0.to(device)
        #print(x.shape)
        x = self.embeddingLayer(x)
        x = self.dropout(x)
        #print(x.shape)
        x = x.view((BATCH_SIZE, 1, -1))
        #print(x.shape)
        h_0 = h_0.view(1, BATCH_SIZE, HIDDEN_SIZE)
        _, h_n = self.GRU(x, h_0)
        h_n = h_n.view((BATCH_SIZE, -1))
        out = self.linear(h_n)
        return out, h_n
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
model = Decoder()
device = torch.device('cuda:0')
model = model.to(device)

In [None]:
# cnt = 0
# for batch in train_loader:
#     x = batch[1]
#     h = batch[0].float()
#     if(x.shape[0] != BATCH_SIZE):
#         print(x.shape[0])
#         continue
#     for j in range(SEQ_LEN - 1):
#         out, h = model(x[:, j], h)
#         y_true = x[:, j + 1].to(device)
#         print(y_true.shape)
#         criterion(out, y_true)
# print(cnt)        

In [None]:
weights = torch.tensor(weights)
weights += 1

In [None]:
print(weights)

In [None]:
weights = weights.to(device).float()

In [None]:
criterion = torch.nn.CrossEntropyLoss(ignore_index = 0, weight = weights)
optimizer = torch.optim.Adam(model.parameters())

def train_model(model, patience = 3, n_epochs = 20):
    
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    for epoch in range(1, n_epochs + 1):

        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        for batch in train_loader:
            loss = 0
            #batch.to(device)
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            
            x = batch[1]
            h = batch[0].float()
            if(x.shape[0] != BATCH_SIZE):
                continue
            for j in range(SEQ_LEN - 1):
                out, h = model(x[:, j], h)
                y_true = x[:, j + 1].to(device)
                loss += criterion(out, y_true)
            #output = model(batch.x, batch.edge_index, batch.batch)
            # calculate the loss
            #loss = criterion(output, batch.y)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # record training loss
            train_losses.append(loss.item())

        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation
        for batch in val_loader:
            #batch.to(device)
            # forward pass: compute predicted outputs by passing inputs to the model
            loss = 0
            x = batch[1]
            h = batch[0].float()
            if(x.shape[0] != BATCH_SIZE):
                continue
            for j in range(SEQ_LEN - 1):
                out, h = model(x[:, j], h)
                y_true = x[:, j + 1].to(device)
                loss += criterion(out, y_true)
            
            
            #output = model(batch.x, batch.edge_index, batch.batch)
            # calculate the loss
            #loss = criterion(output, batch.y)
            # record validation loss
            valid_losses.append(loss.item())

        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        
        epoch_len = len(str(n_epochs))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        
        print(print_msg)
        
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    # load the last checkpoint with the best model
    model.load_state_dict(torch.load('checkpoint.pt'))

    return  model, avg_train_losses, avg_valid_losses

In [None]:
model = train_model(model)

In [None]:
model = model[0]

In [None]:
test_features = np.load('googlenet_test_features.npy')
encoded_features = np.load('googlenet_1024_features.npy')

In [None]:
test_features = np.vstack((test_features, encoded_features))

In [None]:
test_features = pca.fit_transform(test_features)
test_features = test_features[:5]
test_features.shape

In [None]:
test_features = torch.tensor(test_features)

In [None]:
test_features.shape

In [None]:
x = torch.tensor(map_vocab['SOF'])
x = x.view((1, -1))

In [None]:
BATCH_SIZE = 1

In [None]:
model.eval()
token = 'SOF'
cnt = 0
h = test_features[4]
h = h.view((1, -1)).float()
s = ""
while(token != 'EOF' and cnt < 130):
    cnt += 1
    #print(x.shape)
    #print(h.shape)
    out, h = model(x, h)
    char = torch.argmax(out)
    token = inv_mapping[char.item()]
    s += token + " "
print(s)