In [40]:
import os
from random import random
import pandas as pd
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

root_path = 'data/transcriptions'
data_path = os.path.join(root_path,'train.csv')

In [41]:
class Vocab:
    def __init__(self, counter,for_encoder=False, min_freq=None):
        self.sos = "<sos>"
        self.eos = "<eos>"
        self.pad = "<pad>"
        self.unk = "<unk>"
        
        self.pad_idx = 0
        self.unk_idx = 1
        self.sos_idx = 2
        self.eos_idx = 3
        
        if for_encoder:
            self._token2idx = {
                self.pad:self.pad_idx,
                self.unk:self.unk_idx,
            }
        else:
            self._token2idx = {
                self.sos: self.sos_idx,
                self.eos: self.eos_idx,
                self.pad: self.pad_idx,
                self.unk: self.unk_idx,
            }
        self._idx2token = {idx:token for token, idx in self._token2idx.items()}
        
        
        idx = len(self._token2idx)
        min_freq = 0 if min_freq is None else min_freq
        
        for token, count in counter.items():
            if count > min_freq:
                self._token2idx[token] = idx
                self._idx2token[idx]   = token
                idx += 1
        
        self.vocab_size = len(self._token2idx)
        self.tokens     = list(self._token2idx.keys())
    
    def token2idx(self, token):
        return self._token2idx.get(token, self.pad_idx)
    
    def idx2token(self, idx):
        return self._idx2token.get(idx, self.pad)
    
    def sent2idx(self, sent):
        return [self.token2idx(i) for i in sent]
    
    def idx2sent(self, idx):
        return [self.idx2token(i) for i in idx]
    
    def __len__(self):
        return len(self._token2idx)
    
    def __repr__(self):
        
        return '{}'.format(self._token2idx)

In [69]:
class CharactersDataset(Dataset):
    
    def __init__(self,csv_file_path,transform = None):
        self.file = pd.read_csv(csv_file_path,'r')
        self.transform = transform
        self.data = []
        self.characters_vocab = None
        self.transcripts_vocab = None
        self.non_needed_symbols = '\'#$?\\_({)}-:\";!%.1234567890'
        
        self.make_dataset()
       
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        
        x = self.data[idx]['x']
        y = self.data[idx]['y']
        data = {'x':x,'y':y}
        if self.transform:
            data = self.transform(data)
    
        return data
    
    def make_dataset(self):
        characters = set()
        transcripts = set()
        for idx in range(len(self.file)):
            item = str(self.file.iloc[idx][0]).split(',')
            
            x = item[1].strip()
            for symbol in self.non_needed_symbols:
                x = x.replace(symbol,'')
            y = item[2].replace(' ','')
            self.data.append({'x':x,'y':y})
            for character in x:
                characters.add(character)
            for transcript in y:
                transcripts.add(transcript)
        
        self.characters_vocab = Vocab({v:k for k,v in dict(enumerate(characters,start=2)).items()},for_encoder=True)
        self.transcripts_vocab = Vocab({v:k for k,v in dict(enumerate(transcripts,start=4)).items()})
        
            
    def collate_fn(self, batch): 
        x_values = []
        y_values_in = []
        x_lengths = []
        y_lengths = []
        for item in batch:
            
            x_values.append([self.characters_vocab.token2idx(ch) for ch in item['x']])
            y_values_in.append([self.transcripts_vocab.token2idx(tr) for tr in item['y']])
        
        x_values = sorted(x_values,key=len,reverse=True)
        y_values_in = sorted(y_values_in,key=len,reverse=True)
        
        max_x = len(x_values[0])
        max_y = len(y_values_in[0])
        
        for word_index in range(len(x_values)):
            
            x_lengths.append(len(x_values[word_index]))
            y_lengths.append(len(y_values_in[word_index]))
            
            for _ in range(1+ max_x - len(x_values[word_index])):
                x_values[word_index].append(0)
            for _ in range(1+ max_y - len(y_values_in[word_index])):
                y_values_in[word_index].append(0)
            
            y_values_in[word_index].insert(0,2)
            
        x_values = torch.tensor(x_values)
        y_values_in_tensor = torch.tensor(y_values_in)
        
        y_values_out = y_values_in        
        for arr_index in range(len(y_values_out)):
            index_of_first_zero = y_values_out[arr_index].index(0)
            y_values_out[arr_index][index_of_first_zero] = 3
            y_values_out[arr_index] = y_values_out[arr_index][1:] +[0]
            
        y_values_out_tensor = torch.tensor(y_values_out)
        
        return x_values,y_values_in_tensor,y_values_out_tensor

In [71]:
dataset = CharactersDataset(data_path)
dataloader = DataLoader(dataset,batch_size=32,shuffle=True,num_workers=2,collate_fn = dataset.collate_fn)

In [72]:
for kek in dataloader:
    kek = kek
    break

In [76]:
kek[2][0],kek[1][0]

(tensor([15,  9, 26, 17, 25, 25,  4, 21, 25,  9,  8, 25,  9, 26, 13,  3,  0]),
 tensor([ 2, 15,  9, 26, 17, 25, 25,  4, 21, 25,  9,  8, 25,  9, 26, 13,  0]))

In [77]:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(0.2 * dataset_size))

np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=32, 
                                           sampler=train_sampler,collate_fn = dataset.collate_fn)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                sampler=valid_sampler,collate_fn = dataset.collate_fn)

In [106]:
class EncoderLSTM(nn.Module):
    def __init__(self,embed_dim,hidden_size,output_size,n_layers = 1,dropout=0):
        super(EncoderLSTM,self).__init__()
        
        self.embedding = nn.Embedding(output_size,embed_dim,padding_idx = 0)
        self.LSTM = nn.LSTM(embed_dim,hidden_size,n_layers,dropout=(0 if n_layers == 1 else dropout),batch_first=True)
    
    def forward(self,input_seq,hidden=None):
        embedded = self.embedding(input_seq)
        #packed = nn.utils.rnn.pack_padded_sequence(embedded,input_lengths)
        outputs,(hidden,cell) = self.LSTM(embedded)
        #outputs,_ = nn.utils.rnn.pad_packed_sequence(outputs)
        return hidden,cell

In [125]:
class DecoderLSTM(nn.Module):
    def __init__(self,embed_dim,hidden_size,output_size,n_layers=1,dropout=0):
        super(DecoderLSTM,self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout
        
        #layers
        self.embedding = nn.Embedding(output_size,embed_dim, padding_idx = 0)
        self.LSTM = nn.LSTM(embed_dim,hidden_size,n_layers,dropout = (0 if n_layers == 1 else self.dropout),batch_first=True)
        self.out = nn.Linear(hidden_size,output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
        
    def forward(self,input_step,last_hidden, last_cell):
        #input_step (1,batch_size)
        input_step = input_step.unsqueeze(1)

        embedded = self.embedding(input_step)
        #embedded(1,batch_size,hidden_dim)

        output,(hidden,cell) = self.LSTM(embedded,(last_hidden, last_cell))
        #output(batch_size,seq_len,hidden_dim)
        #seq_len = 1 if we using teacher forcing
        output = output.squeeze(1) #(batch_size,hidden_dim)
        
        prediction = self.out(output)
        #prediction(batch_size,output_dim)
        
        return prediction,hidden,cell

In [126]:
encoder = EncoderLSTM(32,64,len(dataset.characters_vocab)).to(device)

In [127]:
decoder = DecoderLSTM(32,64,len(dataset.transcripts_vocab)).to(device)
len(dataset.characters_vocab)

28

In [128]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [129]:
import random
class seq2seq(nn.Module):
    def __init__(self,encoder,decoder,device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    def forward(self,x,y,teacher_forcing_ratio = 0.3):
        batch_size = y.shape[0]
        seq_len = y.shape[1]
        outputs = torch.zeros(seq_len,batch_size,self.decoder.output_size).to(self.device)
        
            
        hidden,cell = self.encoder(x)
        input_token = y[:,0]
        
        for t in range(1,seq_len):
            output,hidden,cell = self.decoder(input_token,hidden,cell)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.max(1)[1]
            input_token = (y[:,t] if teacher_force else top1)
        
        return outputs
    
    def predict(self,x):

        #batch_size = 1!
        hidden,cell = self.encoder(x)
        out_input = torch.LongTensor([[2]]).to(device)
        preds = []
        while True:
            output, hidden, cell = self.decoder(out_input, hidden, cell)
            output = torch.argmax(output)
            our_value = output.item()
            if our_value == 3:
                break
            preds.append(our_value)
            out_input = output.unsqueeze(0)
        return preds

In [130]:
model = seq2seq(encoder,decoder,device).to(device)

In [131]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
        
model.apply(init_weights)

seq2seq(
  (encoder): EncoderLSTM(
    (embedding): Embedding(28, 32, padding_idx=0)
    (LSTM): LSTM(32, 64, batch_first=True)
  )
  (decoder): DecoderLSTM(
    (embedding): Embedding(28, 32, padding_idx=0)
    (LSTM): LSTM(32, 64, batch_first=True)
    (out): Linear(in_features=64, out_features=28, bias=True)
    (softmax): LogSoftmax()
  )
)

In [132]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 53,788 trainable parameters


In [133]:
optimizer = optim.Adam(model.parameters())

In [134]:
dataset.transcripts_vocab

{'<sos>': 2, '<eos>': 3, '<pad>': 0, '<unk>': 1, 'M': 4, 'J': 5, 'W': 6, 'D': 7, 'T': 8, 'H': 9, 'L': 10, 'C': 11, 'R': 12, 'S': 13, 'Y': 14, 'I': 15, 'Z': 16, 'K': 17, 'O': 18, 'F': 19, 'G': 20, 'P': 21, 'V': 22, 'U': 23, 'E': 24, 'A': 25, 'N': 26, 'B': 27}

In [135]:
criterion = nn.CrossEntropyLoss(ignore_index = 0)


In [144]:
def train(model, iterator, optimizer, criterion, clip,epoch,train_loss_list):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        x = batch[0].to(device)
        y_in = batch[1].to(device)
        y_out = batch[2].to(device)
        
        optimizer.zero_grad()
        
        output = model(x, y_in)
        #output dim (y_seq_len,batch_size,output_dim)
        output = output.view(output.shape[0]*output.shape[1],-1)
        y_out = y_out.view(-1)
        
        loss = criterion(output, y_out)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        train_loss_list.append(loss.item())

        if i % 500 == 0:
            plot(epoch, i, train_loss_list)
    return epoch_loss / len(iterator)

In [145]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            x = batch[0].to(device)
            y_in = batch[1].to(device)
            y_out = batch[2].to(device)
            
            output = model(x,y_in,0) #turn off teacher forcing
            
            output = output.view(output.shape[0]*output.shape[1],-1)
            y_out = y_out.view(-1)

            loss = criterion(output, y_out)
            
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)



In [146]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
def plot(epoch, step, train_losses):
    clear_output()
    plt.title(f'Epochs {epoch}, step {step}')
    plt.plot(train_losses)
    plt.show()

In [None]:
import math

N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')
train_losses = []
for epoch in range(N_EPOCHS):
    
    train_loss = train(model, train_loader, optimizer, criterion, CLIP,epoch,train_losses)
    valid_loss = evaluate(model, validation_loader, criterion)
    
    
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut1-model.pt')
        
    
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')

In [None]:
torch.save(encoder.state_dict(),'encoder_weights')
torch.save(decoder.state_dict(),'decoder_weights')
torch.save(model.state_dict(),'model_weights')

In [None]:
encoder = EncoderLSTM(len(dataset.characters_vocab),128).to(device)
decoder = DecoderLSTM(128,len(dataset.transcripts_vocab)).to(device)


In [None]:
encoder.load_state_dict(torch.load('encoder_weights'))
decoder.load_state_dict(torch.load('decoder_weights'))
model = seq2seq(encoder,decoder,device)
model.load_state_dict(torch.load('model_weights'))

In [27]:
for batch in validation_loader:
    kek = batch
    break

In [None]:
x = kek[0][0][1:].unsqueeze(0).to(device)
y_pred = model.predict(x)

In [None]:
output_token = ''
for pred in y_pred:
    output_token += dataset.transcripts_vocab.idx2token(pred)

In [None]:
input_token = ''
for x_val in x.squeeze(0):
    input_token += dataset.characters_vocab.idx2token(x_val.item())


In [None]:
input_token,output_token