In [None]:
from google.colab import drive
drive.mount('/content/drive/')


import os
os.chdir("drive/My Drive/Advanced NLP/Exam")

import pickle
import random
import time
from collections import Counter, defaultdict

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torch
import torch.nn as nn
from torch import optim

plt.switch_backend('agg')
import numpy as np
from tqdm import tqdm

from models import EncoderGRU, AttnDecoderGRU, EncoderLSTM, DecoderLSTM, AttnDecoderLSTM
from utils import Lang, tensorsFromPair, timeSince, showPlot


Mounted at /content/drive/


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset_path = '../SCAN-master'
task_name = 'add_prim'  # or 'length'
primitive = 'turn_left'

train_file_name = '{}_split/tasks_{}_{}.txt'.format(task_name, 'train', 'addprim'+'_'+primitive)
test_file_name = '{}_split/tasks_{}_{}.txt'.format(task_name, 'test', 'addprim'+'_'+primitive)
train_file_path = os.path.join(dataset_path, train_file_name)
test_file_path = os.path.join(dataset_path, test_file_name)

# train_file_path, test_file_path

SOS_token = 0
EOS_token = 1

command_le = Lang('command')
action_le = Lang('action')


def dataloader(path):
    with open(path, 'r') as f:
        dataset = f.readlines()

    def preprocess_data(line):
        line = line.strip().split()
        split_index = line.index('OUT:')
        inp = line[1: split_index]
        outp = line[split_index + 1:]
        command_le.addSentence(inp)
        action_le.addSentence(outp)
        return [inp, outp]

    pairs = list(map(preprocess_data, dataset))
    input_commands, output_actions = np.transpose(pairs).tolist()
    return input_commands, output_actions, pairs


commands_train, actions_train, pairs_train = dataloader(train_file_path)
commands_test, actions_test, pairs_test = dataloader(test_file_path)

MAX_LENGTH = max([len(action) for action in actions_test]) + 1

teacher_forcing_ratio = 0.5


def train(input_tensor, target_tensor, encoder, decoder,
          encoder_optimizer, decoder_optimizer, criterion,
          model='lstm', attention=False):
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_hiddens = torch.zeros(input_length, encoder.hidden_size, device=device)

    loss = 0
    gold_pred = 0

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(
            input_tensor[ei], encoder_hidden)
        if attention:
            if model == "lstm":
                encoder_hiddens[ei] = encoder_hidden[0][0,0]
            elif model == "gru":
                encoder_hiddens[ei] = encoder_hidden[0,0]

    decoder_input = torch.tensor([[SOS_token]], device=device)

    decoder_hidden = encoder_hidden
    
    preds = []
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    
    if use_teacher_forcing:
        for di in range(target_length):
            if attention:
                decoder_output, decoder_hidden, decoder_attention = decoder(
                    decoder_input, decoder_hidden, encoder_hiddens)
            else:
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            pred = topi.squeeze()
            preds.append(topi.squeeze().item())
            
            loss += criterion(decoder_output, target_tensor[di])
            decoder_input = target_tensor[di]
    

    else:
        for di in range(target_length):
            if attention:
                decoder_output, decoder_hidden, decoder_attention = decoder(
                    decoder_input, decoder_hidden, encoder_hiddens)
            else:
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            pred = topi.squeeze()
            preds.append(topi.squeeze().item())
            decoder_input = topi.squeeze().detach()

            loss += criterion(decoder_output, target_tensor[di])
            if decoder_input.item() == EOS_token:
                target_length = di + 1
                break
                
    correct = torch.equal(torch.Tensor(preds).to(device), target_tensor.squeeze())

    loss.backward()

    torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=5.0)
    torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=5.0)
    
    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length, correct


def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100,
               learning_rate=0.001, model='gru', attention=False):
    start = time.time()

    accuracy = 0 
    plot_losses = []
    plot_accs = []
    print_loss_total = 0
    plot_loss_total = 0
    print_pred_total = 0
    print_label_total = 0
    plot_pred_total = 0
    plot_label_total = 0

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    training_pairs = [tensorsFromPair(random.choice(pairs_train), command_le, action_le)
                      for i in range(n_iters)]
    criterion = nn.NLLLoss()

    for iter in range(1, n_iters + 1):
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]
        # target_length = target_tensor.size(0)

        loss, correct = train(input_tensor, target_tensor, encoder, decoder, 
                              encoder_optimizer, decoder_optimizer, criterion, 
                              model, attention=attention)
        print_pred_total += int(correct)
        plot_pred_total += int(correct)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:            
            print_acc_avg = print_pred_total / print_every
            accuracy = print_acc_avg
            print_loss_avg = print_loss_total / print_every
            print_pred_total = 0
            print_label_total = 0
            print_loss_total = 0
            print('%s (%d %d%%) loss: %.4f acc: %.4f' % (timeSince(start, iter / n_iters),
                                                         iter, iter / n_iters * 100, print_loss_avg, print_acc_avg))

        if iter % plot_every == 0:
            plot_acc_avg = plot_pred_total / plot_every
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_accs.append(plot_acc_avg)
            plot_loss_total = 0
            plot_pred_total = 0
            plot_label_total = 0

    showPlot(plot_losses, plot_accs)
    return accuracy


def evaluate(encoder, decoder, pair, model='gru', attention=False, max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor, target_tensor = tensorsFromPair(pair, command_le, action_le)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        encoder_hiddens = torch.zeros(input_length, encoder.hidden_size, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
            
            if attention:
                if model == 'lstm':
                    encoder_hiddens[ei] += encoder_hidden[0][0,0]
                elif model == 'gru':
                    encoder_hiddens[ei] += encoder_hidden[0,0]


        decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS

        decoder_hidden = encoder_hidden

        decoded_words = []
        decoder_attentions = torch.zeros(max_length, input_length)

        for di in range(max_length):
            if attention:
                decoder_output, decoder_hidden, decoder_attention = decoder(
                    decoder_input, decoder_hidden, encoder_hiddens)
                decoder_attentions[di] = decoder_attention.squeeze().data
            else:
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(action_le.index2word[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words, decoder_attentions[:di+1]


def evaluate_model(encoder, decoder, test_pairs, model='gru', attention=False):
    command_cnt = Counter([len(test_pair[0]) for test_pair in test_pairs])
    action_cnt = Counter([len(test_pair[1]) for test_pair in test_pairs])
    command_correct_cnt = defaultdict(int)
    action_correct_cnt = defaultdict(int)
    correct = 0

    for pair in tqdm(test_pairs):
        preds, attentions = evaluate(encoder, decoder, pair, model=model, attention=attention)
        preds = preds[:-1]
        target_output = pair[1]
        if preds == target_output:
            command_correct_cnt[len(pair[0])] += 1
            action_correct_cnt[len(pair[1])] += 1
            correct += 1
            
    command_correct_cnt = dict(command_correct_cnt)
    action_correct_cnt = dict(action_correct_cnt)
    command_cnt = dict(command_cnt)
    action_cnt = dict(action_cnt)

    command_acc = {}
    for command_length, cnt in command_cnt.items():
        command_acc[command_length] = command_correct_cnt.get(
            command_length, 0) / cnt

    action_acc = {}
    for action_length, cnt in action_cnt.items():
        action_acc[action_length] = action_correct_cnt.get(
            action_length, 0) / cnt
            
    return command_acc, action_acc, correct / len(test_pairs)


def evaluateRandomly(encoder, decoder, model='gru', n=10):
    for i in range(n):
        pair = random.choice(pairs_test)
        print('>', pair[0])
        print('=', pair[1])
        output_words, attentions = evaluate(encoder, decoder, pair, model=model)
        output_sentence = output_words
        print('<', output_sentence)
        print('')
        

def showAttention(input_sentence, output_words, attentions):
    # Set up figure with colorbar
    fig = plt.figure(figsize=(16,8))
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')
    fig.colorbar(cax)

    # Set up axes
    ax.set_xticklabels([''] + input_sentence +
                       ['<EOS>'], rotation=90)
    ax.set_yticklabels([''] + output_words)

    # Show label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()


def evaluateAndShowAttention(encoder, decoder, pair, model='gru'):
    output_words, attentions = evaluate(
        encoder, decoder, pair, model=model)
    print('input =', pair[0])
    print('output =', output_words)
    showAttention(pair[0], output_words, attentions)
    


def evaluateAndShowAttentionExample(encoder, decoder, model='gru'):
    for i in range(len(pairs_test)):
        preds, attentions = evaluate(encoder, decoder, pairs_test[i], model=model)
        preds = preds[:-1]
        target_output = pairs_test[i][1]
        if preds == target_output:
            evaluateAndShowAttention(encoder, decoder, pairs_test[i])
            break
        
        
def main_run(hidden_size, num_iter, num_runs, model, dropout=0.5, num_layers=1, attention=True, experiment="3a"):

    input_size = command_le.n_words
    output_size = action_le.n_words

    best_encoder = None
    best_decoder = None
    best_acc = 0

    command_accs, action_accs, overall_accs, train_accs = [], [], [], []

    for i in range(num_runs):
        if model == "lstm":
            encoder = EncoderLSTM(input_size, hidden_size, num_layers=num_layers, dropout=dropout).to(device)
        elif model == "gru":
            encoder = EncoderGRU(input_size, hidden_size, num_layers=num_layers, dropout=dropout).to(device)

        if attention:
            if model == "lstm":
                decoder = AttnDecoderLSTM(hidden_size, output_size, dropout=dropout).to(device)
            elif model == "gru":
                decoder = AttnDecoderGRU(hidden_size, output_size, dropout=0.5).to(device)

        else:
            if model == "lstm":
                decoder = DecoderLSTM(hidden_size, output_size, num_layers=num_layers, dropout=dropout).to(device)

        accuracy_train = trainIters(encoder, decoder, num_iter, print_every=1000, model=model, attention=attention)
        acc_command, acc_action, acc_overall = evaluate_model(encoder, decoder, pairs_test, model=model)
            

        if acc_overall > best_acc:
            best_encoder = encoder
            best_decoder = decoder
            best_acc = acc_overall

        command_accs.append(acc_command)
        action_accs.append(acc_action)
        overall_accs.append(acc_overall)
        train_accs.append(accuracy_train)
    
    # Save results          
    torch.save(best_encoder.state_dict(), f"encoder_{experiment}_{model}_{attention}.pt")
    torch.save(best_decoder.state_dict(), f"decoder{experiment}_{model}_{attention}.pt")

    with open(f'train_results_{model}_{attention}.pickle', 'wb') as f:
        pickle.dump([command_accs, action_accs,overall_accs, train_accs], f)
    
    return command_accs, action_accs, overall_accs, train_accs


def calculate_mean_std(acc_dict):
    mean = []
    error = []
    keys = sorted(acc_dict[0])
    num_runs = len(acc_dict)
    
    for key in keys:
        t = []
        for d in acc_dict:
            t.append(d[key])
        mean.append(np.mean(t))
        error.append(np.std(t) / np.sqrt(num_runs))
    return mean, error, keys

In [None]:
# top-performing model: Gru with Attention
main_run(hidden_size=100, num_layers = 1, dropout=0.1, num_iter=10000, model='gru', num_runs=1, attention=True, experiment="3a1")

0m 32s (- 4m 56s) (1000 10%) loss: 1.0930 acc: 0.0980
1m 6s (- 4m 24s) (2000 20%) loss: 0.8344 acc: 0.1140
1m 39s (- 3m 51s) (3000 30%) loss: 0.6458 acc: 0.1500
2m 11s (- 3m 17s) (4000 40%) loss: 0.4686 acc: 0.2890
2m 44s (- 2m 44s) (5000 50%) loss: 0.4111 acc: 0.3510
3m 16s (- 2m 11s) (6000 60%) loss: 0.3080 acc: 0.4510
3m 49s (- 1m 38s) (7000 70%) loss: 0.3233 acc: 0.4400
4m 21s (- 1m 5s) (8000 80%) loss: 0.2974 acc: 0.4940
4m 55s (- 0m 32s) (9000 90%) loss: 0.2884 acc: 0.5040
5m 29s (- 0m 0s) (10000 100%) loss: 0.2338 acc: 0.5830


100%|██████████| 1208/1208 [00:11<00:00, 109.71it/s]


([{7: 0.3008474576271186,
   5: 0.5294117647058824,
   8: 0.259375,
   6: 0.3716216216216216,
   4: 0.8125,
   3: 0.0}],
 [{25: 0.0625,
   3: 0.6,
   8: 0.7916666666666666,
   6: 0.35294117647058826,
   9: 0.017857142857142856,
   26: 0.25,
   11: 0.4027777777777778,
   5: 0.3719512195121951,
   19: 0.0,
   10: 0.3888888888888889,
   7: 0.15833333333333333,
   2: 0.6521739130434783,
   27: 0.21875,
   4: 0.5579710144927537,
   13: 0.0,
   12: 0.0,
   17: 0.0625,
   18: 0.25,
   14: 0.0,
   15: 0.0}],
 [0.33278145695364236],
 [0.583])

In [None]:
# overall best model: LSTM
main_run(hidden_size=200, num_layers=2, dropout=0.5, num_iter=10000, num_runs=1, model='lstm', experiment="3a2", attention=False)

0m 22s (- 3m 20s) (1000 10%) loss: 1.1823 acc: 0.1010
0m 45s (- 3m 0s) (2000 20%) loss: 0.7575 acc: 0.1140
1m 5s (- 2m 33s) (3000 30%) loss: 0.5888 acc: 0.1650
1m 27s (- 2m 10s) (4000 40%) loss: 0.4850 acc: 0.2200
1m 48s (- 1m 48s) (5000 50%) loss: 0.4210 acc: 0.2380
2m 10s (- 1m 27s) (6000 60%) loss: 0.3641 acc: 0.3050
2m 33s (- 1m 5s) (7000 70%) loss: 0.3281 acc: 0.3800
2m 54s (- 0m 43s) (8000 80%) loss: 0.2709 acc: 0.4180
3m 15s (- 0m 21s) (9000 90%) loss: 0.2538 acc: 0.4650
3m 37s (- 0m 0s) (10000 100%) loss: 0.2364 acc: 0.4940


100%|██████████| 1208/1208 [00:06<00:00, 174.35it/s]


([{7: 0.125,
   5: 0.30392156862745096,
   8: 0.1375,
   6: 0.2702702702702703,
   4: 0.25,
   3: 0.0}],
 [{25: 0.0,
   3: 0.4117647058823529,
   8: 0.5138888888888888,
   6: 0.13725490196078433,
   9: 0.125,
   26: 0.0,
   11: 0.2638888888888889,
   5: 0.23170731707317074,
   19: 0.0,
   10: 0.1111111111111111,
   7: 0.16666666666666666,
   2: 0.2608695652173913,
   27: 0.0,
   4: 0.17391304347826086,
   13: 0.0,
   12: 0.0,
   17: 0.03125,
   18: 0.0,
   14: 0.25,
   15: 0.0}],
 [0.1804635761589404],
 [0.494])