In [1]:
import math
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack

In [2]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
class Vocab:
    def __init__(self, counter, sos, eos, pad, unk, min_freq=None):
        self.sos = sos
        self.eos = eos
        self.pad = pad
        self.unk = unk
        
        self.pad_idx = 0
        self.unk_idx = 1
        self.sos_idx = 2
        self.eos_idx = 3
        
        self._token2idx = {
            self.sos: self.sos_idx,
            self.eos: self.eos_idx,
            self.pad: self.pad_idx,
            self.unk: self.unk_idx,
        }
        self._idx2token = {idx:token for token, idx in self._token2idx.items()}
        
        idx = len(self._token2idx)
        min_freq = 0 if min_freq is None else min_freq
        
        for token, count in counter.items():
            if count > min_freq:
                self._token2idx[token] = idx
                self._idx2token[idx]   = token
                idx += 1
        
        self.vocab_size = len(self._token2idx)
        self.tokens     = list(self._token2idx.keys())
    
    def token2idx(self, token):
        return self._token2idx.get(token, self.pad_idx)
    
    def idx2token(self, idx):
        return self._idx2token.get(idx, self.pad)
    
    def __len__(self):
        return len(self._token2idx)
    
def padding(sequences, pad_idx):
    '''
    Inputs:
        sequences: list of list of tokens
    '''
    max_length = max(map(len, sequences))
    
    return [seq + [pad_idx]*(max_length - len(seq)) for seq in sequences]



import csv
from collections import Counter

def words_tokenize(line):
    return list(line)

def trans_tokenize(line):
    return line.split()

class Dataset(object):
    def __init__(self, path):
        val_size = 0.1
        shuffle  = True

        with open(path, 'r') as f:
            reader = csv.reader(f)
            lines   = list(reader)

        _, words, trans = zip(*lines[1:])

        c = list(zip(words, trans))
        random.shuffle(c)
        words, trans = zip(*c)

        val_size = int(len(words) * val_size)
        train_words, val_words = words[val_size:], words[:val_size]
        train_trans, val_trans = trans[val_size:], trans[:val_size]
        
        words_counter = Counter()
        trans_counter = Counter()

        for line in train_words:
            tokens = words_tokenize(line)
            for token in tokens:
                words_counter[token] += 1

        for line in train_trans:
            tokens = trans_tokenize(line)
            for token in tokens:
                trans_counter[token] += 1
                
        sos = "<sos>"
        eos = "<eos>"
        pad = "<pad>"
        unk = "<unk>"

        self.words_vocab = Vocab(words_counter, 
                            sos, eos, pad, unk)

        self.trans_vocab = Vocab(trans_counter, 
                            sos, eos, pad, unk)
        
        self.train_words = [[self.words_vocab.token2idx(item) for item in words_tokenize(word)] for word in train_words]
        self.val_words   = [[self.words_vocab.token2idx(item) for item in words_tokenize(word)] for word in val_words]

        self.train_trans = [[self.trans_vocab.token2idx(item) for item in trans_tokenize(trans)] for trans in train_trans]
        self.val_trans   = [[self.trans_vocab.token2idx(item) for item in trans_tokenize(trans)] for trans in val_trans]
        
    def __len__(self):
        return len(self.train_trans)
        
    def get_batch(self, batch_size, sort=False, val=False):
        if val:
            words, trans = self.val_words,   self.val_trans
        else:
            words, trans = self.train_words, self.train_trans

        random_ids = np.random.randint(0, len(words), batch_size)
        batch_words = [words[idx] for idx in random_ids]
        batch_trans = [trans[idx] for idx in random_ids]

        batch_trans_in  = [[self.trans_vocab.sos_idx] + tran for tran in batch_trans]
        batch_trans_out = [tran + [self.trans_vocab.eos_idx] for tran in batch_trans]

        words_lens = list(map(len, batch_words))
        trans_lens = list(map(len, batch_trans_in))

        batch_words     = padding(batch_words,     pad_idx=self.words_vocab.pad_idx)
        batch_trans_in  = padding(batch_trans_in,  pad_idx=self.trans_vocab.pad_idx)
        batch_trans_out = padding(batch_trans_out, pad_idx=self.trans_vocab.pad_idx)


        batch_words     = torch.LongTensor(batch_words)
        batch_trans_in  = torch.LongTensor(batch_trans_in)
        batch_trans_out = torch.LongTensor(batch_trans_out)
        words_lens = torch.LongTensor(words_lens)
        trans_lens = torch.LongTensor(trans_lens)

        if sort:
            lens, indices   = torch.sort(words_lens, 0, True)
            batch_words     = batch_words[indices]
            batch_trans_in  = batch_trans_in[indices]
            batch_trans_out = batch_trans_out[indices]
            trans_lens = trans_lens[indices]
            words_lens = lens

        return batch_words, batch_trans_in, batch_trans_out, words_lens, trans_lens

In [7]:
dataset = Dataset('./data/transcriptions/train.csv')
batch_words, batch_trans_in, batch_trans_out, words_lens, trans_lens = dataset.get_batch(32, sort=True)
batch_words.size(), batch_trans_in.size(), batch_trans_out.size(), words_lens.size(), trans_lens.size()

(torch.Size([32, 13]),
 torch.Size([32, 12]),
 torch.Size([32, 12]),
 torch.Size([32]),
 torch.Size([32]))

In [8]:
class Encoder(nn.Module):
    def __init__(self, embedding_size, hidden_size, vocab_size,pad_idx):
        super(Encoder, self).__init__()
        
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.pad_idx = pad_idx
        
        self.embedding = nn.Embedding(vocab_size,embedding_size,padding_idx = pad_idx)
        self.gru = nn.GRU(embedding_size,hidden_size,batch_first=True)

    def forward(self, batch_words, words_lens):

        batch_size = batch_words.size(0)
        embedded = self.embedding(batch_words)
        if words_lens is not None:
            embedded = pack(embedded, words_lens, batch_first=True)
        outputs, hidden = self.gru(embedded)
        
        if words_lens is not None:
            outs, _ = unpack(outputs, batch_first=True)
        return outs, hidden.squeeze(0)

In [9]:
class DotAttention(nn.Module):
    def __init__(self, hidden_size):
        super(DotAttention, self).__init__()
        self.linear1 = nn.Linear(2*hidden_size, hidden_size,bias=False)
        self.linear2 = nn.Linear(hidden_size, 1, bias=False)
        
    def forward(self, query, context, mask=None):
        '''
        Inputs:
            context: (batch x seq_len_enc x hidden_size) - outputs of encoder
            query: (batch x seq_len_dec x hidden_size) - outputs of decoder
            mask: (batch x seq_len_enc)
            
        Outputs:
            weigths: (batch x seq_len_dec x seq_len_enc)
            outputs: (batch x seq_len_dec x hidden_size)
        '''
        weights = torch.matmul(query, context.transpose(1,2))
        if mask is not None:
            mask = mask.unsqueeze(1)
            mask = mask.repeat(1,weights.size(1),1)
            weights.masked_fill_(mask, -float("inf"))
        weights = F.softmax(weights, dim = 2)
        outputs = torch.matmul(weights,context)
        return outputs

In [10]:
class ConAttention(nn.Module):
    def __init__(self, hidden_size):
        super(ConAttention, self).__init__()
        self.linear1 = nn.Linear(2*hidden_size, hidden_size,bias=False)
        self.linear2 = nn.Linear(hidden_size, 1, bias=False)
        
    def forward(self, context, query, mask):
        '''
        Inputs:
            context: (batch x seq_len_enc x hidden_size) - outputs of encoder
            query: (batch x seq_len_dec x hidden_size) - outputs of decoder
            mask: (batch x seq_len_enc)
            
        Outputs:
            weigths: (batch x seq_len_dec x seq_len_enc)
            outputs: (batch x seq_len_dec x hidden_size)
        '''
        batch_size = context.size(0)
        seq_len_enc = context.size(1)
        seq_len_dec = query.size(1)
        context_new = context.unsqueeze(1)
        context_new = context_new.repeat(1,seq_len_dec,1,1)
        query_new = query.unsqueeze(2)
        query_new = query_new.repeat(1,1, seq_len_enc, 1)
        
        weights = torch.cat((context_new, query_new),-1)
        weights = self.linear1(weights)
        weights = F.tanh(weights)
        weights = self.linear2(weights)
        weights=  weights.squeeze(-1)
        if mask is not None:
            mask = mask.unsqueeze(1)
            mask = mask.repeat(1,weights.size(1),1)
            weights.masked_fill_(mask, -float("inf"))
        weights = F.softmax(weights, dim = 2)
        
        outputs = torch.matmul(weights,context)
        return outputs

In [11]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, pad_idx, attention_type=None):
        super(Decoder, self).__init__()
        
        assert attention_type in ['dot','bahaganau','general',None]
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(vocab_size,emb_size,padding_idx=pad_idx)
        self.linear = nn.Linear(hidden_size,vocab_size)
        
        if attention_type == 'dot':
            self.attention = DotAttention(hidden_size)
        elif attention_type=='bahaganau':
            self.attention = ConAttention(hidden_size)
        elif attention_type=='general':
            self.attention = GenAttention(hidden_size)
        else:
            self.attention=None
            
        self.gru = nn.GRU(emb_size,hidden_size,batch_first=True)
        self.linear = nn.Linear(2*hidden_size, hidden_size)
        self.linear_out = nn.Linear(hidden_size, vocab_size)

    def forward(self,target,context,mask):
        embedded = self.embedding(target)
        query, hidden = self.gru(embedded)
        outputs = self.attention(context,query,mask)
        outputs = torch.cat((query,outputs),-1)
        outputs = self.linear(outputs)
        outputs = F.tanh(outputs)
        out = self.linear_out(outputs)
        return out

In [12]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
    def forward(self, batch_words, words_lens, batch_trans_in):
        context, hidden = self.encoder(batch_words, words_lens)
        mask = batch_words!=0
        mask = 1 - mask
        logits = self.decoder(batch_trans_in,context, mask)
        logits = logits.view(-1, dataset.trans_vocab.vocab_size)
        return logits
    
    def generate(self, bos_idx, eos_idx, batch_words, words_lens):
        inp = [bos_idx]
        mask = batch_words!=0
        mask = 1-mask
        context, hidden = self.encoder(batch_words, words_lens)

        for _ in range(100):
            inp_tensor = torch.LongTensor([[inp[-1]]]).to(batch_words.device)
            pred   = self.decoder(inp_tensor, context, mask)
            next_token = pred[-1].topk(1)[1].item()
            inp.append(next_token)
            if next_token == eos_idx:
                break
        return inp

In [13]:
def plot(epoch, batch_idx, train_losses, val_losses):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('epoch %s. | batch: %s | loss: %s' % (epoch, batch_idx, np.mean(train_losses[-100:])))
    plt.plot(train_losses)
    plt.subplot(132)
    plt.title('epoch %s. | loss: %s' % (epoch, np.mean(val_losses[-100:])))
    plt.plot(val_losses)
    plt.show()

In [14]:
batch_words, batch_trans_in, batch_trans_out, words_lens, trans_lens = dataset.get_batch(32, sort=True,val=False)

In [None]:
batch_words!=0

In [16]:
def one_step(batch_size, val):
    batch_words, batch_trans_in, batch_trans_out, words_lens, trans_lens = dataset.get_batch(batch_size, sort=True, val=val)
    batch_words     = batch_words.to(device)
    batch_trans_in  = batch_trans_in.to(device)
    batch_trans_out = batch_trans_out.to(device)
    words_lens      = words_lens.to(device)
    trans_lens      = trans_lens.to(device)
    
    logits = model(batch_words, words_lens, batch_trans_in)
    
    batch_trans_out = batch_trans_out.view(-1)
    mask = batch_trans_out != dataset.trans_vocab.pad_idx
    loss = criterion(logits[mask], batch_trans_out[mask])
    
    return loss

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
emb_size    = 32
hidden_size = 64
encoder = Encoder(emb_size, hidden_size,dataset.words_vocab.vocab_size, dataset.words_vocab.pad_idx).to(device)
decoder = Decoder(dataset.trans_vocab.vocab_size, emb_size, hidden_size, dataset.trans_vocab.pad_idx,'bahaganau').to(device)
model   = Model().to(device)
model.encoder = encoder
model.decoder = decoder

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

batch_size = 128
epoch      = 0
num_epochs = 10

train_losses = []
val_losses   = []

In [21]:
while epoch < num_epochs:
    for batch_idx in range(len(dataset.train_words) // batch_size):
        loss = one_step(batch_size, val=False)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        if batch_idx % 100 == 0:
            with torch.no_grad():
                loss = one_step(batch_size, val=True)
                val_losses.append(loss.item())

            plot(epoch, batch_idx, train_losses, val_losses)
            
    epoch += 1

In [19]:
def _print(val):
    batch_words, batch_trans_in, batch_trans_out, words_lens, trans_lens = dataset.get_batch(1, sort=True, val=val)
    batch_words     = batch_words.to(device)
    batch_trans_out = batch_trans_out.to(device)
    words_lens = words_lens.to(device)

    inp = model.generate(dataset.words_vocab.sos_idx, dataset.words_vocab.eos_idx, batch_words, words_lens)
            
    tokens = [dataset.trans_vocab.idx2token(idx) for idx in inp if idx not in [dataset.trans_vocab.sos_idx,
                                                                                     dataset.trans_vocab.eos_idx,
                                                                                     dataset.trans_vocab.pad_idx]]
    print('Src: ', ''.join([dataset.words_vocab.idx2token(idx) for idx in batch_words[0].tolist()]))
    print('Pred:', ''.join(tokens))
    print('Real:', ''.join([dataset.trans_vocab.idx2token(idx) for idx in batch_trans_out[0].tolist() if idx not in [dataset.trans_vocab.sos_idx,
                                                                            dataset.trans_vocab.eos_idx,
                                                                            dataset.trans_vocab.pad_idx]]))

In [20]:
for _ in range(10):
    _print(True)
    print()

Src:  MONSIEURS
Pred: MIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNSIHNS
Real: MAHSYERZ

Src:  SEMIPRECIOUS
Pred: SIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIHMIH
Real: SEHMIYPREHSHAHS

Src:  GRIBBINS
Pred: GRBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIHBIH
Real: GRIHBIHNZ

Src:  SHEILAH
Pred: SHLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYLEYL
Real: SHIYLAH

Src:  QUARTERMAN
Pred: KERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTERTER
Real: KWAORTERMAHN

Src:  PORATH
Pred: PAERAERAERAERAERAERAERAERAERAERAERAERAERAE