In [1]:
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
import torch.optim as optim
import torch.nn.utils as utils
import seaborn as sns
import matplotlib.pyplot as plt
import time
import random
from torch.utils import data
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict
import os
import glob
from tqdm import tqdm

import Levenshtein
import torchaudio

import wandb



In [2]:
# These are the various characters in the transcripts of the datasetW
VOCAB = ['<sos>',   
         'A',   'B',    'C',    'D',    
         'E',   'F',    'G',    'H',    
         'I',   'J',    'K',    'L',       
         'M',   'N',    'O',    'P',    
         'Q',   'R',    'S',    'T', 
         'U',   'V',    'W',    'X', 
         'Y',   'Z',    "'",    ' ', 
         '<eos>']

VOCAB_MAP = {VOCAB[i]:i for i in range(0, len(VOCAB))}

SOS_TOKEN = VOCAB_MAP["<sos>"]
EOS_TOKEN = VOCAB_MAP["<eos>"]

BATCH_SIZE = 96

In [3]:
class MFCCDataset:
    def __init__(self, data_path, vocab_map, val = False, cep_norm = True):
        """
        Let's access the datapaths for the input and the labels in this sections 
        x: MFCC path
        y: Transcripts 

        1) Load all the data a-priori in the init for faster training. 
        2) Cepstral normalization :  
        """
        self.val = val
        self.cep_norm = cep_norm 
        if self.val:
            self.x =  str(data_path)+"\\dev-clean\\mfcc\\*.npy" 
            self.y =  str(data_path)+"\\dev-clean\\transcript\\raw\\*.npy"
        else: 
            self.x = str(data_path)+"\\train-clean-100\\mfcc\\*.npy"
            self.y = str(data_path)+"\\train-clean-100\\transcript\\raw\\*.npy"

        

        self.mfcc_list = sorted(glob.glob(self.x))[:100]
        self.transcript_list = sorted(glob.glob(self.y))[:100]
        self.alphabets = vocab_map
       
    def __len__(self):
        
        return len(self.mfcc_list)

    def __getitem__(self, index):
        """
        cepstral normalization performed here for higher SNR 
        """

        if self.val:
            mf_temp = np.load(self.mfcc_list[index], allow_pickle= True)
            tr_temp = np.load(self.transcript_list[index], allow_pickle= True)
            tr_temp = [self.alphabets[ele] for ele in tr_temp]
            if self.cep_norm:
                mf_temp = (mf_temp - np.mean(mf_temp, axis = 0))/ np.std(mf_temp)
            
            return torch.tensor(mf_temp),torch.tensor(tr_temp)
        
        else: 
            mf_temp = np.load(self.mfcc_list[index], allow_pickle= True)
            tr_temp = np.load(self.transcript_list[index], allow_pickle= True)
            
            # Converting the alphabets in the labels to integers using the pre-defined map provided 
            tr_temp = [self.alphabets[ele] for ele in tr_temp]

            if self.cep_norm:
                mf_temp = (mf_temp - np.mean(mf_temp, axis = 0))/ np.std(mf_temp)
            return torch.tensor(mf_temp), torch.tensor(tr_temp)

#Collate function for uniform padding of the input sequences 
def collate_train(data): 
    
    time_mask = torchaudio.transforms.TimeMasking(80)
    frequency_mask =torchaudio.transforms.FrequencyMasking(5)
    

    (xx, yy) = zip(*data)
    x_lens = [len(x) for x in xx]
    y_lens = [len(y) for y in yy]

    xx_pad = pad_sequence(xx,batch_first=True)
    yy_pad = pad_sequence(yy,batch_first=True)
    batch_mfcc_pad = np.transpose(xx_pad,(0,2,1))
    batch_mfcc_pad = time_mask(xx_pad)
    batch_mfcc_pad = frequency_mask(xx_pad)
    batch_mfcc_pad = np.transpose(xx_pad,(0,2,1))
    x_lens = np.asarray(x_lens)
    y_lens = np.asarray(y_lens)
    # Some augmentation and masking here may help the network converge better. 

        
    return xx_pad, yy_pad, torch.tensor(x_lens), torch.tensor(y_lens)

def collate_val(data): 
    
    (xx, yy) = zip(*data)
    x_lens = [len(x) for x in xx]
    y_lens = [len(y) for y in yy]

    xx_pad = pad_sequence(xx,batch_first=True)
    yy_pad = pad_sequence(yy,batch_first=True)

    x_lens = np.asarray(x_lens)
    y_lens = np.asarray(y_lens)
    # Some augmentation and masking here may help the network converge better. 

        
    return xx_pad, yy_pad, torch.tensor(x_lens), torch.tensor(y_lens)


In [4]:
# Dataset and dataloader sections 
data_path = "C:\\Users\\thopa\Desktop\\Assignments\\11685\\HW4\\2022Implementation\\11-785-f22-hw4p2\\hw4p2"
train_data = MFCCDataset(data_path, vocab_map= VOCAB_MAP)
val_data = MFCCDataset(data_path,vocab_map= VOCAB_MAP, val = True)

train_loader = DataLoader(train_data ,batch_size = 8 , collate_fn= collate_train , shuffle = True)
val_loader = DataLoader(val_data, batch_size = 8, collate_fn = collate_val, shuffle= False)


In [5]:
for i, (x,y,lx,ly) in enumerate(val_loader):
    print(i)

0
1
2
3
4
5
6
7
8
9
10
11
12


## Neural Network and Training 

### Silly notes for reference 
When training RNN (LSTM or GRU or vanilla-RNN), it is difficult to batch the variable length sequences. For example: if the length of sequences in a size 8 batch is [4,6,8,5,4,3,7,8], you will pad all the sequences and that will result in 8 sequences of length 8. You would end up doing 64 computations (8x8), but you needed to do only 45 computations. Moreover, if you wanted to do something fancy like using a bidirectional-RNN, it would be harder to do batch computations just by padding and you might end up doing more computations than required.

Instead, PyTorch allows us to pack the sequence, internally packed sequence is a tuple of two lists. One contains the elements of sequences. Elements are interleaved by time steps (see example below) and other contains the size of each sequence the batch size at each step. This is helpful in recovering the actual sequences as well as telling RNN what is the batch size at each time step. This has been pointed by @Aerin. This can be passed to RNN and it will internally optimize the computations.

In [6]:
class PBLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(PBLSTM,self).__init__()

        self.blstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size, num_layers = 2, batch_first = True, bidirectional = True, dropout = 0.3)
    
    def reshape(self, x, x_lens):
        # Reshaping for concatenation / reducing dimensions
        batch, rows, cols = x.shape[0], x.shape[1], x.shape[2]

        if (rows % 2 != 0):
            x = x[:,:-1,:]
        x = x.reshape(batch, int(rows/2), cols*2)
        x_lens = x_lens//2
    
        return x, x_lens
    
    def forward(self,x):
        """
        Computational savings and original sequence recovery using the pack padded and pad packed routine 
        """
        x_pad, x_pad_lens = pad_packed_sequence(x, batch_first=True)
        x, x_lens = self.reshape(x_pad, x_pad_lens.to("cuda"))
        input = rnn_utils.pack_padded_sequence(x, lengths = x_lens.cpu(), batch_first= True, enforce_sorted= False)
        rnn_out, _ = self.blstm(input)
        #output, lens = rnn_utils.pad_packed_sequence(rnn_out, batch_first= True)
        
        return rnn_out

#### Locked Dropout Regularization 
LockedDropout can be used to apply the same dropout mask to every time step. 

Benefits of locked dropout: 
- Reduces variance like any typical regularization process

In [7]:
class LockedDropout(nn.Module):
    def __init__(self, prob):
        super(LockedDropout, self).__init__()
        self.prob = prob 
    
    def forward(self, x):
        output, output_len = pad_packed_sequence(x, batch_first= True)
        x = output
        x = x.clone()
        mask = x.new_empty(1, x.size(1), x.size(2), requires_grad= False).bernoulli(1- self.prob)
        mask = mask.div_(1 - self.prob)
        mask = mask.expand_as(x)
        x_masked = x * mask 
        x_masked = pack_padded_sequence(x_masked, output_len.cpu(), batch_first= True,enforce_sorted=False)
        return x_masked



In [8]:
class Encoder(nn.Module):
    def __init__(self, input_size, encoder_hidden_size):
        super(Encoder, self).__init__()
        """
        The encoder used is a pyramidal-BiLSTM for matching the input rate and the speech transcription rate which is about 8:1. 
        This model is sigificantly influenced by the LAS paper 
        LAS: Chan, William, et al. "Listen, attend and spell." arXiv preprint arXiv:1508.01211 (2015).

        [REF: B. Raj, Deep Learning Carnegie Mellon University]
        The pBLSTM is a variant of Bi-LSTMs that downsamples sequences by a factor of 2 by concatenating
        adjacent pairs of inputs before running a conventional Bi-LSTM on the reduced-length sequence. So, given
        an input vector sequence X0, X1, X2, X3, . . . XN−1, the pBLSTM first concatenates adjacent pairs of vectors
        as [X0, X1], [X2, X3], . . . [XN−2, XN−1], and then computes a regular BiLSTM on the reshaped input.

        -) Initial Bi-LSTM 
        -) 3x Pyramidal Bi-LSTM 
        -) Locked dropout regularization : 
        
        """    
        self.base_lstm = nn.LSTM(input_size = input_size, hidden_size = encoder_hidden_size, num_layers = 1, batch_first = True, bidirectional = True, dropout = 0.1)
        self.pblstm = nn.Sequential(PBLSTM(4*encoder_hidden_size,encoder_hidden_size), LockedDropout(0.25), PBLSTM(4*encoder_hidden_size,encoder_hidden_size), LockedDropout(0.25), PBLSTM(4*encoder_hidden_size,encoder_hidden_size), LockedDropout(0.25))
        
    def forward(self, x, x_lens):
        pack_padd_out = pack_padded_sequence(x, x_lens.to('cpu'),batch_first=True, enforce_sorted=False)
        #print(type(pack_padd_out))
        out_lstm, _  = self.base_lstm(pack_padd_out)
        encoder_outputs = self.pblstm(out_lstm)
        encoder_outputs, encoder_lens = pad_packed_sequence(encoder_outputs, batch_first=True)

        return encoder_outputs, encoder_lens




In [9]:
# from torchsummaryX import summary

# for data in train_loader:
#     x, y, lx, ly = data
#     print(x.shape, y.shape, lx.shape, ly.shape)
#     break 

# encoder = Encoder(15,256)# TODO: Initialize Listener
# out, lens = encoder.forward(x, lx)
# del encoder

In [10]:
# Attention block 
"""
Possible Efficiencies with the attention mechanism (d2l book)
1) In general, it requires that both the query and the key have the 
same vector length, say d, even though this can be addressed easily by replacing 
q⊤k with q⊤Mk where M is a suitably chosen matrix to translate
between both spaces. For now assume that the dimensions match.
2) Adding dropout weights also helps 
"""
class Attention(nn.Module):
    def __init__(self, encoder_output_size, decoder_output_size, projection):
        super(Attention, self).__init__()
        self.key_layer = nn.Linear(encoder_output_size, projection)
        self.value_layer = nn.Linear(encoder_output_size, projection)
        self.query_layer = nn.Linear(decoder_output_size, projection)
    
    def key_value_calc(self, encoder_output, encoder_len):
        _ ,encoder_max_seq_len, _ = encoder_output.shape 
        self.key = self.key_layer(encoder_output)
        self.value = self.value_layer(encoder_output)
        # Attention mask 
        # Removing the influence of padding in the raw weights, we create a boolean mask of (batchsize, timesteps)
        self.mask = (torch.arange(encoder_max_seq_len)[None, :] < encoder_len[:, None]).to("cuda")
        
    def forward(self, decoder_output_embeddings):
        self.query = self.query_layer(decoder_output_embeddings)
        
        energy = torch.bmm(self.key, self.query.unsqueeze(2))
        
        energy = torch.squeeze(energy, dim = 2)

        #What should the mask least value be? 
        energy.masked_fill_(self.mask, -1e9)
        
        attention = torch.nn.functional.softmax(energy, dim = 1)
        context = torch.bmm(torch.permute(self.value,[0,2,1]),attention.unsqueeze(2)).squeeze(2)
        return context, attention

In [11]:
# Decoder ~ according to the speller of the LAS paper 

class Decoder(nn.Module):
    def __init__(self, embed_dim, projection, vocab_size, decoder_hidden_size, decoder_output_size, encoder_output_size):
        super(Decoder, self).__init__()
        
        """
        A simple lookup table that stores embeddings of a fixed dictionary and size.
        This module is often used to store word embeddings and retrieve them using indices. 
        The input to the module is a list of indices, and the output is the corresponding word embeddings.
        """
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx = 0 ).cuda()
        self.lstm_cells = nn.Sequential(nn.LSTMCell(embed_dim + projection , decoder_hidden_size) , nn.LSTMCell(decoder_hidden_size , decoder_output_size))
        self.vocab_size = vocab_size
        self.attention = Attention(2*encoder_output_size , decoder_output_size, projection)
        self.character_prob = nn.Linear(512, vocab_size)
        self.device = "cuda"
        
    def forward(self, encoder_output, encoder_len, y = None, mode = "train", teacherForcingRate = 0.1, isGumbel = False ):
       
        # batch, key_seq_max_len, key_value_size = key.shape
        batch, encoder_max_seq, _ = encoder_output.shape

        # # Attention mask for making the system autoregressive 
        # mask = torch.arange(key_seq_max_len).unsqueeze(0)>=encoder_len.unsqueeze(1)
        # mask = mask.to(self.device)

        # List to store output attention plots 
        predictions, attention_plot = [], []
        prediction = torch.full((batch,1), fill_value = 0 ,device= self.device)
        
        # Hidden states
        hidden_states= [None]*len(self.lstm_cells)
        self.attention.key_value_calc(encoder_output, encoder_len)
        
        context = self.attention.value[:,0,:]

        if mode == "train":
            max_len = y.shape[1]
            char_embedding = self.embedding(y)
        else: 
            max_len = 600

        for i in range(max_len):
            if mode == "train":
                # Teacher Forcing regime ~ Assigned and picked randomly 
                teacher_forcing = True if random.random() > teacherForcingRate else False 
                if not teacher_forcing:
                    if i != 0 : # use Gumbel noise to add noise to add variety to phoneme
                        char_embed = torch.nn.functional.gumbel_softmax(prediction).mm(self.embedding.weight)
                    else:
                        char_embed = self.embedding(prediction.argmax(dim=-1))
                else:
                    if i == 0:
                        char_embed = self.embedding(torch.zeros(batch, dtype = torch.long).fill_(VOCAB_MAP['<sos>']).to(self.device)) 
                    else: 
                        char_embed = char_embedding[:,i-1,:] # ground truth teacher forcing 
            # Validation mode 
            else: 
                if i == 0: 
                    char_embed = self.embedding(torch.zeros(batch, dtype = torch.long).fill_(VOCAB_MAP['<sos>']).to(self.device)) 
                else: 
                    char_embed = self.embedding(prediction.argmax(dim = -1)) # feed in the previous prediction as input 
            
            # Input to the decoder (prev embedding + context from attention mechanism) 
            decoder_input_embedding = torch.cat([char_embed, context.squeeze(1)], dim = 1)
           
            for i in range(len(self.lstm_cells)):
                # An LSTM Cell returns (h,c) -> h = hidden state, c = cell memory state
                # Using 2 LSTM Cells is akin to a 2 layer LSTM looped through t timesteps 
                # The second LSTM Cell takes in the output hidden state of the first LSTM Cell (from the current timestep) as Input, along with the hidden and cell states of the cell from the previous timestep
                hidden_states[i] = self.lstm_cells[i](decoder_input_embedding, hidden_states[i]) 
                decoder_input_embedding = hidden_states[i][0]

            decoder_output_embedding = hidden_states[-1][0]
            # What is the query? (same len as the key)
            # Hidden state of the LSTM 
            # 8x768 and 128x30
            # decoder_output_embeddings, mask
            context, attention = self.attention(decoder_output_embedding)
            attention_plot.append(attention[0].detach().cpu())
            
            output_context = torch.cat([self.attention.query, context], dim = 1)
            prediction = self.character_prob(output_context)
            predictions.append(prediction.unsqueeze(1))
        attentions = torch.stack(attention_plot, dim = 0)
        predictions = torch.cat(predictions, dim = 1 )

        return predictions, attentions

In [12]:
"""
Combining the pipelines of the Seq2Seq model
"""
class Seq2Seq(nn.Module):
    def __init__(self, input_size, encoder_hidden_size, vocab_size, embed_size, decoder_hidden_size, decoder_output_size, projection_size = 128 ):
        super(Seq2Seq, self).__init__()
        """
        Parameters of each of the model classes 
        Encoder : input_size, encoder_hidden_size
        Decoder : embed_dim, projection, vocab_size, decoder_hidden_size, decoder_output_size, encoder_output_size
        """

        self.encoder =  Encoder(input_size = input_size, encoder_hidden_size = encoder_hidden_size)
        #self.attention = Attention(encoder_output_size= 2*encoder_hidden_size, decoder_output_size=decoder_output_size, projection= projection_size)
        self.decoder =  Decoder(embed_dim = embed_size, projection = projection_size, vocab_size = vocab_size, decoder_hidden_size = decoder_hidden_size, decoder_output_size=decoder_output_size,encoder_output_size = encoder_hidden_size)
    def forward(self, x, x_lens, y = None, mode = "none"):
        
        encoder_outputs, encoder_lens = self.encoder(x, x_lens)
        # encoder_output, key, value, encoder_len, y = None, mode = "train", teacherForcingRate = 0.1, isGumbel = False 
        predictions, attention_map = self.decoder(encoder_outputs, encoder_lens , y, mode = mode)

        return predictions, attention_map 


In [13]:
"""
Model initialization 
"""
DEVICE = "cuda"
model  = Seq2Seq(input_size=15,encoder_hidden_size=512,vocab_size=len(VOCAB),
            embed_size=512,decoder_hidden_size=512,decoder_output_size=128,projection_size=256)
model.to(DEVICE)
# print(model)



Seq2Seq(
  (encoder): Encoder(
    (base_lstm): LSTM(15, 512, batch_first=True, dropout=0.1, bidirectional=True)
    (pblstm): Sequential(
      (0): PBLSTM(
        (blstm): LSTM(2048, 512, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
      )
      (1): LockedDropout()
      (2): PBLSTM(
        (blstm): LSTM(2048, 512, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
      )
      (3): LockedDropout()
      (4): PBLSTM(
        (blstm): LSTM(2048, 512, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
      )
      (5): LockedDropout()
    )
  )
  (decoder): Decoder(
    (embedding): Embedding(30, 512, padding_idx=0)
    (lstm_cells): Sequential(
      (0): LSTMCell(768, 512)
      (1): LSTMCell(512, 128)
    )
    (attention): Attention(
      (key_layer): Linear(in_features=1024, out_features=256, bias=True)
      (value_layer): Linear(in_features=1024, out_features=256, bias=True)
      (query_layer): Linear(in_features=128, out_f

In [14]:
# We have given you this utility function which takes a sequence of indices and converts them to a list of characters
def indices_to_chars(indices, vocab):
    tokens = []
    for i in indices: # This loops through all the indices
        if vocab[int(i)] == vocab[SOS_TOKEN]: # If SOS is encountered, dont add it to the final list
            continue
        elif vocab[int(i)] == vocab[EOS_TOKEN]: # If EOS is encountered, stop the decoding process
            break
        else:
            tokens.append(vocab[i])
    return tokens

# To make your life more easier, we have given the Levenshtein distantce / Edit distance calculation code
def calc_edit_distance(predictions, y, ly, vocab= VOCAB, print_example= False):

    dist                = 0
    batch_size, seq_len = predictions.shape

    for batch_idx in range(batch_size): 

        y_sliced    = indices_to_chars(y[batch_idx,0:ly[batch_idx]], vocab)
        pred_sliced = indices_to_chars(predictions[batch_idx], vocab)

        # Strings - When you are using characters from the AudioDataset
        y_string    = ''.join(y_sliced)
        pred_string = ''.join(pred_sliced)
        
        dist        += Levenshtein.distance(pred_string, y_string)
        # Comment the above abd uncomment below for toy dataset 
        # dist      += Levenshtein.distance(y_sliced, pred_sliced)

    if print_example: 
        # Print y_sliced and pred_sliced if you are using the toy dataset
        # print("Ground Truth : ", y_string)
        # print("Prediction   : ", pred_string)
        print("Ground Truth : ", y_sliced)
        print("Prediction   : ", pred_sliced)
        
    dist/=batch_size
    return dist

### Hyper-parameters under consideration 
- Optimizer
    - Learning Rate
    - Weight Decay
- Learning Rate Scheduler
    - reduction 
- Loss function 
    - Reduction
    - Factor 
    - Patience 

In [15]:
optimizer = torch.optim.Adam(model.parameters(),lr = 2e-3, weight_decay = 5e-6 )
criterion = nn.CrossEntropyLoss(reduction = "none")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.4, patience = 2)
scaler = torch.cuda.amp.GradScaler()

### Train and evaluate 

In [17]:
epochs = 100 
best_lev_dist = float("inf")
tf_rate = 0.5
for epoch in tqdm(range(epochs)): 
    print("\nEpoch: {}/{}".format(epoch+1, epochs))
    curr_lr = float(optimizer.param_groups[0]['lr'])


    model.train()

    batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train')
    running_loss = 0

    # Levenstein distance debug at random indices 
    
    for i,(x, y, x_len, y_len) in enumerate(train_loader):
        optimizer.zero_grad()
        x, x_len, y, y_len = x.to(DEVICE), x_len, y.to(DEVICE), y_len
        pred, attn = model(x = x, x_lens = x_len, y = y, mode = "train")
    

        max_len = torch.max(torch.tensor(y_len))
        lst = torch.arange(0,max_len).repeat(y_len.size(0),1)

        seq_len = y_len.unsqueeze(1).expand(y_len.size(0),max_len)
        mask = (lst<seq_len).int().cuda() 
        loss = criterion(pred.view(-1, pred.size(2)), y.view(-1))
        masked_loss = torch.sum(loss * mask.view(-1)) / torch.sum(mask)

        masked_loss.backward()
        running_loss+=masked_loss
        optimizer.step()
        idx = "{}_{}".format(epoch, i)
        torch.cuda.empty_cache()
        batch_bar.set_postfix(
            loss="{:.04f}".format(running_loss/(i+1)),
            lr="{:.08f}".format(float(optimizer.param_groups[0]['lr'])))
        batch_bar.update()

        del x, y, x_len, y_len
        torch.cuda.empty_cache()

    # Validate 
    model.eval()
    batch_bar = tqdm(total = len(val_loader), dynamic_ncols= True, position = 0 ,leave = False, desc = "Val")
    debug_ind = np.random.randint(0,len(val_loader))
    running_lev_dist = 0.0

    for i, (x,y,lx,ly) in enumerate(val_loader):
        x,y,lx,ly = x.to(DEVICE), y.to(DEVICE), lx, ly
        with torch.inference_mode():
            predictions, attention = model(x, lx, y = None)
        greedy_predictions = predictions.argmax(dim = -1)

        if i == debug_ind: 
            running_lev_dist += calc_edit_distance(greedy_predictions, y ,ly, VOCAB, print_example = True)
        else: 
            running_lev_dist += calc_edit_distance(greedy_predictions, y ,ly, VOCAB, print_example = False)
        
        batch_bar.set_postfix(
            dist="{:.04f}".format(running_lev_dist/(i+1)))
        batch_bar.update()
        del x, y, lx, ly
        torch.cuda.empty_cache()
    batch_bar.close()
    running_lev_dist /= len(val_loader)
    val_dist = running_lev_dist

    #wandb.log({"train loss": running_loss, "lev_dist": val_dist, "tf_rate": tf_rate, "lr":curr_lr})
    scheduler.step(val_dist)

    if val_dist <= best_lev_dist:
        best_lev_dist = val_dist
        print("Saving model")
        torch.save({'model_state_dict':model.state_dict(),
                    'optimizer_state_dict':optimizer.state_dict(),
                    'val_dist': val_dist, 
                    'epoch': epoch}, 'C:\\Users\\thopa\\Desktop\\Assignments\\11685\\HW4\\2022Implementation\\complete_implementation\\ckpt\\seq2seq_{}.pth'.format(epoch))
        if val_dist<28:
            tf_rate = tf_rate - 0.05*tf_rate
            tf_rate = max(0.5, tf_rate)



  0%|          | 0/100 [00:00<?, ?it/s]


Epoch: 1/100


  x_lens = x_lens//2
  max_len = torch.max(torch.tensor(y_len))
                                                                                  

Ground Truth :  ['W', 'H', 'A', 'T', ' ', 'A', 'L', 'T', 'E', 'R', 'N', 'A', 'T', 'I', 'V', 'E', ' ', 'W', 'A', 'S', ' ', 'T', 'H', 'E', 'R', 'E', ' ', 'F', 'O', 'R', ' ', 'H', 'E', 'R']
Prediction   :  ['H', 'E', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', '

  1%|          | 1/100 [00:48<1:19:15, 48.03s/it]


Epoch: 2/100


                                                                                  

Ground Truth :  ['W', 'H', 'A', 'T', ' ', 'A', 'L', 'T', 'E', 'R', 'N', 'A', 'T', 'I', 'V', 'E', ' ', 'W', 'A', 'S', ' ', 'T', 'H', 'E', 'R', 'E', ' ', 'F', 'O', 'R', ' ', 'H', 'E', 'R']
Prediction   :  ['H', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'A', 'N', 'D', '

  2%|▏         | 2/100 [01:36<1:18:55, 48.32s/it]


Epoch: 3/100


Val:  69%|██████▉   | 9/13 [00:11<00:05,  1.32s/it, dist=553.9583]                

Ground Truth :  ['O', ' ', 'I', 'F', ' ', 'Y', 'O', 'U', ' ', 'P', 'L', 'A', 'Y', ' ', 'U', 'S', ' ', 'A', ' ', 'R', 'O', 'U', 'N', 'D', 'E', 'L', ' ', 'S', 'I', 'N', 'G', 'E', 'R', ' ', 'H', 'O', 'W', ' ', 'C', 'A', 'N', ' ', 'T', 'H', 'A', 'T', ' ', 'H', 'A', 'R', 'M', ' ', 'T', 'H', 'E', ' ', 'E', 'M', 'P', 'E', 'R', 'O', 'R', "'", 'S', ' ', 'D', 'A', 'U', 'G', 'H', 'T', 'E', 'R']
Prediction   :  ['T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', ' ', 'T', 'H', 'E', '

                                                                   

Saving model


  3%|▎         | 3/100 [02:24<1:17:51, 48.16s/it]


Epoch: 4/100


Val:  54%|█████▍    | 7/13 [00:09<00:08,  1.34s/it, dist=540.2679]                

Ground Truth :  ['I', ' ', 'S', 'A', 'W', ' ', 'T', 'H', 'E', ' ', 'L', 'A', 'D', 'Y', ' ', 'W', 'H', 'O', ' ', 'E', 'R', 'E', 'W', 'H', 'I', 'L', 'E', ' ', 'A', 'P', 'P', 'E', 'A', 'R', 'E', 'D', ' ', 'V', 'E', 'I', 'L', 'E', 'D', ' ', 'U', 'N', 'D', 'E', 'R', 'N', 'E', 'A', 'T', 'H', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'G', 'E', 'L', 'I', 'C', ' ', 'F', 'E', 'S', 'T', 'I', 'V', 'A', 'L', ' ', 'D', 'I', 'R', 'E', 'C', 'T', ' ', 'H', 'E', 'R', ' ', 'E', 'Y', 'E', 'S', ' ', 'T', 'O', ' ', 'M', 'E', ' ', 'A', 'C', 'R', 'O', 'S', 'S', ' ', 'T', 'H', 'E', ' ', 'R', 'I', 'V', 'E', 'R']
Prediction   :  ['H', 'A', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', '

                                                                   

Saving model


  4%|▍         | 4/100 [03:12<1:17:05, 48.19s/it]


Epoch: 5/100


Val:  31%|███       | 4/13 [00:05<00:11,  1.27s/it, dist=559.8750]                

Ground Truth :  ['T', 'H', 'O', 'U', ' ', 'M', 'A', 'K', 'E', 'S', 'T', ' ', 'M', 'E', ' ', 'R', 'E', 'M', 'E', 'M', 'B', 'E', 'R', ' ', 'W', 'H', 'E', 'R', 'E', ' ', 'A', 'N', 'D', ' ', 'W', 'H', 'A', 'T', ' ', 'P', 'R', 'O', 'S', 'E', 'R', 'P', 'I', 'N', 'A', ' ', 'T', 'H', 'A', 'T', ' ', 'M', 'O', 'M', 'E', 'N', 'T', ' ', 'W', 'A', 'S', ' ', 'W', 'H', 'E', 'N', ' ', 'L', 'O', 'S', 'T', ' ', 'H', 'E', 'R', ' ', 'M', 'O', 'T', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'S', 'H', 'E', ' ', 'H', 'E', 'R', 'S', 'E', 'L', 'F', ' ', 'T', 'H', 'E', ' ', 'S', 'P', 'R', 'I', 'N', 'G']
Prediction   :  ['H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', '

  5%|▌         | 5/100 [04:00<1:15:55, 47.95s/it]                  


Epoch: 6/100


Val:  62%|██████▏   | 8/13 [00:10<00:07,  1.42s/it, dist=546.4375]                

Ground Truth :  ['W', 'H', 'E', 'N', 'C', 'E', ' ', 'S', 'H', 'E', ' ', 'T', 'O', ' ', 'M', 'E', ' ', 'I', 'N', ' ', 'T', 'H', 'O', 'S', 'E', ' ', 'D', 'E', 'S', 'I', 'R', 'E', 'S', ' ', 'O', 'F', ' ', 'M', 'I', 'N', 'E', ' ', 'W', 'H', 'I', 'C', 'H', ' ', 'L', 'E', 'D', ' ', 'T', 'H', 'E', 'E', ' ', 'T', 'O', ' ', 'T', 'H', 'E', ' ', 'L', 'O', 'V', 'I', 'N', 'G', ' ', 'O', 'F', ' ', 'T', 'H', 'A', 'T', ' ', 'G', 'O', 'O', 'D', ' ', 'B', 'E', 'Y', 'O', 'N', 'D', ' ', 'W', 'H', 'I', 'C', 'H', ' ', 'T', 'H', 'E', 'R', 'E', ' ', 'I', 'S', ' ', 'N', 'O', 'T', 'H', 'I', 'N', 'G', ' ', 'T', 'O', ' ', 'A', 'S', 'P', 'I', 'R', 'E', ' ', 'T', 'O']
Prediction   :  ['T', 'H', 'E', ' ', 'H', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'H', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'H', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'H', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'H', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'H', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'H', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'H', 'E', 'R', ' ', 'T', 'H', 'E', '

  6%|▌         | 6/100 [04:47<1:14:39, 47.65s/it]                  


Epoch: 7/100


Val:  38%|███▊      | 5/13 [00:06<00:10,  1.25s/it, dist=541.6250]                

Ground Truth :  ['N', 'O', 'R', ' ', 'E', 'V', 'E', 'N', ' ', 'T', 'H', 'U', 'S', ' ', 'O', 'U', 'R', ' ', 'W', 'A', 'Y', ' ', 'C', 'O', 'N', 'T', 'I', 'N', 'U', 'E', 'D', ' ', 'F', 'A', 'R', ' ', 'B', 'E', 'F', 'O', 'R', 'E', ' ', 'T', 'H', 'E', ' ', 'L', 'A', 'D', 'Y', ' ', 'W', 'H', 'O', 'L', 'L', 'Y', ' ', 'T', 'U', 'R', 'N', 'E', 'D', ' ', 'H', 'E', 'R', 'S', 'E', 'L', 'F', ' ', 'U', 'N', 'T', 'O', ' ', 'M', 'E', ' ', 'S', 'A', 'Y', 'I', 'N', 'G', ' ', 'B', 'R', 'O', 'T', 'H', 'E', 'R', ' ', 'L', 'O', 'O', 'K', ' ', 'A', 'N', 'D', ' ', 'L', 'I', 'S', 'T', 'E', 'N']
Prediction   :  ['T', 'H', 'E', ' ', 'P', 'R', 'O', 'V', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'P', 'R', 'O', 'V', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'P', 'R', 'O', 'V', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'P', 'R', 'O', 'V', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'P', 'R', 'O', 'V', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'P', 'R', 'O', 'V', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'P', 'R', 'O', 'V', 'E', 'R', ' ', 'T', 'H', 'E', ' ', '

                                                                   

Saving model


  7%|▋         | 7/100 [05:35<1:14:09, 47.84s/it]


Epoch: 8/100


Val:  77%|███████▋  | 10/13 [00:13<00:04,  1.34s/it, dist=523.2375]               

Ground Truth :  ['T', 'H', 'E', ' ', 'L', 'A', 'D', 'I', 'E', 'S']
Prediction   :  ['T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'M', 'O', 'R', 'L', 'A', 'N', 'D', ' ', 'T', 'H', 'E', '

                                                                   

Saving model


  8%|▊         | 8/100 [06:23<1:13:23, 47.87s/it]


Epoch: 9/100


Val:  31%|███       | 4/13 [00:05<00:11,  1.24s/it, dist=539.8750]                

Ground Truth :  ['T', 'H', 'O', 'U', ' ', 'M', 'A', 'K', 'E', 'S', 'T', ' ', 'M', 'E', ' ', 'R', 'E', 'M', 'E', 'M', 'B', 'E', 'R', ' ', 'W', 'H', 'E', 'R', 'E', ' ', 'A', 'N', 'D', ' ', 'W', 'H', 'A', 'T', ' ', 'P', 'R', 'O', 'S', 'E', 'R', 'P', 'I', 'N', 'A', ' ', 'T', 'H', 'A', 'T', ' ', 'M', 'O', 'M', 'E', 'N', 'T', ' ', 'W', 'A', 'S', ' ', 'W', 'H', 'E', 'N', ' ', 'L', 'O', 'S', 'T', ' ', 'H', 'E', 'R', ' ', 'M', 'O', 'T', 'H', 'E', 'R', ' ', 'H', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'S', 'H', 'E', ' ', 'H', 'E', 'R', 'S', 'E', 'L', 'F', ' ', 'T', 'H', 'E', ' ', 'S', 'P', 'R', 'I', 'N', 'G']
Prediction   :  ['H', 'A', 'D', ' ', 'N', 'O', 'T', ' ', 'A', 'N', 'D', ' ', 'H', 'E', 'R', ' ', 'W', 'A', 'S', ' ', 'A', 'N', 'D', ' ', 'H', 'E', 'R', ' ', 'W', 'A', 'S', ' ', 'A', 'N', 'D', ' ', 'H', 'E', 'R', ' ', 'W', 'A', 'S', ' ', 'A', 'N', 'D', ' ', 'H', 'E', 'R', ' ', 'W', 'A', 'S', ' ', 'A', 'N', 'D', ' ', 'H', 'E', 'R', ' ', 'W', 'A', 'S', ' ', 'A', 'N', 'D', ' ', 'H', 'E', 'R', ' ', '

  9%|▉         | 9/100 [07:10<1:12:14, 47.63s/it]                  


Epoch: 10/100


Val:  92%|█████████▏| 12/13 [00:16<00:01,  1.49s/it, dist=542.3333]               

Ground Truth :  ['I', 'F', ' ', 'W', 'E', ' ', 'H', 'A', 'D', ' ', 'B', 'E', 'E', 'N', ' ', 'B', 'R', 'O', 'T', 'H', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'S', 'I', 'S', 'T', 'E', 'R', ' ', 'I', 'N', 'D', 'E', 'E', 'D', ' ', 'T', 'H', 'E', 'R', 'E', ' ', 'W', 'A', 'S', ' ', 'N', 'O', 'T', 'H', 'I', 'N', 'G']
Prediction   :  ['A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', '

 10%|█         | 10/100 [07:57<1:11:08, 47.43s/it]                 


Epoch: 11/100


Val:  85%|████████▍ | 11/13 [00:15<00:02,  1.42s/it, dist=525.2273]               

Ground Truth :  ['I', 'T', ' ', 'I', 'S', ' ', 'T', 'H', 'E', ' ', 'E', 'X', 'P', 'R', 'E', 'S', 'S', 'I', 'O', 'N', ' ', 'O', 'F', ' ', 'L', 'I', 'F', 'E', ' ', 'U', 'N', 'D', 'E', 'R', ' ', 'C', 'R', 'U', 'D', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'M', 'O', 'R', 'E', ' ', 'R', 'I', 'G', 'I', 'D', ' ', 'C', 'O', 'N', 'D', 'I', 'T', 'I', 'O', 'N', 'S', ' ', 'T', 'H', 'A', 'N', ' ', 'O', 'U', 'R', 'S', ' ', 'L', 'I', 'V', 'E', 'D', ' ', 'B', 'Y', ' ', 'P', 'E', 'O', 'P', 'L', 'E', ' ', 'W', 'H', 'O', ' ', 'L', 'O', 'V', 'E', 'D', ' ', 'A', 'N', 'D', ' ', 'H', 'A', 'T', 'E', 'D', ' ', 'M', 'O', 'R', 'E', ' ', 'N', 'A', 'I', 'V', 'E', 'L', 'Y', ' ', 'A', 'G', 'E', 'D', ' ', 'S', 'O', 'O', 'N', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'D', 'I', 'E', 'D', ' ', 'Y', 'O', 'U', 'N', 'G', 'E', 'R', ' ', 'T', 'H', 'A', 'N', ' ', 'W', 'E', ' ', 'D', 'O']
Prediction   :  ['A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', '

 11%|█         | 11/100 [08:45<1:10:21, 47.44s/it]                 


Epoch: 12/100


Val:  38%|███▊      | 5/13 [00:06<00:10,  1.28s/it, dist=537.6500]                

Ground Truth :  ['N', 'O', 'R', ' ', 'E', 'V', 'E', 'N', ' ', 'T', 'H', 'U', 'S', ' ', 'O', 'U', 'R', ' ', 'W', 'A', 'Y', ' ', 'C', 'O', 'N', 'T', 'I', 'N', 'U', 'E', 'D', ' ', 'F', 'A', 'R', ' ', 'B', 'E', 'F', 'O', 'R', 'E', ' ', 'T', 'H', 'E', ' ', 'L', 'A', 'D', 'Y', ' ', 'W', 'H', 'O', 'L', 'L', 'Y', ' ', 'T', 'U', 'R', 'N', 'E', 'D', ' ', 'H', 'E', 'R', 'S', 'E', 'L', 'F', ' ', 'U', 'N', 'T', 'O', ' ', 'M', 'E', ' ', 'S', 'A', 'Y', 'I', 'N', 'G', ' ', 'B', 'R', 'O', 'T', 'H', 'E', 'R', ' ', 'L', 'O', 'O', 'K', ' ', 'A', 'N', 'D', ' ', 'L', 'I', 'S', 'T', 'E', 'N']
Prediction   :  ['T', 'H', 'E', ' ', 'C', 'O', 'U', 'L', 'D', ' ', 'N', 'O', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'D', '

 12%|█▏        | 12/100 [09:33<1:09:48, 47.60s/it]                 


Epoch: 13/100


Val:  85%|████████▍ | 11/13 [00:15<00:02,  1.45s/it, dist=537.4318]               

Ground Truth :  ['I', 'T', ' ', 'I', 'S', ' ', 'T', 'H', 'E', ' ', 'E', 'X', 'P', 'R', 'E', 'S', 'S', 'I', 'O', 'N', ' ', 'O', 'F', ' ', 'L', 'I', 'F', 'E', ' ', 'U', 'N', 'D', 'E', 'R', ' ', 'C', 'R', 'U', 'D', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'M', 'O', 'R', 'E', ' ', 'R', 'I', 'G', 'I', 'D', ' ', 'C', 'O', 'N', 'D', 'I', 'T', 'I', 'O', 'N', 'S', ' ', 'T', 'H', 'A', 'N', ' ', 'O', 'U', 'R', 'S', ' ', 'L', 'I', 'V', 'E', 'D', ' ', 'B', 'Y', ' ', 'P', 'E', 'O', 'P', 'L', 'E', ' ', 'W', 'H', 'O', ' ', 'L', 'O', 'V', 'E', 'D', ' ', 'A', 'N', 'D', ' ', 'H', 'A', 'T', 'E', 'D', ' ', 'M', 'O', 'R', 'E', ' ', 'N', 'A', 'I', 'V', 'E', 'L', 'Y', ' ', 'A', 'G', 'E', 'D', ' ', 'S', 'O', 'O', 'N', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'D', 'I', 'E', 'D', ' ', 'Y', 'O', 'U', 'N', 'G', 'E', 'R', ' ', 'T', 'H', 'A', 'N', ' ', 'W', 'E', ' ', 'D', 'O']
Prediction   :  ['A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'G', 'E', 'N', 'E', 'R', 'A', 'L', ' ', 'T', 'O', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', '

 13%|█▎        | 13/100 [10:21<1:09:16, 47.78s/it]                 


Epoch: 14/100


Val:  38%|███▊      | 5/13 [00:06<00:10,  1.25s/it, dist=534.3250]                

Ground Truth :  ['N', 'O', 'R', ' ', 'E', 'V', 'E', 'N', ' ', 'T', 'H', 'U', 'S', ' ', 'O', 'U', 'R', ' ', 'W', 'A', 'Y', ' ', 'C', 'O', 'N', 'T', 'I', 'N', 'U', 'E', 'D', ' ', 'F', 'A', 'R', ' ', 'B', 'E', 'F', 'O', 'R', 'E', ' ', 'T', 'H', 'E', ' ', 'L', 'A', 'D', 'Y', ' ', 'W', 'H', 'O', 'L', 'L', 'Y', ' ', 'T', 'U', 'R', 'N', 'E', 'D', ' ', 'H', 'E', 'R', 'S', 'E', 'L', 'F', ' ', 'U', 'N', 'T', 'O', ' ', 'M', 'E', ' ', 'S', 'A', 'Y', 'I', 'N', 'G', ' ', 'B', 'R', 'O', 'T', 'H', 'E', 'R', ' ', 'L', 'O', 'O', 'K', ' ', 'A', 'N', 'D', ' ', 'L', 'I', 'S', 'T', 'E', 'N']
Prediction   :  ['A', 'N', 'D', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'G', 'R', 'E', 'A', 'T', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'G', 'R', 'E', 'A', 'T', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'G', 'R', 'E', 'A', 'T', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'G', 'R', 'E', 'A', 'T', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'G', 'R', 'E', 'A', 'T', '

 14%|█▍        | 14/100 [11:08<1:08:14, 47.61s/it]                 


Epoch: 15/100


Val:  92%|█████████▏| 12/13 [00:17<00:01,  1.55s/it, dist=525.3542]               

Ground Truth :  ['I', 'F', ' ', 'W', 'E', ' ', 'H', 'A', 'D', ' ', 'B', 'E', 'E', 'N', ' ', 'B', 'R', 'O', 'T', 'H', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'S', 'I', 'S', 'T', 'E', 'R', ' ', 'I', 'N', 'D', 'E', 'E', 'D', ' ', 'T', 'H', 'E', 'R', 'E', ' ', 'W', 'A', 'S', ' ', 'N', 'O', 'T', 'H', 'I', 'N', 'G']
Prediction   :  ['A', 'N', 'D', ' ', 'H', 'E', 'R', ' ', 'F', 'A', 'T', 'H', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', '

 15%|█▌        | 15/100 [11:56<1:07:27, 47.62s/it]                 


Epoch: 16/100


Val:  54%|█████▍    | 7/13 [00:09<00:08,  1.49s/it, dist=525.2143]                

Ground Truth :  ['I', ' ', 'S', 'A', 'W', ' ', 'T', 'H', 'E', ' ', 'L', 'A', 'D', 'Y', ' ', 'W', 'H', 'O', ' ', 'E', 'R', 'E', 'W', 'H', 'I', 'L', 'E', ' ', 'A', 'P', 'P', 'E', 'A', 'R', 'E', 'D', ' ', 'V', 'E', 'I', 'L', 'E', 'D', ' ', 'U', 'N', 'D', 'E', 'R', 'N', 'E', 'A', 'T', 'H', ' ', 'T', 'H', 'E', ' ', 'A', 'N', 'G', 'E', 'L', 'I', 'C', ' ', 'F', 'E', 'S', 'T', 'I', 'V', 'A', 'L', ' ', 'D', 'I', 'R', 'E', 'C', 'T', ' ', 'H', 'E', 'R', ' ', 'E', 'Y', 'E', 'S', ' ', 'T', 'O', ' ', 'M', 'E', ' ', 'A', 'C', 'R', 'O', 'S', 'S', ' ', 'T', 'H', 'E', ' ', 'R', 'I', 'V', 'E', 'R']
Prediction   :  ['A', 'N', 'D', ' ', 'H', 'E', 'R', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'G', 'E', 'N', 'E', 'R', 'A', 'L', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'G', 'E', 'N', 'E', 'R', 'A', 'L', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'G', 'E', 'N', 'E', 'R', 'A', 'L', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'G', 'E', 'N', 'E', 'R', 'A', 'L', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', '

 16%|█▌        | 16/100 [12:43<1:06:29, 47.49s/it]                 


Epoch: 17/100


Val:  15%|█▌        | 2/13 [00:02<00:14,  1.35s/it, dist=546.1875]                

Ground Truth :  ['B', 'U', 'T', ' ', 'C', 'A', 'N', ' ', 'H', 'E', ' ', 'U', 'N', 'D', 'E', 'R', 'S', 'T', 'A', 'N', 'D', ' ', 'Y', 'O', 'U', ' ', 'Y', 'E', 'S']
Prediction   :  ['A', 'N', 'D', ' ', 'H', 'E', 'R', ' ', 'F', 'A', 'T', 'H', 'E', 'R', ' ', 'T', 'H', 'E', ' ', 'G', 'E', 'N', 'E', 'R', 'A', 'L', ' ', 'S', 'E', 'E', 'N', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', 'T', 'H', 'E', ' ', 'P', 'E', 'R', 'S', 'E', 'N', 'T', ' ', 'A', 'N', 'D', ' ', '

 17%|█▋        | 17/100 [13:31<1:05:55, 47.66s/it]                 


Epoch: 18/100


 17%|█▋        | 17/100 [14:10<1:09:14, 50.05s/it], dist=522.1429]                


KeyboardInterrupt: 