In [1]:
from google.colab import drive
drive.mount('/content/drive/')


import os
os.chdir("drive/My Drive/Advanced NLP/Exam")

import pickle
import random
import time
from collections import Counter, defaultdict

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torch
import torch.nn as nn
from torch import optim

plt.switch_backend('agg')
import numpy as np
from tqdm import tqdm

from models import EncoderGRU, AttnDecoderGRU, EncoderLSTM, DecoderLSTM, AttnDecoderLSTM
from utils import Lang, tensorsFromPair, timeSince, showPlot


MessageError: ignored

In [None]:
os.getcwd()

'/content/drive/My Drive/Advanced NLP/Exam'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset_path = '../SCAN-master'
task_name = 'add_prim'  # or 'length'
primitive = 'jump'

train_file_name = '{}_split/tasks_{}_{}.txt'.format(task_name, 'train', 'addprim'+'_'+primitive)
test_file_name = '{}_split/tasks_{}_{}.txt'.format(task_name, 'test', 'addprim'+'_'+primitive)
train_file_path = os.path.join(dataset_path, train_file_name)
test_file_path = os.path.join(dataset_path, test_file_name)

# train_file_path, test_file_path

SOS_token = 0
EOS_token = 1

command_le = Lang('command')
action_le = Lang('action')


def dataloader(path):
    with open(path, 'r') as f:
        dataset = f.readlines()

    def preprocess_data(line):
        line = line.strip().split()
        split_index = line.index('OUT:')
        inp = line[1: split_index]
        outp = line[split_index + 1:]
        command_le.addSentence(inp)
        action_le.addSentence(outp)
        return [inp, outp]

    pairs = list(map(preprocess_data, dataset))
    input_commands, output_actions = np.transpose(pairs).tolist()
    return input_commands, output_actions, pairs


commands_train, actions_train, pairs_train = dataloader(train_file_path)
commands_test, actions_test, pairs_test = dataloader(test_file_path)

MAX_LENGTH = max([len(action) for action in actions_test]) + 1

teacher_forcing_ratio = 0.5


def train(input_tensor, target_tensor, encoder, decoder,
          encoder_optimizer, decoder_optimizer, criterion,
          model='lstm', attention=True):
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_hiddens = torch.zeros(input_length, encoder.hidden_size, device=device)

    loss = 0
    gold_pred = 0

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(
            input_tensor[ei], encoder_hidden)
        if attention:
            if model == "lstm":
                encoder_hiddens[ei] = encoder_hidden[0][0,0]
            elif model == "gru":
                encoder_hiddens[ei] = encoder_hidden[0,0]

    decoder_input = torch.tensor([[SOS_token]], device=device)

    decoder_hidden = encoder_hidden
    
    preds = []
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    
    if use_teacher_forcing:
        for di in range(target_length):
            if attention:
                decoder_output, decoder_hidden, decoder_attention = decoder(
                    decoder_input, decoder_hidden, encoder_hiddens)
            else:
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            pred = topi.squeeze()
            preds.append(topi.squeeze().item())
            
            loss += criterion(decoder_output, target_tensor[di])
            decoder_input = target_tensor[di]
    

    else:
        for di in range(target_length):
            if attention:
                decoder_output, decoder_hidden, decoder_attention = decoder(
                    decoder_input, decoder_hidden, encoder_hiddens)
            else:
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            pred = topi.squeeze()
            preds.append(topi.squeeze().item())
            decoder_input = topi.squeeze().detach()

            loss += criterion(decoder_output, target_tensor[di])
            if decoder_input.item() == EOS_token:
                target_length = di + 1
                break
                
    correct = torch.equal(torch.Tensor(preds).to(device), target_tensor.squeeze())

    loss.backward()

    torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=5.0)
    torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=5.0)
    
    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length, correct


def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100,
               learning_rate=0.001, model='gru', attention=True):
    start = time.time()

    accuracy = 0 
    plot_losses = []
    plot_accs = []
    print_loss_total = 0
    plot_loss_total = 0
    print_pred_total = 0
    print_label_total = 0
    plot_pred_total = 0
    plot_label_total = 0

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    training_pairs = [tensorsFromPair(random.choice(pairs_train), command_le, action_le)
                      for i in range(n_iters)]
    criterion = nn.NLLLoss()

    for iter in range(1, n_iters + 1):
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]
        # target_length = target_tensor.size(0)

        loss, correct = train(input_tensor, target_tensor, encoder, decoder, 
                              encoder_optimizer, decoder_optimizer, criterion, 
                              model, attention=attention)
        print_pred_total += int(correct)
        plot_pred_total += int(correct)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:            
            print_acc_avg = print_pred_total / print_every
            accuracy = print_acc_avg
            print_loss_avg = print_loss_total / print_every
            print_pred_total = 0
            print_label_total = 0
            print_loss_total = 0
            print('%s (%d %d%%) loss: %.4f acc: %.4f' % (timeSince(start, iter / n_iters),
                                                         iter, iter / n_iters * 100, print_loss_avg, print_acc_avg))

        if iter % plot_every == 0:
            plot_acc_avg = plot_pred_total / plot_every
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_accs.append(plot_acc_avg)
            plot_loss_total = 0
            plot_pred_total = 0
            plot_label_total = 0

    showPlot(plot_losses, plot_accs)
    return accuracy


def evaluate(encoder, decoder, pair, model='gru', attention=True, max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor, target_tensor = tensorsFromPair(pair, command_le, action_le)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        encoder_hiddens = torch.zeros(input_length, encoder.hidden_size, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
            
            if attention:
                if model == 'lstm':
                    encoder_hiddens[ei] += encoder_hidden[0][0,0]
                elif model == 'gru':
                    encoder_hiddens[ei] += encoder_hidden[0,0]


        decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS

        decoder_hidden = encoder_hidden

        decoded_words = []
        decoder_attentions = torch.zeros(max_length, input_length)

        for di in range(max_length):
            if attention:
                decoder_output, decoder_hidden, decoder_attention = decoder(
                    decoder_input, decoder_hidden, encoder_hiddens)
                decoder_attentions[di] = decoder_attention.squeeze().data
            else:
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(action_le.index2word[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words, decoder_attentions[:di+1]


def evaluate_model(encoder, decoder, test_pairs, model='gru', attention=True):
    command_cnt = Counter([len(test_pair[0]) for test_pair in test_pairs])
    action_cnt = Counter([len(test_pair[1]) for test_pair in test_pairs])
    command_correct_cnt = defaultdict(int)
    action_correct_cnt = defaultdict(int)
    correct = 0

    for pair in tqdm(test_pairs):
        preds, attentions = evaluate(encoder, decoder, pair, model=model, attention=attention)
        preds = preds[:-1]
        target_output = pair[1]
        if preds == target_output:
            command_correct_cnt[len(pair[0])] += 1
            action_correct_cnt[len(pair[1])] += 1
            correct += 1
            
    command_correct_cnt = dict(command_correct_cnt)
    action_correct_cnt = dict(action_correct_cnt)
    command_cnt = dict(command_cnt)
    action_cnt = dict(action_cnt)

    command_acc = {}
    for command_length, cnt in command_cnt.items():
        command_acc[command_length] = command_correct_cnt.get(
            command_length, 0) / cnt

    action_acc = {}
    for action_length, cnt in action_cnt.items():
        action_acc[action_length] = action_correct_cnt.get(
            action_length, 0) / cnt
            
    return command_acc, action_acc, correct / len(test_pairs)


def evaluateRandomly(encoder, decoder, model='gru', n=10):
    for i in range(n):
        pair = random.choice(pairs_test)
        print('>', pair[0])
        print('=', pair[1])
        output_words, attentions = evaluate(encoder, decoder, pair, model=model)
        output_sentence = output_words
        print('<', output_sentence)
        print('')
        

def showAttention(input_sentence, output_words, attentions):
    # Set up figure with colorbar
    fig = plt.figure(figsize=(16,8))
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')
    fig.colorbar(cax)

    # Set up axes
    ax.set_xticklabels([''] + input_sentence +
                       ['<EOS>'], rotation=90)
    ax.set_yticklabels([''] + output_words)

    # Show label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()


def evaluateAndShowAttention(encoder, decoder, pair, model='gru'):
    output_words, attentions = evaluate(
        encoder, decoder, pair, model=model)
    print('input =', pair[0])
    print('output =', output_words)
    showAttention(pair[0], output_words, attentions)
    


def evaluateAndShowAttentionExample(encoder, decoder, model='gru'):
    for i in range(len(pairs_test)):
        preds, attentions = evaluate(encoder, decoder, pairs_test[i], model=model)
        preds = preds[:-1]
        target_output = pairs_test[i][1]
        if preds == target_output:
            evaluateAndShowAttention(encoder, decoder, pairs_test[i])
            break
        
        
def main_run(hidden_size, num_iter, num_runs, model, dropout=0.5, num_layers=1, attention=True, experiment="3b1"):

    input_size = command_le.n_words
    output_size = action_le.n_words

    best_encoder = None
    best_decoder = None
    best_acc = 0

    command_accs, action_accs, overall_accs, train_accs = [], [], [], []

    for i in range(num_runs):
        if model == "lstm":
            encoder = EncoderLSTM(input_size, hidden_size, num_layers=num_layers, dropout=dropout).to(device)
        elif model == "gru":
            encoder = EncoderGRU(input_size, hidden_size, num_layers=num_layers, dropout=dropout).to(device)

        if attention:
            if model == "lstm":
                decoder = AttnDecoderLSTM(hidden_size, output_size, dropout=dropout).to(device)
            elif model == "gru":
                decoder = AttnDecoderGRU(hidden_size, output_size, dropout=0.5).to(device)

        else:
            if model == "lstm":
                decoder = DecoderLSTM(hidden_size, output_size, num_layers=num_layers, dropout=dropout).to(device)

        accuracy_train = trainIters(encoder, decoder, num_iter, print_every=1000, model=model, attention=attention)
        acc_command, acc_action, acc_overall = evaluate_model(encoder, decoder, pairs_test, model=model, attention=attention)
            

        if acc_overall > best_acc:
            best_encoder = encoder
            best_decoder = decoder
            best_acc = acc_overall

        command_accs.append(acc_command)
        action_accs.append(acc_action)
        overall_accs.append(acc_overall)
        train_accs.append(accuracy_train)
        
    # Save results          
    torch.save(best_encoder.state_dict(), f"encoder_{experiment}_{model}_{attention}.pt")
    torch.save(best_decoder.state_dict(), f"decoder{experiment}_{model}_{attention}.pt")

    with open(f'train_results_{model}_{attention}.pickle', 'wb') as f:
        pickle.dump([command_accs, action_accs,overall_accs, train_accs], f)
   
    return command_accs, action_accs, overall_accs, train_accs


def calculate_mean_std(acc_dict):
    mean = []
    error = []
    keys = sorted(acc_dict[0])
    num_runs = len(acc_dict)
    
    for key in keys:
        t = []
        for d in acc_dict:
            t.append(d[key])
        mean.append(np.mean(t))
        error.append(np.std(t) / np.sqrt(num_runs))
    return mean, error, keys

  result = getattr(asarray(obj), method)(*args, **kwds)


In [None]:
# top-performing model: LSTM with Attention
main_run(hidden_size=100, num_layers = 1, dropout=0.1, num_iter=50000, model='lstm', num_runs=1, attention=True)

0m 33s (- 27m 33s) (1000 2%) loss: 0.9811 acc: 0.1010


KeyboardInterrupt: ignored

In [None]:
# overall best model: LSTM
main_run(hidden_size=200, num_layers=2, dropout=0.5, num_iter=75000, num_runs=10, model='lstm', attention=False, experiment ="3b2")

0m 24s (- 30m 44s) (1000 1%) loss: 1.0924 acc: 0.0900
0m 45s (- 27m 33s) (2000 2%) loss: 0.6981 acc: 0.1300
1m 5s (- 26m 7s) (3000 4%) loss: 0.5245 acc: 0.1860
1m 24s (- 25m 6s) (4000 5%) loss: 0.4675 acc: 0.2250
1m 46s (- 24m 44s) (5000 6%) loss: 0.4126 acc: 0.2810
2m 6s (- 24m 10s) (6000 8%) loss: 0.3674 acc: 0.3090
2m 26s (- 23m 42s) (7000 9%) loss: 0.3210 acc: 0.3640
2m 46s (- 23m 15s) (8000 10%) loss: 0.2681 acc: 0.4510
3m 6s (- 22m 47s) (9000 12%) loss: 0.2338 acc: 0.5420
3m 26s (- 22m 24s) (10000 13%) loss: 0.2233 acc: 0.5520
3m 46s (- 21m 59s) (11000 14%) loss: 0.1736 acc: 0.6220
4m 6s (- 21m 35s) (12000 16%) loss: 0.1292 acc: 0.6790
4m 27s (- 21m 15s) (13000 17%) loss: 0.1070 acc: 0.6670
4m 48s (- 20m 57s) (14000 18%) loss: 0.1206 acc: 0.7150
5m 8s (- 20m 35s) (15000 20%) loss: 0.0849 acc: 0.7520
5m 29s (- 20m 15s) (16000 21%) loss: 0.0812 acc: 0.7540
5m 49s (- 19m 50s) (17000 22%) loss: 0.0792 acc: 0.8140
6m 8s (- 19m 26s) (18000 24%) loss: 0.0807 acc: 0.8150
6m 28s (- 19m 5s

100%|██████████| 7706/7706 [00:49<00:00, 155.44it/s]


0m 24s (- 29m 45s) (1000 1%) loss: 1.1238 acc: 0.0950
0m 44s (- 26m 46s) (2000 2%) loss: 0.7718 acc: 0.1170
1m 3s (- 25m 19s) (3000 4%) loss: 0.5679 acc: 0.1870
1m 22s (- 24m 31s) (4000 5%) loss: 0.4862 acc: 0.2220
1m 42s (- 23m 55s) (5000 6%) loss: 0.4219 acc: 0.2540
2m 2s (- 23m 27s) (6000 8%) loss: 0.3634 acc: 0.3330
2m 22s (- 23m 4s) (7000 9%) loss: 0.3394 acc: 0.3530
2m 42s (- 22m 41s) (8000 10%) loss: 0.3082 acc: 0.3670
3m 3s (- 22m 25s) (9000 12%) loss: 0.2627 acc: 0.4510
3m 23s (- 22m 0s) (10000 13%) loss: 0.2672 acc: 0.4780
3m 43s (- 21m 37s) (11000 14%) loss: 0.2075 acc: 0.5320
4m 3s (- 21m 16s) (12000 16%) loss: 0.1785 acc: 0.5600
4m 23s (- 20m 55s) (13000 17%) loss: 0.1575 acc: 0.6000
4m 43s (- 20m 35s) (14000 18%) loss: 0.1378 acc: 0.6240
5m 3s (- 20m 13s) (15000 20%) loss: 0.1131 acc: 0.6840
5m 23s (- 19m 52s) (16000 21%) loss: 0.1245 acc: 0.7010
5m 43s (- 19m 32s) (17000 22%) loss: 0.0928 acc: 0.7530
6m 4s (- 19m 13s) (18000 24%) loss: 0.0956 acc: 0.7590
6m 24s (- 18m 53

100%|██████████| 7706/7706 [00:51<00:00, 150.32it/s]


0m 24s (- 29m 46s) (1000 1%) loss: 1.1018 acc: 0.1000
0m 43s (- 26m 44s) (2000 2%) loss: 0.6814 acc: 0.1180
1m 4s (- 25m 38s) (3000 4%) loss: 0.5209 acc: 0.1670
1m 25s (- 25m 10s) (4000 5%) loss: 0.4690 acc: 0.2150
1m 44s (- 24m 26s) (5000 6%) loss: 0.4048 acc: 0.2780
2m 4s (- 23m 55s) (6000 8%) loss: 0.3549 acc: 0.3040
2m 25s (- 23m 30s) (7000 9%) loss: 0.3151 acc: 0.3670
2m 45s (- 23m 5s) (8000 10%) loss: 0.3228 acc: 0.4260
3m 5s (- 22m 42s) (9000 12%) loss: 0.2594 acc: 0.4410
3m 25s (- 22m 18s) (10000 13%) loss: 0.2180 acc: 0.5210
3m 46s (- 21m 55s) (11000 14%) loss: 0.2008 acc: 0.5540
4m 6s (- 21m 35s) (12000 16%) loss: 0.1718 acc: 0.5700
4m 27s (- 21m 15s) (13000 17%) loss: 0.1608 acc: 0.6150
4m 47s (- 20m 53s) (14000 18%) loss: 0.1358 acc: 0.6510
5m 7s (- 20m 31s) (15000 20%) loss: 0.1116 acc: 0.6580
5m 28s (- 20m 10s) (16000 21%) loss: 0.1130 acc: 0.6930
5m 48s (- 19m 49s) (17000 22%) loss: 0.0881 acc: 0.7580
6m 8s (- 19m 27s) (18000 24%) loss: 0.0787 acc: 0.7720
6m 28s (- 19m 5

100%|██████████| 7706/7706 [00:51<00:00, 150.42it/s]


0m 24s (- 30m 39s) (1000 1%) loss: 1.1031 acc: 0.0800
0m 44s (- 27m 12s) (2000 2%) loss: 0.7021 acc: 0.1320
1m 4s (- 25m 47s) (3000 4%) loss: 0.5386 acc: 0.1760
1m 24s (- 25m 0s) (4000 5%) loss: 0.4563 acc: 0.2160
1m 44s (- 24m 25s) (5000 6%) loss: 0.4234 acc: 0.2570
2m 4s (- 23m 53s) (6000 8%) loss: 0.3644 acc: 0.3200
2m 25s (- 23m 31s) (7000 9%) loss: 0.3260 acc: 0.3670
2m 45s (- 23m 4s) (8000 10%) loss: 0.3269 acc: 0.3630
3m 5s (- 22m 41s) (9000 12%) loss: 0.2696 acc: 0.4370
3m 26s (- 22m 19s) (10000 13%) loss: 0.2541 acc: 0.4750
3m 45s (- 21m 54s) (11000 14%) loss: 0.2243 acc: 0.5200
4m 5s (- 21m 30s) (12000 16%) loss: 0.2121 acc: 0.5390
4m 26s (- 21m 9s) (13000 17%) loss: 0.1711 acc: 0.5860
4m 46s (- 20m 48s) (14000 18%) loss: 0.1728 acc: 0.6100
5m 6s (- 20m 26s) (15000 20%) loss: 0.1241 acc: 0.6330
5m 27s (- 20m 7s) (16000 21%) loss: 0.1239 acc: 0.6590
5m 47s (- 19m 46s) (17000 22%) loss: 0.1117 acc: 0.6740
6m 7s (- 19m 24s) (18000 24%) loss: 0.1209 acc: 0.7040
6m 27s (- 19m 2s) 

100%|██████████| 7706/7706 [00:51<00:00, 149.16it/s]


0m 24s (- 30m 32s) (1000 1%) loss: 1.0923 acc: 0.0930
0m 44s (- 27m 15s) (2000 2%) loss: 0.7113 acc: 0.1270
1m 5s (- 26m 0s) (3000 4%) loss: 0.5423 acc: 0.1660
1m 25s (- 25m 14s) (4000 5%) loss: 0.4579 acc: 0.2330
1m 45s (- 24m 35s) (5000 6%) loss: 0.4114 acc: 0.2770
2m 5s (- 24m 5s) (6000 8%) loss: 0.3563 acc: 0.3160
2m 25s (- 23m 34s) (7000 9%) loss: 0.3153 acc: 0.3880
2m 45s (- 23m 7s) (8000 10%) loss: 0.3155 acc: 0.4160
3m 7s (- 22m 53s) (9000 12%) loss: 0.2753 acc: 0.4460
3m 27s (- 22m 27s) (10000 13%) loss: 0.2375 acc: 0.5030
3m 47s (- 22m 3s) (11000 14%) loss: 0.2419 acc: 0.5160
4m 7s (- 21m 39s) (12000 16%) loss: 0.2108 acc: 0.5580
4m 27s (- 21m 15s) (13000 17%) loss: 0.1795 acc: 0.6090
4m 47s (- 20m 53s) (14000 18%) loss: 0.1544 acc: 0.6170
5m 8s (- 20m 32s) (15000 20%) loss: 0.1416 acc: 0.6380
5m 28s (- 20m 9s) (16000 21%) loss: 0.1138 acc: 0.7260
5m 48s (- 19m 48s) (17000 22%) loss: 0.1206 acc: 0.7410
6m 9s (- 19m 30s) (18000 24%) loss: 0.1065 acc: 0.7690
6m 29s (- 19m 8s) (

100%|██████████| 7706/7706 [01:00<00:00, 126.35it/s]


0m 24s (- 30m 1s) (1000 1%) loss: 1.0798 acc: 0.0940
0m 43s (- 26m 44s) (2000 2%) loss: 0.6674 acc: 0.1360
1m 3s (- 25m 29s) (3000 4%) loss: 0.5452 acc: 0.1880
1m 23s (- 24m 45s) (4000 5%) loss: 0.4484 acc: 0.2400
1m 43s (- 24m 14s) (5000 6%) loss: 0.4066 acc: 0.2850
2m 3s (- 23m 40s) (6000 8%) loss: 0.3838 acc: 0.3350
2m 23s (- 23m 18s) (7000 9%) loss: 0.3082 acc: 0.3700
2m 45s (- 23m 5s) (8000 10%) loss: 0.2642 acc: 0.4430
3m 6s (- 22m 45s) (9000 12%) loss: 0.2524 acc: 0.4750
3m 26s (- 22m 19s) (10000 13%) loss: 0.2497 acc: 0.5260
3m 46s (- 21m 57s) (11000 14%) loss: 0.1716 acc: 0.5560
4m 6s (- 21m 33s) (12000 16%) loss: 0.1722 acc: 0.6270
4m 26s (- 21m 8s) (13000 17%) loss: 0.1352 acc: 0.6390
4m 46s (- 20m 48s) (14000 18%) loss: 0.1261 acc: 0.6490
5m 6s (- 20m 27s) (15000 20%) loss: 0.0924 acc: 0.7180
5m 27s (- 20m 9s) (16000 21%) loss: 0.1024 acc: 0.6880
5m 48s (- 19m 49s) (17000 22%) loss: 0.0869 acc: 0.7210
6m 9s (- 19m 28s) (18000 24%) loss: 0.0752 acc: 0.7640
6m 29s (- 19m 8s) 

100%|██████████| 7706/7706 [00:53<00:00, 144.02it/s]


0m 24s (- 30m 35s) (1000 1%) loss: 1.1103 acc: 0.0860
0m 46s (- 28m 11s) (2000 2%) loss: 0.6949 acc: 0.1190
1m 6s (- 26m 45s) (3000 4%) loss: 0.5362 acc: 0.1640
1m 27s (- 25m 51s) (4000 5%) loss: 0.4570 acc: 0.2100
1m 47s (- 25m 9s) (5000 6%) loss: 0.3902 acc: 0.2800
2m 8s (- 24m 39s) (6000 8%) loss: 0.3495 acc: 0.3140
2m 29s (- 24m 8s) (7000 9%) loss: 0.3225 acc: 0.3480
2m 49s (- 23m 35s) (8000 10%) loss: 0.3121 acc: 0.4160
3m 8s (- 23m 4s) (9000 12%) loss: 0.2652 acc: 0.4800
3m 29s (- 22m 44s) (10000 13%) loss: 0.2113 acc: 0.5150
3m 50s (- 22m 21s) (11000 14%) loss: 0.1909 acc: 0.5660
4m 10s (- 21m 56s) (12000 16%) loss: 0.1881 acc: 0.6250
4m 31s (- 21m 34s) (13000 17%) loss: 0.1618 acc: 0.6500
4m 51s (- 21m 9s) (14000 18%) loss: 0.1363 acc: 0.6850
5m 12s (- 20m 48s) (15000 20%) loss: 0.1314 acc: 0.6930
5m 32s (- 20m 26s) (16000 21%) loss: 0.1107 acc: 0.7400
5m 53s (- 20m 5s) (17000 22%) loss: 0.1025 acc: 0.7220
6m 12s (- 19m 41s) (18000 24%) loss: 0.0845 acc: 0.8030
6m 34s (- 19m 22

100%|██████████| 7706/7706 [00:53<00:00, 145.27it/s]


0m 24s (- 30m 2s) (1000 1%) loss: 1.0948 acc: 0.0880
0m 44s (- 26m 57s) (2000 2%) loss: 0.6956 acc: 0.1310
1m 4s (- 25m 55s) (3000 4%) loss: 0.5603 acc: 0.1570
1m 24s (- 25m 6s) (4000 5%) loss: 0.4596 acc: 0.2170
1m 44s (- 24m 29s) (5000 6%) loss: 0.4144 acc: 0.2720
2m 5s (- 24m 1s) (6000 8%) loss: 0.3778 acc: 0.3090
2m 26s (- 23m 41s) (7000 9%) loss: 0.3287 acc: 0.3390
2m 46s (- 23m 15s) (8000 10%) loss: 0.3045 acc: 0.4150
3m 7s (- 22m 54s) (9000 12%) loss: 0.2813 acc: 0.4340
3m 27s (- 22m 30s) (10000 13%) loss: 0.2416 acc: 0.5070
3m 48s (- 22m 6s) (11000 14%) loss: 0.1693 acc: 0.5670
4m 7s (- 21m 41s) (12000 16%) loss: 0.1832 acc: 0.5900
4m 28s (- 21m 19s) (13000 17%) loss: 0.1545 acc: 0.6150
4m 48s (- 20m 56s) (14000 18%) loss: 0.1543 acc: 0.6390
5m 9s (- 20m 36s) (15000 20%) loss: 0.1363 acc: 0.6310
5m 30s (- 20m 18s) (16000 21%) loss: 0.1038 acc: 0.7130
5m 51s (- 19m 57s) (17000 22%) loss: 0.1104 acc: 0.7270
6m 11s (- 19m 35s) (18000 24%) loss: 0.0885 acc: 0.7590
6m 31s (- 19m 13s

100%|██████████| 7706/7706 [00:51<00:00, 150.33it/s]


0m 24s (- 30m 15s) (1000 1%) loss: 1.1323 acc: 0.0780
0m 44s (- 26m 56s) (2000 2%) loss: 0.7319 acc: 0.1160
1m 4s (- 25m 52s) (3000 4%) loss: 0.5637 acc: 0.1610
1m 24s (- 25m 5s) (4000 5%) loss: 0.4557 acc: 0.2120
1m 45s (- 24m 39s) (5000 6%) loss: 0.3874 acc: 0.2610
2m 6s (- 24m 14s) (6000 8%) loss: 0.3566 acc: 0.3260
2m 26s (- 23m 45s) (7000 9%) loss: 0.3190 acc: 0.3560
2m 46s (- 23m 17s) (8000 10%) loss: 0.2946 acc: 0.4030
3m 7s (- 22m 54s) (9000 12%) loss: 0.2737 acc: 0.4500
3m 27s (- 22m 28s) (10000 13%) loss: 0.2820 acc: 0.4470
3m 47s (- 22m 6s) (11000 14%) loss: 0.2033 acc: 0.5440
4m 8s (- 21m 43s) (12000 16%) loss: 0.2295 acc: 0.5610
4m 28s (- 21m 22s) (13000 17%) loss: 0.1675 acc: 0.6230
4m 49s (- 20m 59s) (14000 18%) loss: 0.1532 acc: 0.6610
5m 10s (- 20m 42s) (15000 20%) loss: 0.1288 acc: 0.7050
5m 30s (- 20m 18s) (16000 21%) loss: 0.1209 acc: 0.7540
5m 50s (- 19m 56s) (17000 22%) loss: 0.0910 acc: 0.7550
6m 11s (- 19m 35s) (18000 24%) loss: 0.0884 acc: 0.7430
6m 31s (- 19m 

100%|██████████| 7706/7706 [01:00<00:00, 127.37it/s]


0m 24s (- 30m 29s) (1000 1%) loss: 1.0671 acc: 0.0930
0m 44s (- 27m 16s) (2000 2%) loss: 0.7145 acc: 0.1260
1m 4s (- 25m 53s) (3000 4%) loss: 0.5338 acc: 0.1770
1m 24s (- 25m 1s) (4000 5%) loss: 0.4341 acc: 0.2320
1m 44s (- 24m 21s) (5000 6%) loss: 0.4015 acc: 0.2720
2m 5s (- 24m 3s) (6000 8%) loss: 0.3523 acc: 0.3160
2m 26s (- 23m 43s) (7000 9%) loss: 0.2957 acc: 0.3880
2m 46s (- 23m 10s) (8000 10%) loss: 0.2574 acc: 0.4430
3m 6s (- 22m 44s) (9000 12%) loss: 0.2616 acc: 0.4870
3m 26s (- 22m 22s) (10000 13%) loss: 0.2341 acc: 0.4900
3m 46s (- 21m 58s) (11000 14%) loss: 0.2178 acc: 0.5600
4m 6s (- 21m 35s) (12000 16%) loss: 0.2017 acc: 0.5910
4m 26s (- 21m 10s) (13000 17%) loss: 0.1692 acc: 0.6290
4m 47s (- 20m 51s) (14000 18%) loss: 0.1454 acc: 0.6290
5m 7s (- 20m 30s) (15000 20%) loss: 0.1765 acc: 0.6100
5m 28s (- 20m 12s) (16000 21%) loss: 0.1195 acc: 0.6630
5m 49s (- 19m 52s) (17000 22%) loss: 0.1141 acc: 0.6670
6m 9s (- 19m 31s) (18000 24%) loss: 0.1017 acc: 0.7140
6m 29s (- 19m 8s

  plt.figure()
  fig, ax = plt.subplots()
100%|██████████| 7706/7706 [00:50<00:00, 152.43it/s]


([{9: 0.0, 6: 0.0, 8: 0.0, 5: 0.0, 7: 0.0, 4: 0.0, 3: 0.0, 2: 0.0},
  {9: 0.0, 6: 0.0, 8: 0.0, 5: 0.0, 7: 0.0, 4: 0.0, 3: 0.0, 2: 0.0},
  {9: 0.0,
   6: 0.000744047619047619,
   8: 0.0,
   5: 0.005859375,
   7: 0.0,
   4: 0.0078125,
   3: 0.0,
   2: 0.0},
  {9: 0.0, 6: 0.0, 8: 0.0, 5: 0.0, 7: 0.0, 4: 0.0, 3: 0.0, 2: 0.0},
  {9: 0.0, 6: 0.0, 8: 0.0, 5: 0.001953125, 7: 0.0, 4: 0.0, 3: 0.0, 2: 0.25},
  {9: 0.0, 6: 0.0, 8: 0.0, 5: 0.0, 7: 0.0, 4: 0.0, 3: 0.0, 2: 0.0},
  {9: 0.0, 6: 0.0, 8: 0.0, 5: 0.001953125, 7: 0.0, 4: 0.0, 3: 0.0, 2: 0.0},
  {9: 0.0,
   6: 0.000744047619047619,
   8: 0.0,
   5: 0.001953125,
   7: 0.0,
   4: 0.0,
   3: 0.0,
   2: 0.0},
  {9: 0.0,
   6: 0.004464285714285714,
   8: 0.0,
   5: 0.0078125,
   7: 0.0,
   4: 0.0078125,
   3: 0.0,
   2: 0.0},
  {9: 0.0, 6: 0.0, 8: 0.0, 5: 0.005859375, 7: 0.0, 4: 0.0, 3: 0.0, 2: 0.0}],
 [{12: 0.0,
   11: 0.0,
   15: 0.0,
   3: 0.0,
   5: 0.0,
   14: 0.0,
   22: 0.0,
   13: 0.0,
   8: 0.0,
   17: 0.0,
   10: 0.0,
   4: 0.0,
   19: