In [1]:
import ast
import warnings
import re
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import random
from tqdm.notebook import tqdm_notebook
from torchmetrics.functional import bleu_score
warnings.filterwarnings("ignore")

SOS_token = 0
EOS_token = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def load_text_data():
    with open("../../data/x.txt") as f:
        x_file = f.readlines()

    with open("../../data/y.txt") as f:
        y_file = f.readlines()

    x_list = []
    y_list_temp = []

    for x_val, y_val in zip(x_file, y_file):
        x_list.append((x_val.split(", \n")[0]))
        y_list_temp.append((y_val.split(" \n")[0]))

    y_list = []
    for y in y_list_temp:
        x = y
        x = ast.literal_eval(x)
        y_list.append(x)

    x_list_clean = []
    y_list_clean = []

    for x in x_list:
        x = re.sub(r'[^\w\s]', '', x)
        x = x.lower()
        x_list_clean.append(x)

    for y in y_list:
        for i in range(len(y)):
            y[i] = re.sub(r'[^\w\s]', '', y[i])
            y[i] = y[i].lower()
        y_list_clean.append(y)

    return x_list_clean, y_list_clean

x_text, y_text = load_text_data()

In [3]:
x_word2idx = {}
x_idx2word = {0: "<SOS>", 1: "<EOS>"}
x_word2count = {}
x_n_words = 2 # Count SOS and EOS

for x in x_text:
    for word in x.split():
        if word not in x_word2idx:
            x_word2idx[word] = x_n_words
            x_idx2word[x_n_words] = word
            x_n_words += 1
            x_word2count[word] = 1
        else:
            x_word2count[word] += 1

x_lookup = []

for x in x_text:
    x_lookup_sentence = [x_word2idx[word] for word in x.split()]
    x_lookup_sentence.append(EOS_token)
    x_lookup.append(x_lookup_sentence)

In [4]:
y_word2idx = {}
y_idx2word = {0: "<SOS>", 1: "<EOS>"}
y_word2count = {}
y_n_words = 2 # Count SOS and EOS

for y in y_text:
    for i in range(len(y)):
        for word in y[i].split():
            if word not in y_word2idx:
                y_word2idx[word] = y_n_words
                y_idx2word[y_n_words] = word
                y_n_words += 1
                y_word2count[word] = 1
            else:
                y_word2count[word] += 1

y_lookup = []

for y in y_text:
    y_lookup_temp = []
    for i in range(len(y)):
        y_lookup_sentence = [y_word2idx[word] for word in y[i].split()]
        y_lookup_sentence.append(EOS_token)
        y_lookup_temp.append(y_lookup_sentence)
    y_lookup.append(y_lookup_temp)

In [5]:
x_lookup_train = x_lookup[:int(len(x_lookup)*0.8)]
x_lookup_val = x_lookup[int(len(x_lookup)*0.8):int(len(x_lookup)*0.9)]
x_lookup_test = x_lookup[int(len(x_lookup)*0.9):]

y_lookup_train = y_lookup[:int(len(y_lookup)*0.8)]
y_lookup_val = y_lookup[int(len(y_lookup)*0.8):int(len(y_lookup)*0.9)]
y_lookup_test = y_lookup[int(len(y_lookup)*0.9):]

In [6]:
def get_lookup_tensors(index, mode):
    if mode == 'train':
        x = x_lookup_train[index]
        y = y_lookup_train[index]
    elif mode == 'val':
        x = x_lookup_val[index]
        y = y_lookup_val[index]
    elif mode == 'test':
        x = x_lookup_test[index]
        y = y_lookup_test[index]

    # x = torch.tensor(x, dtype=torch.long, device=device).view(-1, 1)
    # y = torch.tensor(y, dtype=torch.long, device=device).view(-1, 1)
    return x, y

In [7]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [8]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [9]:
hidden_size = 256
encoder = EncoderRNN(x_n_words, hidden_size).to(device)
decoder = DecoderRNN(hidden_size, y_n_words).to(device)

encoder_optimizer = optim.SGD(encoder.parameters(), lr=0.01)
decoder_optimizer = optim.SGD(decoder.parameters(), lr=0.01)

criterion = nn.NLLLoss()

In [10]:
encoder_total_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
decoder_total_params = sum(p.numel() for p in decoder.parameters() if p.requires_grad)

print(f'Encoder total parameters: {encoder_total_params:,}')
print(f'Decoder total parameters: {decoder_total_params:,}')

Encoder total parameters: 444,160
Decoder total parameters: 4,150,938


In [24]:
def batch_train_step(x_batch, y_batch, encoder, decoder, criterion):
    multi_ex_loss_list = []
    for ex_index, (x, y) in enumerate(zip(x_batch, y_batch)):
        '''
        x: 1-D list of length N; N = Variable Length
        y: 2-D list of shape (5, N); N = Variable Length
        
        For each input tuple (x, y), we are generating 5 sentences for x.
        Where, each sentence is generated 1 word at a time.
        Each predicted sentence would have the same length as the y[i] -- the caption 
        that it is being compared with.
        
        Then, NLLLoss is applied on this output sentence vs y[i] -- 1 sentence at a time
        for 5 sentences for each example (x, y).
        Mean of these 5 values is taken to get the final loss per example (x, y).
        
        Now, per batch (x_batch, y_batch), this loss is calculated and the mean of this batch
        loss is used for BACKPROP. 
        This allows us to implement batch-wise training and helps converge faster.
        '''
        encoder_hidden = encoder.initHidden()
        
        for i in range(len(x)):
            encoder_output, encoder_hidden = encoder(torch.tensor(x[i], device=device), encoder_hidden)

        decoder_input = torch.tensor(SOS_token, device=device)
        decoder_hidden = encoder_hidden

        sentence_loss_list = []

        for i in range(len(y)):
            sentence_output = torch.empty(1, decoder.output_size, device=device)
            teacher_forcing = True if random.random() < 0.5 else False

            for j in range(len(y[i])):
                if teacher_forcing:
                    decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    sentence_output = torch.cat((sentence_output, decoder_output), dim=0)
                    decoder_input = torch.tensor(y[i][j], device=device)
                else:
                    decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    sentence_output = torch.cat((sentence_output, decoder_output), dim=0)
                    topv, topi = decoder_output.topk(1)
                    decoder_input = topi.squeeze().detach()

            sentence_output = sentence_output[1:]
            sentence_loss = criterion(sentence_output, torch.tensor(y[i], device=device))
            sentence_loss_list.append(sentence_loss)

        single_ex_loss = torch.mean(torch.stack(sentence_loss_list))
        multi_ex_loss_list.append(single_ex_loss)

    batch_loss = torch.mean(torch.stack(multi_ex_loss_list))

    return batch_loss

In [29]:
def val_step(x_val, y_val, encoder, decoder, criterion, y_idx2word):
    with torch.no_grad():
        bleu_score_list = []
        val_multi_ex_loss_list = []
        for val_ex_index, (x, y) in enumerate(zip(x_val, y_val)):
            encoder_hidden = encoder.initHidden()
            
            for i in range(len(x)):
                encoder_output, encoder_hidden = encoder(torch.tensor(x[i], device=device), encoder_hidden)

            decoder_hidden = encoder_hidden
            decoder_input = torch.tensor(SOS_token, device=device)
            sentence_loss_list = []
            predicted_sentences = []
            true_sentences = []

            for i in range(len(y)):
                sentence_output = torch.empty(1, decoder.output_size, device=device)
                predicted_words = []
                true_words = []

                for j in range(len(y[i])):
                    deocder_output, deocder_hidden = decoder(decoder_input, decoder_hidden)
                    sentence_output = torch.cat((sentence_output, deocder_output))
                    topv, topi = deocder_output.topk(1)
                    decoder_input = topi.squeeze().detach()
                    
                    predicted_words.append(y_idx2word[topi.item()])
                    true_words.append(y_idx2word[y[i][j]])
                
                sentence_output = sentence_output[1:]
                val_sentence_loss = criterion(sentence_output, torch.tensor(y[i], device=device))
                sentence_loss_list.append(val_sentence_loss)

                predicted_sentence_i = ' '.join(predicted_words)
                true_sentence_i = ' '.join(true_words)
                predicted_sentences.append(predicted_sentence_i)
                true_sentences.append(true_sentence_i)

            val_single_ex_loss = torch.mean(torch.stack(sentence_loss_list))
            val_multi_ex_loss_list.append(val_single_ex_loss)
                
            bleu_score_value = bleu_score(preds=predicted_sentences, target=true_sentences)
            bleu_score_list.append(bleu_score_value)
        
        val_bleu_score = torch.mean(torch.stack(bleu_score_list)).item()
        val_loss = torch.mean(torch.stack(val_multi_ex_loss_list)).item()

        return val_loss, val_bleu_score
            
# val_loss, val_bleu_score = val_step(x_lookup_val, y_lookup_val, encoder, decoder, criterion, y_idx2word)    

In [28]:
num_epochs = 1
batch_size = 10
num_batches = len(x_lookup_train) // batch_size
train_epoch_loss_list = []
val_epoch_loss_list = []
val_epoch_bleu_score_list = []

pre_train_val_loss, pre_train_val_bleu_score = val_step(x_lookup_val, y_lookup_val, encoder, decoder, criterion, y_idx2word)
print(f"pre_train_val_loss: {pre_train_val_loss:.3f}")

for epoch in tqdm_notebook(range(num_epochs), desc="Epoch"):
    train_batch_loss_list = []
    for batch in range(num_batches):
        x_batch = x_lookup_train[batch*batch_size:(batch+1)*batch_size]
        y_batch = y_lookup_train[batch*batch_size:(batch+1)*batch_size]

        batch_loss = batch_train_step(x_batch, y_batch, encoder, decoder, criterion)

        batch_loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()
        train_batch_loss_list.append(batch_loss)

    train_epoch_loss = torch.mean(torch.stack(train_batch_loss_list)).item()
    train_epoch_loss_list.append(train_epoch_loss)

    val_epoch_loss, val_bleu_score = val_step(x_lookup_val, y_lookup_val, encoder, decoder, criterion, y_idx2word)
    val_epoch_loss_list.append(val_epoch_loss)
    val_epoch_bleu_score_list.append(val_bleu_score)
    
    print(f"Epoch: {epoch} \t Loss: {train_epoch_loss:.3f} \t Val Loss: {val_epoch_loss:.3f} \t Val BLeU Score: {val_bleu_score}")

pre_train_val_loss: 8.586


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 0 Loss: 7.725 	 Val Loss: 7.503 	 Val BLeU Score: 0.0


In [30]:
def evaluate(x, y, x_idx2word, y_idx2word, encoder, decoder, target_length=50):
    for i in range(len(x)):
        x_i = x[i]
        y_i = y[i]
        input_length = len(x_i)
        n_captions = len(y_i)
        encoder_hidden = encoder.initHidden()

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(torch.tensor(x_i[ei], device=device), encoder_hidden)

        decoder_input = torch.tensor([[SOS_token]], device=device)
        decoder_hidden = encoder_hidden
        decoded_sentences = []

        for di in range(n_captions):
            decoded_words = []
            for dj in range(target_length):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach()
                decoded_words.append(y_idx2word[topi.item()])
                if topi.item() == EOS_token:
                    is_eos = True
                    break
                else:
                    is_eos = False

            predicted_sentence = ' '.join(decoded_words)
            decoded_sentences.append(predicted_sentence)
            if is_eos:
                break
            
        input_text = ' '.join([x_idx2word[word] for word in x_i])
        target_text = ' '.join([y_idx2word[word] for word in y_i[0]])

        print(f"input: {input_text}")
        print(f"target: {target_text}")
        print(f"predicted: {decoded_sentences[0]}")
        print()

In [35]:
def evaluate_randomly(x_idx2word, y_idx2word, encoder, decoder, n=1):
    for i in range(n):
        rand_int = random.randint(0, len(x_lookup_val))
        x_i = x_lookup_val[rand_int]
        y_i = y_lookup_val[rand_int]

        evaluate([x_i], [y_i], x_idx2word, y_idx2word, encoder, decoder)

In [42]:
def test(x_test, y_test, encoder, decoder, x_idx2word, y_idx2word, target_length=50):
    test_loss = 0
    test_results = []
    bleu_score_list = []

    for i in range(len(x_test)):
        test_loss_i = 0
        x = x_test[i]
        y = y_test[i]
        input_length = len(x)
        n_captions = len(y)
        encoder_hidden = encoder.initHidden()

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(torch.tensor(x[ei], device=device), encoder_hidden)

        decoder_input = torch.tensor([[SOS_token]], device=device)
        decoder_hidden = encoder_hidden
        decoded_sentences = []
        true_sentences = []

        for di in range(n_captions):
            decoded_words = []
            true_words = []
            for dj in range(target_length):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                y_di_dj = y[di][dj] if dj < len(y[di]) else EOS_token
                test_loss_i += criterion(decoder_output, torch.tensor(y_di_dj, device=device).view(-1))
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach()
                
                if dj < len(y[di]):
                    true_words.append(y_idx2word[y[di][dj]])

                if topi.item() == EOS_token:
                    is_eos = True
                    decoded_words.append('<EOS>')
                    break
                else:
                    is_eos = False
                    decoded_words.append(y_idx2word[topi.item()])
                
            predicted_sentence = ' '.join(decoded_words)
            true_sentence = ' '.join(true_words)
            decoded_sentences.append(predicted_sentence)
            true_sentences.append(true_sentence)
            if is_eos:
                break

        test_loss_i = test_loss_i / n_captions
        test_loss += test_loss_i.item()

        bleu_score_i = bleu_score(preds=decoded_sentences, target=true_sentences)
        bleu_score_list.append(bleu_score_i)

        input_text = ' '.join([x_idx2word[word] for word in x])
        target_text = ' '.join([y_idx2word[word] for word in y[0]])
        predicted_text = decoded_sentences[0]

        test_result_i = (input_text, target_text, predicted_text)
        test_results.append(test_result_i)

    test_bleu_score = torch.mean(torch.stack(bleu_score_list)).item()

    return test_loss / len(x_test), test_results, test_bleu_score

In [33]:
evaluate(x_lookup_val, y_lookup_val, x_idx2word, y_idx2word, encoder, decoder)

input: bus 009 038 087 08 <EOS>
target: a green and white bus is on the street <EOS>
predicted: a a a a a a a a <EOS>

input: person 015 025 057 093 person 041 008 048 013 surfboard 041 044 058 07 <EOS>
target: a woman in a wet suit carrying a surfboard into the ocean <EOS>
predicted: a a a a a a a a <EOS>

input: zebra 067 057 084 071 zebra 017 057 038 075 <EOS>
target: an antilope is eating grass in between two zebra <EOS>
predicted: a a a a a a a <EOS>

input: person 057 047 079 07 <EOS>
target: a man surfing beside a bird on a cloudy day <EOS>
predicted: a a a a a a a a <EOS>

input: person 079 062 084 078 person 065 062 069 083 person 075 061 078 074 person 087 062 089 074 person 063 062 064 066 person 092 061 095 07 <EOS>
target: people are skiing and snowboarding on a high mountain <EOS>
predicted: a a a a a a a a <EOS>

input: person 022 042 028 062 person 071 031 078 065 person 042 041 049 064 person 028 041 034 061 person 05 034 057 061 person 053 054 069 081 person 061 049 0

In [36]:
evaluate_randomly(x_idx2word, y_idx2word, encoder, decoder)

input: person 022 042 028 062 person 071 031 078 065 person 042 041 049 064 person 028 041 034 061 person 05 034 057 061 person 053 054 069 081 person 061 049 069 063 skis 00 073 025 083 <EOS>
target: people standing and sitting on snow with skis and mountain <EOS>
predicted: a a a a a a a a <EOS>



In [43]:
test_loss, test_results, test_bleu_score = test(x_lookup_test, y_lookup_test, encoder, decoder, x_idx2word, y_idx2word)