In [14]:
import ast
import warnings
import re
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import random
warnings.filterwarnings("ignore")

SOS_token = 0
EOS_token = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
def load_text_data():
    with open("../../data/x.txt") as f:
        x_file = f.readlines()

    with open("../../data/y.txt") as f:
        y_file = f.readlines()

    x_list = []
    y_list_temp = []

    for x_val, y_val in zip(x_file, y_file):
        x_list.append((x_val.split(", \n")[0]))
        y_list_temp.append((y_val.split(" \n")[0]))

    y_list = []
    for y in y_list_temp:
        x = y
        x = ast.literal_eval(x)
        y_list.append(x)

    x_list_clean = []
    y_list_clean = []

    for x in x_list:
        x = re.sub(r'[^\w\s]', '', x)
        x = x.lower()
        x_list_clean.append(x)

    for y in y_list:
        for i in range(len(y)):
            y[i] = re.sub(r'[^\w\s]', '', y[i])
            y[i] = y[i].lower()
        y_list_clean.append(y)

    return x_list_clean, y_list_clean

x_text, y_text = load_text_data()

In [16]:
x_word2idx = {}
x_idx2word = {0: "SOS", 1: "EOS"}
x_word2count = {}
x_n_words = 2 # Count SOS and EOS

for x in x_text:
    for word in x.split():
        if word not in x_word2idx:
            x_word2idx[word] = x_n_words
            x_idx2word[x_n_words] = word
            x_n_words += 1
            x_word2count[word] = 1
        else:
            x_word2count[word] += 1

x_lookup = []

for x in x_text:
    x_lookup_sentence = [x_word2idx[word] for word in x.split()]
    x_lookup_sentence.append(EOS_token)
    x_lookup.append(x_lookup_sentence)

In [17]:
y_word2idx = {}
y_idx2word = {0: "SOS", 1: "EOS"}
y_word2count = {}
y_n_words = 2 # Count SOS and EOS

for y in y_text:
    for i in range(len(y)):
        for word in y[i].split():
            if word not in y_word2idx:
                y_word2idx[word] = y_n_words
                y_idx2word[y_n_words] = word
                y_n_words += 1
                y_word2count[word] = 1
            else:
                y_word2count[word] += 1

y_lookup = []

for y in y_text:
    y_lookup_temp = []
    for i in range(len(y)):
        y_lookup_sentence = [y_word2idx[word] for word in y[i].split()]
        y_lookup_sentence.append(EOS_token)
        y_lookup_temp.append(y_lookup_sentence)
    y_lookup.append(y_lookup_temp)

In [18]:
x_lookup_train = x_lookup[:int(len(x_lookup)*0.8)]
x_lookup_val = x_lookup[int(len(x_lookup)*0.8):int(len(x_lookup)*0.9)]
x_lookup_test = x_lookup[int(len(x_lookup)*0.9):]

y_lookup_train = y_lookup[:int(len(y_lookup)*0.8)]
y_lookup_val = y_lookup[int(len(y_lookup)*0.8):int(len(y_lookup)*0.9)]
y_lookup_test = y_lookup[int(len(y_lookup)*0.9):]

In [19]:
def get_lookup_tensors(index):
    '''Returns a tuple of tensors (x, y) where x is a tensor of shape (seq_len, 1) and y is a list of tensors of shape (seq_len, 1) of length 5.'''
    y_lookup_tensor_list = [torch.tensor(y_lookup[1][i], device=device).view(-1, 1) for i in range(len(y_lookup[index]))]
    x_lookup_tensor = torch.tensor(x_lookup[index], device=device).view(-1, 1)
    return x_lookup_tensor, y_lookup_tensor_list

In [20]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [21]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [22]:
hidden_size = 256
encoder = EncoderRNN(x_n_words, hidden_size).to(device)
decoder = DecoderRNN(hidden_size, y_n_words).to(device)

encoder_optimizer = optim.SGD(encoder.parameters(), lr=0.01)
decoder_optimizer = optim.SGD(decoder.parameters(), lr=0.01)

criterion = nn.NLLLoss()

In [23]:
def batch_train_step(x_batch, y_batch, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion):
    batch_loss = 0

    for i in range(len(x_batch)):
        x = x_batch[i]
        y = y_batch[i]
        encoder_hidden = encoder.initHidden()
        loss = 0

        input_length = len(x)
        n_captions = len(y)

        for ei in range(input_length):            
            encoder_output, encoder_hidden = encoder(torch.tensor(x[ei], device=device), encoder_hidden)
        
        decoder_input = torch.tensor([[SOS_token]], device=device)
        decoder_hidden = encoder_hidden

        for di in range(n_captions):
            teacher_forcing = True if random.random() < 0.5 else False
            
            for word_index in y[di]:
                if teacher_forcing:
                    decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    loss += criterion(decoder_output, torch.tensor(word_index, device=device).view(-1))
                    decoder_input = torch.tensor(word_index, device=device)
                else:
                    decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    loss += criterion(decoder_output, torch.tensor(word_index, device=device).view(-1))
                    topv, topi = decoder_output.topk(1)
                    decoder_input = topi.squeeze().detach()
            
        loss = loss / n_captions
        batch_loss += loss.item()
        
        loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()
    
    return batch_loss / len(x_batch)

In [24]:
def val_step(x_val, y_val, encoder, decoder, criterion, target_length=50):
    val_loss = 0

    for i in range(len(x_val)):
        val_loss_i = 0
        x = x_val[i]
        y = y_val[i]
        input_length = len(x)
        n_captions = len(y)
        encoder_hidden = encoder.initHidden()

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(torch.tensor(x[ei], device=device), encoder_hidden)

        decoder_input = torch.tensor([[SOS_token]], device=device)
        decoder_hidden = encoder_hidden

        for di in range(n_captions):
            for dj in range(target_length):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                y_di_dj = y[di][dj] if dj < len(y[di]) else EOS_token
                val_loss_i += criterion(decoder_output, torch.tensor(y_di_dj, device=device).view(-1))
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach()

        val_loss_i = val_loss_i / n_captions
        val_loss += val_loss_i.item()
    
    return val_loss / len(x_val)

In [None]:
num_epochs = 100
batch_size = 2
num_batches = len(x_lookup_train) // batch_size
train_loss_list = []
val_loss_list = []

for epoch in range(num_epochs):
    epoch_loss = 0

    for batch in range(num_batches):
        x_batch = x_lookup_train[batch*batch_size:(batch+1)*batch_size]
        y_batch = y_lookup_train[batch*batch_size:(batch+1)*batch_size]
        
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        
        batch_loss = batch_train_step(x_batch, y_batch, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        epoch_loss += batch_loss / num_batches
        print(f"Epoch: {epoch+1}/{num_epochs} | Batch: {batch+1}/{num_batches} | Loss: {batch_loss}")
        
    with torch.no_grad():
        val_loss = val_step(x_lookup_val, y_lookup_val, encoder, decoder, criterion)
        print(f"Epoch: {epoch} / {num_epochs} \t Epoch Loss: {epoch_loss:.3f} \t Val Loss: {val_loss:.3f}")        
    
    train_loss_list.append(epoch_loss)
    val_loss_list.append(val_loss)

In [33]:
def evaluate(x, y, x_idx2word, y_idx2word, encoder, decoder, target_length=50):
    for i in range(len(x)):
        x_i = x[i]
        y_i = y[i]
        input_length = len(x_i)
        n_captions = len(y_i)
        encoder_hidden = encoder.initHidden()

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(torch.tensor(x_i[ei], device=device), encoder_hidden)

        decoder_input = torch.tensor([[SOS_token]], device=device)
        decoder_hidden = encoder_hidden
        decoded_sentences = []

        for di in range(n_captions):
            decoded_words = []
            for dj in range(target_length):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach()
                decoded_words.append(y_idx2word[topi.item()])
                if topi.item() == EOS_token:
                    is_eos = True
                    break
                else:
                    is_eos = False

            predicted_sentence = ' '.join(decoded_words)
            decoded_sentences.append(predicted_sentence)
            if is_eos:
                break
            
        input_text = ' '.join([x_idx2word[word] for word in x_i])
        target_text = ' '.join([y_idx2word[word] for word in y_i[0]])

        print(f"input: {input_text}")
        print(f"target: {target_text}")
        print(f"predicted: {decoded_sentences[0]}")
        print()

In [32]:
def evaluate_randomly(x_idx2word, y_idx2word, encoder, decoder, n=1):
    for i in range(n):
        rand_int = np.random.randint(0, len(x_lookup_val))
        x_i = x_lookup_val[rand_int]
        y_i = y_lookup_val[rand_int]

        evaluate([x_i], [y_i], x_idx2word, y_idx2word, encoder, decoder)

In [31]:
def test(x_test, y_test, encoder, decoder, x_idx2word, y_idx2word, target_length=50):
    test_loss = 0
    test_results = []

    for i in range(len(x_test)):
        test_loss_i = 0
        x = x_test[i]
        y = y_test[i]
        input_length = len(x)
        n_captions = len(y)
        encoder_hidden = encoder.initHidden()

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(torch.tensor(x[ei], device=device), encoder_hidden)

        decoder_input = torch.tensor([[SOS_token]], device=device)
        decoder_hidden = encoder_hidden
        decoded_sentences = []

        for di in range(n_captions):
            decoded_words = []
            for dj in range(target_length):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                y_di_dj = y[di][dj] if dj < len(y[di]) else EOS_token
                test_loss_i += criterion(decoder_output, torch.tensor(y_di_dj, device=device).view(-1))
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach()
                
                if topi.item() == EOS_token:
                    is_eos = True
                    decoded_words.append('<EOS>')
                    break
                else:
                    is_eos = False
                    decoded_words.append(y_idx2word[topi.item()])
                
            predicted_sentence = ' '.join(decoded_words)
            decoded_sentences.append(predicted_sentence)
            if is_eos:
                break

        test_loss_i = test_loss_i / n_captions
        test_loss += test_loss_i.item()

        input_text = ' '.join([x_idx2word[word] for word in x])
        target_text = ' '.join([y_idx2word[word] for word in y[0]])
        predicted_text = decoded_sentences[0]

        test_result_i = (input_text, target_text, predicted_text)
        test_results.append(test_result_i)

    return test_loss / len(x_test), test_results

In [None]:
evaluate(x_lookup_val, y_lookup_val, x_idx2word, y_idx2word, encoder, decoder)

In [None]:
evaluate_randomly(x_idx2word, y_idx2word, encoder, decoder)

In [None]:
test_loss = test(x_lookup_test, y_lookup_test, encoder, decoder, x_idx2word, y_idx2word)