In [42]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam
from torchtext.vocab import Vocab
import pandas as pd
import numpy as np
from collections import Counter
import random
from tqdm.notebook import tqdm

In [43]:
df = pd.read_csv('ben.txt', sep='\t', names=['en', 'ben', 'attr'], usecols=['en', 'ben'])
df.head()

Unnamed: 0,en,ben
0,Go.,যাও।
1,Go.,যান।
2,Go.,যা।
3,Run!,পালাও!
4,Run!,পালান!


In [44]:
def tokenize(row):
    en = row['en'][:-1] + ' ' + row['en'][-1]
    ben = '<bos> ' + row['ben'][:-1] + ' ' + row['ben'][-1] + ' <eos>'

    en = en.lower().split()
    ben = ben.split()

    return (en, ben)

In [45]:
df['en_ben_tokenized'] = df.apply(tokenize, axis=1)
df.dropna(inplace=True)
df.head()

Unnamed: 0,en,ben,en_ben_tokenized
0,Go.,যাও।,"([go, .], [<bos>, যাও, ।, <eos>])"
1,Go.,যান।,"([go, .], [<bos>, যান, ।, <eos>])"
2,Go.,যা।,"([go, .], [<bos>, যা, ।, <eos>])"
3,Run!,পালাও!,"([run, !], [<bos>, পালাও, !, <eos>])"
4,Run!,পালান!,"([run, !], [<bos>, পালান, !, <eos>])"


In [46]:
train_pct = 0.8
train_size = int(df.shape[0] * train_pct)
df = df.sample(frac=1)
train_data = df['en_ben_tokenized'].iloc[:train_size]
val_data = df['en_ben_tokenized'].iloc[train_size:]

train_data.shape[0], val_data.shape[0]

(3465, 867)

In [47]:
train_en_data, train_beng_data = list(zip(*train_data))
train_en_data[:5], train_beng_data[:5]

((['do', 'you', 'have', 'time', '?'],
  ['my', 'father', 'is', 'very', 'good', 'at', 'fishing', '.'],
  ['the', 'museum', 'is', 'closed', 'sundays', '.'],
  ['do', 'you', 'understand', 'what', 'i', 'mean', '?'],
  ['please', 'speak', 'slowly', '.']),
 (['<bos>', 'আপনার', 'হাতে', 'সময়', 'আছে', '?', '<eos>'],
  ['<bos>', 'আমার', 'বাবা', 'মাছ', 'ধরাতে', 'খুব', 'ভাল', '।', '<eos>'],
  ['<bos>', 'জাদুঘরটা', 'রবিবার', 'বন্ধ', 'থাকে', '।', '<eos>'],
  ['<bos>',
   'আমি',
   'যা',
   'বলতে',
   'চাইছি',
   'তুমি',
   'কি',
   'তা',
   'বুঝতে',
   'পারছো',
   '?',
   '<eos>'],
  ['<bos>', 'অনুগ্রহ', 'করে', 'ধিরে', 'কথা', 'বলুন', '।', '<eos>']))

In [48]:
val_en_data, val_beng_data = list(zip(*val_data))
val_en_data[:5], val_beng_data[:5]

((['a', 'car', 'hit', 'tom', '.'],
  ['one',
   'million',
   'people',
   'lost',
   'their',
   'lives',
   'in',
   'the',
   'war',
   '.'],
  ["tom's", 'alive', '.'],
  ['is', 'it', 'still', 'raining', '?'],
  ['he', 'was', 'very', 'patient', '.']),
 (['<bos>', 'একটা', 'গাড়ী', 'টমকে', 'ধাক্কা', 'মারল', '।', '<eos>'],
  ['<bos>',
   'যুদ্ধে',
   'দশ',
   'লক্ষ',
   'মানুষ',
   'তাদের',
   'প্রাণ',
   'হারিয়েছিলেন',
   '।',
   '<eos>'],
  ['<bos>', 'টম', 'বেঁচে', 'আছে', '।', '<eos>'],
  ['<bos>', 'এখনো', 'বৃষ্টি', 'পরছে', '?', '<eos>'],
  ['<bos>', 'ও', 'খুব', 'ধৈর্যশীল', 'ছিলো', '।', '<eos>']))

## Create Vocabs

In [49]:
counter = Counter()
for sent in train_en_data:
    counter.update(sent)
en_vocab = Vocab(counter)

en_PAD_IDX = en_vocab['<pad>']
en_PAD_IDX

1

In [50]:
en_vocab.itos[:10]

['<unk>', '<pad>', '.', '?', 'i', 'tom', 'you', 'to', 'is', 'the']

In [51]:
counter = Counter()
for sent in train_beng_data:
    counter.update(sent)
beng_vocab = Vocab(counter)

beng_PAD_IDX = beng_vocab['<pad>']
beng_PAD_IDX

1

In [52]:
beng_vocab.itos[:10]

['<unk>', '<pad>', '<bos>', '<eos>', '।', '?', 'আমি', 'টম', 'আমার', 'কি']

## Dataset and DataLoader

In [53]:
class CustomDataset(Dataset):
    def __init__(self, en_data, beng_data):
        self.en_data = en_data
        self.beng_data = beng_data
    
    def __getitem__(self, index):
        en_tensor = torch.LongTensor([en_vocab[token] for token in self.en_data[index]])

        beng_tensor = torch.LongTensor([beng_vocab[token] for token in self.beng_data[index]])
        beng_input_tensor = beng_tensor[:-1]
        beng_target_tensor = beng_tensor[1:]

        return (en_tensor, beng_input_tensor, beng_target_tensor)
    
    def __len__(self):
        return len(self.en_data)

In [54]:
def collate_fn(batch_data):
    en, ben_input, ben_target = [], [], []
    for b in batch_data:
        en.append(b[0])
        ben_input.append(b[1])
        ben_target.append(b[2])
    return (pad_sequence(en, batch_first=True, padding_value=en_PAD_IDX),
            pad_sequence(ben_input, batch_first=True, padding_value=beng_PAD_IDX),
            pad_sequence(ben_target, batch_first=True, padding_value=beng_PAD_IDX))

In [55]:
train_ds = CustomDataset(list(train_en_data).copy(), list(train_beng_data).copy())
val_ds = CustomDataset(list(val_en_data).copy(), list(val_beng_data).copy())

BATCH_SIZE = 16
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2)

## Model

In [56]:
class Translator(nn.Module):
    def __init__(self, en_vocab_len, beng_vocab_len, en_emb_size, beng_emb_size, hidden_size):
        super(Translator, self).__init__()
        
        self.en_emb_layer = nn.Embedding(num_embeddings=en_vocab_len, embedding_dim=en_emb_size, padding_idx=en_PAD_IDX)
        self.en_lstm_layer = nn.LSTM(input_size=en_emb_size, hidden_size=hidden_size, num_layers=2, batch_first=True)
        self.beng_emb_layer = nn.Embedding(num_embeddings=beng_vocab_len, embedding_dim=beng_emb_size, padding_idx=beng_PAD_IDX)
        self.beng_lstm_layer = nn.LSTM(input_size=beng_emb_size, hidden_size=hidden_size, num_layers=2, batch_first=True)
        self.beng_linear = nn.Linear(hidden_size, beng_vocab_len)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, xb, yb_i, h, c):
        if h is None:
            out = self.en_emb_layer(xb)
            out = self.dropout(out)
            _, (h, _) = self.en_lstm_layer(out)

        out = self.beng_emb_layer(yb_i)
        out = self.dropout(out)
        out, (h, c) = self.beng_lstm_layer(out, (h, c))
        out = out.reshape((-1, out.shape[-1]))
        out = self.beng_linear(out)

        return out, h, c
    
    def init_hidden(self, bs):
        h = None
        c = torch.zeros((2, bs, 256), device=torch.device('cuda'))

        return h, c

In [57]:
class AttentionEncoder(nn.Module):
    def __init__(self, en_vocab_len, en_emb_size, hidden_size, num_layers):
        super(AttentionEncoder, self).__init__()
        
        self.en_emb_layer = nn.Embedding(num_embeddings=en_vocab_len, embedding_dim=en_emb_size, padding_idx=en_PAD_IDX)
        self.en_lstm_layer = nn.LSTM(input_size=en_emb_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, xb):
        out = self.en_emb_layer(xb)
        out = self.dropout(out)
        out, (h, _) = self.en_lstm_layer(out)

        return out, h

In [58]:
class AttentionDecoder(nn.Module):
    def __init__(self, beng_vocab_len, beng_emb_size, hidden_size, num_layers):
        super(AttentionDecoder, self).__init__()
        
        self.beng_emb_layer = nn.Embedding(num_embeddings=beng_vocab_len, embedding_dim=beng_emb_size, padding_idx=beng_PAD_IDX)
        self.beng_lstm_layer = nn.LSTM(input_size=hidden_size + beng_emb_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.beng_linear = nn.Linear(hidden_size, beng_vocab_len)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, yb_i, context_vector, h, c):
        yb_i = yb_i.unsqueeze(1)
        out = self.beng_emb_layer(yb_i)
        out = self.dropout(out)

        lstm_input = torch.cat((context_vector, out), dim=2)
        out, (h, c) = self.beng_lstm_layer(lstm_input, (h, c))
        out = out.reshape((-1, out.shape[-1]))
        out = self.beng_linear(out)

        return out, h, c

In [59]:
class AttentionSeq2Seq(nn.Module):
    def __init__(self, en_vocab_len, beng_vocab_len, en_emb_size, beng_emb_size, hidden_size, num_layers=1):
        super(AttentionSeq2Seq, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.encoder = AttentionEncoder(en_vocab_len, en_emb_size, hidden_size, num_layers)
        self.decoder = AttentionDecoder(beng_vocab_len, beng_emb_size, hidden_size, num_layers)
        self.attention = nn.Sequential(
            nn.Linear(hidden_size*2, 1),
            nn.ReLU()
        )
    
    def get_attention_context(self, enc_hidden_states, h):
        # enc_hidden_states.shape -> (bs, N, hidden_size)
        # h.shape -> (1, bs, hidden_size)

        bs = enc_hidden_states.shape[0]
        seq_len = enc_hidden_states.shape[1]
        hidden_size = enc_hidden_states.shape[2]

        new_h = h.permute(1, 0, 2).repeat(1, seq_len, 1) # new_h.shape -> (bs, N, hidden_size)
        att_out = self.attention(torch.cat((enc_hidden_states, new_h), dim=2)) # att_out.shape -> (bs, N, 1)
        att_out = att_out.view((bs, seq_len)) # att_out.shape -> (bs, N)
        att_score = nn.Softmax(dim=1)(att_out)

        # att_score.unsqueeze(2).shape -> (bs, N, 1)
        context_vector = att_score.unsqueeze(2) * enc_hidden_states #context_vector.shape -> (bs, N, hidden_size)
        context_vector = torch.sum(context_vector, dim=1) # context_vector.shape -> (bs, hidden_size)
        context_vector = context_vector.unsqueeze(1) # context_vector.shape -> (bs, 1, hidden_size)
        return context_vector
    
    def forward(self, xb, yb_i, enc_hidden_states, h, c):
        result = None
        if h is None:
            enc_hidden_states, h = self.encoder(xb)

        for i in range(yb_i.shape[1]):
            context_vector = self.get_attention_context(enc_hidden_states, h)

            out, h, c = self.decoder(yb_i[:, i], context_vector, h, c)
            out = out.unsqueeze(1)
            result = out if result is None else torch.cat([result, out], dim=1)
        
        result = result.view(-1, result.shape[-1])
        return result, enc_hidden_states, h, c
    
    def init_hidden(self, bs, ):
        enc_hidden_states = None
        h = None
        c = torch.zeros((self.num_layers, bs, self.hidden_size), device=torch.device('cuda'))

        return enc_hidden_states, h, c

## Training

In [61]:
def fit(model, epochs, lr, attention):
    criterion = nn.CrossEntropyLoss(ignore_index=beng_PAD_IDX)
    opt = Adam(model.parameters(), lr=lr)
    best_val_loss = None
    for epoch in tqdm(range(epochs), leave=True):
        batch_train_loss = torch.empty(0, device=torch.device('cuda'))
        batch_val_loss = torch.empty(0, device=torch.device('cuda'))

        model.train()
        for xb, yb_i, yb_t in tqdm(train_dl, 'Training', leave=False):
            bs = xb.shape[0]
            xb = xb.to('cuda')
            yb_i = yb_i.to('cuda')
            yb_t = yb_t.to('cuda')
            
            if attention:
                enc_hidden_states, h, c = model.init_hidden(bs)
                y_hat, enc_hidden_states, h, c = model(xb, yb_i, enc_hidden_states, h, c)
            else:
                h, c = model.init_hidden(bs)
                y_hat, h, c = model(xb, yb_i, h, c)
            
            loss = criterion(y_hat, yb_t.view(-1))
            opt.zero_grad()
            loss.backward()
            opt.step()

            batch_train_loss = torch.cat((batch_train_loss, loss.view(1, -1).detach()))
        
        model.eval()
        with torch.no_grad():
            for xb, yb_i, yb_t in tqdm(val_dl, 'Validation', leave=False):
                bs = xb.shape[0]
                xb = xb.to('cuda')
                yb_i = yb_i.to('cuda')
                yb_t = yb_t.to('cuda')
                
                if attention:
                    enc_hidden_states, h, c = model.init_hidden(bs)
                    y_hat, enc_hidden_states, h, c = model(xb, yb_i, enc_hidden_states, h, c)
                else:
                    h, c = model.init_hidden(bs)
                    y_hat, h, c = model(xb, yb_i, h, c)
                
                loss = criterion(y_hat, yb_t.view(-1))
                batch_val_loss = torch.cat((batch_val_loss, loss.view(1, -1).detach()))
        
        mean_batch_val_loss = torch.mean(batch_val_loss).cpu().item()
        if best_val_loss is None or best_val_loss >=  mean_batch_val_loss:
            torch.save(model.state_dict(), 'bestmodel.pth')
            best_val_loss = mean_batch_val_loss
        
        print(f'epoch - {epoch+1} | train loss - {torch.mean(batch_train_loss).cpu().item()} | val loss - {mean_batch_val_loss}')

In [62]:
epochs = 30
lr = 0.001
lstm_model = Translator(len(en_vocab), len(beng_vocab), 100, 100, 256).to('cuda')
fit(lstm_model, epochs, lr, attention=False)
lstm_model.load_state_dict(torch.load('bestmodel.pth'))

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 1 | train loss - 5.120729446411133 | val loss - 4.736519813537598


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 2 | train loss - 4.486807346343994 | val loss - 4.556915283203125


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 3 | train loss - 4.171319484710693 | val loss - 4.367337226867676


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 4 | train loss - 3.8595850467681885 | val loss - 4.144284248352051


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 5 | train loss - 3.560281991958618 | val loss - 4.00666618347168


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 6 | train loss - 3.2371394634246826 | val loss - 3.7940568923950195


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 7 | train loss - 2.956050157546997 | val loss - 3.658188581466675


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 8 | train loss - 2.693800210952759 | val loss - 3.5964508056640625


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 9 | train loss - 2.464211940765381 | val loss - 3.538093090057373


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 10 | train loss - 2.249990224838257 | val loss - 3.4762260913848877


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 11 | train loss - 2.064470052719116 | val loss - 3.464493989944458


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 12 | train loss - 1.888898253440857 | val loss - 3.389040231704712


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 13 | train loss - 1.7309865951538086 | val loss - 3.359445333480835


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 14 | train loss - 1.586377739906311 | val loss - 3.3805203437805176


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 15 | train loss - 1.4574626684188843 | val loss - 3.340514659881592


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 16 | train loss - 1.3489375114440918 | val loss - 3.3396050930023193


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 17 | train loss - 1.234498143196106 | val loss - 3.3601388931274414


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 18 | train loss - 1.1381703615188599 | val loss - 3.3553214073181152


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 19 | train loss - 1.0530673265457153 | val loss - 3.3787901401519775


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 20 | train loss - 0.9739530086517334 | val loss - 3.344196319580078


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 21 | train loss - 0.9034404158592224 | val loss - 3.3719370365142822


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 22 | train loss - 0.8394720554351807 | val loss - 3.3523449897766113


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 23 | train loss - 0.7879330515861511 | val loss - 3.3593904972076416


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 24 | train loss - 0.730983316898346 | val loss - 3.325222969055176


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 25 | train loss - 0.6847972869873047 | val loss - 3.372664213180542


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 26 | train loss - 0.6390566229820251 | val loss - 3.368350028991699


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 27 | train loss - 0.6036423444747925 | val loss - 3.3767971992492676


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 28 | train loss - 0.5621861219406128 | val loss - 3.440016269683838


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 29 | train loss - 0.5302238464355469 | val loss - 3.4356515407562256


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 30 | train loss - 0.50298672914505 | val loss - 3.403416156768799



<All keys matched successfully>

In [63]:
epochs = 30
lr = 0.001
att_lstm_model = AttentionSeq2Seq(len(en_vocab), len(beng_vocab), 100, 100, 256, 1).to('cuda')
fit(att_lstm_model, epochs, lr, attention=True)
att_lstm_model.load_state_dict(torch.load('bestmodel.pth'))

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 1 | train loss - 4.97560977935791 | val loss - 4.471499919891357


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 2 | train loss - 4.0816802978515625 | val loss - 4.113191604614258


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 3 | train loss - 3.583799362182617 | val loss - 3.854275703430176


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 4 | train loss - 3.157088279724121 | val loss - 3.5914077758789062


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 5 | train loss - 2.7576401233673096 | val loss - 3.4340643882751465


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 6 | train loss - 2.409083127975464 | val loss - 3.3165605068206787


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 7 | train loss - 2.0927743911743164 | val loss - 3.1831958293914795


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 8 | train loss - 1.8271565437316895 | val loss - 3.1087534427642822


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 9 | train loss - 1.5847270488739014 | val loss - 3.028921365737915


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 10 | train loss - 1.3857178688049316 | val loss - 2.9932665824890137


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 11 | train loss - 1.2057846784591675 | val loss - 2.9453229904174805


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 12 | train loss - 1.0570112466812134 | val loss - 2.8937249183654785


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 13 | train loss - 0.9288049936294556 | val loss - 2.9013781547546387


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 14 | train loss - 0.8146975636482239 | val loss - 2.920964479446411


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 15 | train loss - 0.7213859558105469 | val loss - 2.913313388824463


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 16 | train loss - 0.6438480615615845 | val loss - 2.8840112686157227


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 17 | train loss - 0.5690960884094238 | val loss - 2.8955283164978027


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 18 | train loss - 0.507536768913269 | val loss - 2.8835315704345703


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 19 | train loss - 0.4639299809932709 | val loss - 2.906773090362549


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 20 | train loss - 0.4175020456314087 | val loss - 2.9280011653900146


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 21 | train loss - 0.3801088035106659 | val loss - 2.90811824798584


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 22 | train loss - 0.34220683574676514 | val loss - 3.0159852504730225


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 23 | train loss - 0.3203802704811096 | val loss - 2.938331127166748


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 24 | train loss - 0.2974872589111328 | val loss - 2.952138900756836


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 25 | train loss - 0.27735069394111633 | val loss - 2.9582724571228027


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 26 | train loss - 0.2587052881717682 | val loss - 3.0338428020477295


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 27 | train loss - 0.2467300444841385 | val loss - 3.006430149078369


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 28 | train loss - 0.22967708110809326 | val loss - 2.9816954135894775


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 29 | train loss - 0.2193228006362915 | val loss - 2.999763011932373


HBox(children=(FloatProgress(value=0.0, description='Training', max=217.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=55.0, style=ProgressStyle(description_wi…

epoch - 30 | train loss - 0.21508552134037018 | val loss - 3.0640957355499268



<All keys matched successfully>

## Inference

In [71]:
def predict(model, en_text, attention):
    en_text = en_text[:-1] + ' ' + en_text[-1]
    en_tensor = torch.LongTensor([en_vocab[tokens] for tokens in en_text.lower().split()]).view(1, -1).to('cuda')
    beng_start_tensor = torch.LongTensor([beng_vocab['<bos>']]).view(1, -1).to('cuda')
    beng_prev_word_tensor = beng_start_tensor
    beng_text = ''

    if attention:
        enc_hidden_states, h, c = model.init_hidden(1)
    else:
        h, c = model.init_hidden(1)
    count = 0
    model.eval()
    with torch.no_grad():
        while True:
            if attention:
                y_hat, enc_hidden_states, h, c = model(en_tensor, beng_prev_word_tensor, enc_hidden_states, h, c)
            else:
                y_hat, h, c = model(en_tensor, beng_prev_word_tensor, h, c)
            y_hat = nn.Softmax(dim=1)(y_hat)
            y_hat = torch.argmax(y_hat)

            beng_word = beng_vocab.itos[y_hat.item()]            

            if beng_word == '<eos>' or count == 15:
                break
            
            beng_text += ' ' + beng_word
            beng_prev_word_tensor = y_hat.view(1, -1)
            count += 1
    return beng_text

In [75]:
models = [lstm_model, att_lstm_model]
sample_en_texts = ['who are you?',
                   'how are you?',
                   'try hard.',
                   'this is my daughter.',
                   'can you read?',
                   'he is actually not the manager.',
                   'i asked tom to come with me.',
                   'how far is it from here?',
                   'are you mad?']
model_translations = []

for model_no, model in enumerate(models):
    translation = []

    for text in sample_en_texts:
        attention = True if model_no == 1 else False
        translation.append(predict(model, text, attention))
    
    model_translations.append(translation)

pd.DataFrame({'English': sample_en_texts, 'Bengali without Attention': model_translations[0], 'Bengali with Attention': model_translations[1]})

Unnamed: 0,English,Bengali without Attention,Bengali with Attention
0,who are you?,তুমি কী এনেছো ?,আপনি কে ?
1,how are you?,তুমি কেমন আছো ?,আপনি কেমন আছেন ?
2,try hard.,একটু তারাতারি করো ।,আরও চেষ্টা করুন ।
3,this is my daughter.,এটা আমার ছোটো বোন ।,এটা আমার মেয়ে ।
4,can you read?,তুমি কি ওখানে জন্মেছিলে ?,তুমি পড়তে পারো ?
5,he is actually not the manager.,তিনি একজন শিক্ষক এবং ঔপন্যাসিক ।,উনি আসলে ম্যানেজারই নন ।
6,i asked tom to come with me.,আমি টমকে জিজ্ঞাসা করলাম যে সে কোথায় গেছিলো ।,আমি চাই টম আমাকে আমার সাথে আসো ।
7,how far is it from here?,এয়ারপোর্ট কত দুর এখান থেকে ?,এখান থেকে কতটা কতটা দূরে ?
8,are you mad?,তুমি কি চিন্তিত ?,আপনি কি চিন্তিত ?
