# MT using encoder-decoder

In [132]:
import csv
import time
import spacy
import torch
import random
import numpy as np
import tqdm
from torch import nn
import torch.nn.functional as F
from torchtext.legacy import data
from torchtext.legacy import datasets
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, GloVe
from string import punctuation

In [133]:
SEED = 456
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [134]:
BATCH_SIZE = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

spacy_tokenizer_en = get_tokenizer('spacy', language='en_core_web_sm')
spacy_tokenizer_de = get_tokenizer('spacy', language='de_core_news_sm')

DEFIELD = data.Field(tokenize = spacy_tokenizer_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            batch_first = True)
ENFIELD = data.Field(tokenize = spacy_tokenizer_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            batch_first = True)

train_data, val_data, test_data = datasets.Multi30k.splits(exts = ('.de', '.en'), fields = (DEFIELD, ENFIELD))


train_data, temp = train_data.split(split_ratio=0.02, random_state=random.seed(SEED))


DEFIELD.build_vocab(train_data, min_freq = 1)
ENFIELD.build_vocab(train_data, min_freq = 1)

train_dataloader, val_dataloader, test_dataloader = data.BucketIterator.splits(
    (train_data, val_data, test_data), 
     batch_size = BATCH_SIZE,
     device = device)

In [135]:
class Encoder(nn.Module):
  def __init__(self, vocab_size, embed_dim, num_hiddens, num_layers=1):
    super(Encoder, self).__init__()
    self.num_hiddens = num_hiddens
    self.num_layers = num_layers

    self.embedding = nn.Embedding(vocab_size, embed_dim)
    self.lstm = nn.LSTM(embed_dim, num_hiddens, num_layers, batch_first=True)

  def forward(self, inputs, hidden):
    embeddings = self.embedding(inputs)
    
    output, (h_state, c_state) = self.lstm(embeddings, hidden)
    return output, hidden

  def init_hidden(self, batch_size=1):
    return (torch.zeros(self.num_layers, batch_size, self.num_hiddens, device=device),
            torch.zeros(self.num_layers, batch_size, self.num_hiddens, device=device))

In [136]:
class Decoder(nn.Module):
  def __init__(self, embed_dim, num_hiddens, output_size, num_layers=1):
    super(Decoder, self).__init__()
    self.num_hiddens = num_hiddens
    self.output_size = output_size
    self.num_layers = num_layers

    self.embedding = nn.Embedding(self.output_size, embed_dim)
    self.lstm = nn.LSTM(embed_dim, self.num_hiddens, batch_first=True)
    self.classifier = nn.Linear(self.num_hiddens, self.output_size)

  def forward(self, inputs, hidden, encoder_outputs):
    embeddings = self.embedding(inputs).view(1, -1).unsqueeze(0)

    output, hidden = self.lstm(embeddings, hidden)
    output = F.log_softmax(self.classifier(output[0]), dim=1)
    return output, hidden, output

In [137]:
class DecoderBatched(nn.Module):
  def __init__(self, embed_dim, num_hiddens, output_size, num_layers=1):
    super(DecoderBatched, self).__init__()
    self.num_hiddens = num_hiddens
    self.output_size = output_size
    self.num_layers = num_layers

    self.embedding = nn.Embedding(self.output_size, embed_dim)
    self.lstm = nn.LSTM(embed_dim, self.num_hiddens, batch_first=True)
    self.classifier = nn.Linear(self.num_hiddens, self.output_size)

  def forward(self, inputs, hidden, encoder_outputs):
    embeddings = self.embedding(inputs)

    output, hidden = self.lstm(embeddings, hidden)
    output = F.log_softmax(self.classifier(output), dim=2)
    return output, hidden, output

In [138]:
class BahdanauDecoder(nn.Module):
  def __init__(self, embed_dim, num_hiddens, output_size, num_layers=1):
    super(BahdanauDecoder, self).__init__()
    self.num_hiddens = num_hiddens
    self.output_size = output_size
    self.num_layers = num_layers

    self.embedding = nn.Embedding(self.output_size, embed_dim)
    
    self.fc_hidden = nn.Linear(self.num_hiddens, self.num_hiddens, bias=False)
    self.fc_encoder = nn.Linear(self.num_hiddens, self.num_hiddens, bias=False)
    self.weight = nn.Parameter(torch.FloatTensor(1, num_hiddens))
    self.attn_combine = nn.Linear(self.num_hiddens * 2, self.num_hiddens)
    self.lstm = nn.LSTM(self.num_hiddens + embed_dim, self.num_hiddens, batch_first=True)
    self.classifier = nn.Linear(self.num_hiddens, self.output_size)

  def forward(self, inputs, hidden, encoder_outputs):
    encoder_outputs = encoder_outputs.squeeze()
    embeddings = self.embedding(inputs).view(1, -1)
    
    x = torch.tanh(self.fc_hidden(hidden[0])+self.fc_encoder(encoder_outputs))
    alignment_scores = x.bmm(self.weight.unsqueeze(2))  
    attn_weights = F.softmax(alignment_scores.view(1,-1), dim=1)
    context_vector = torch.bmm(attn_weights.unsqueeze(0),
                             encoder_outputs.unsqueeze(0))
    output = torch.cat((embeddings, context_vector[0]), 1).unsqueeze(0)
    output, hidden = self.lstm(output, hidden)
    output = F.log_softmax(self.classifier(output[0]), dim=1)
    return output, hidden, attn_weights

In [139]:
class BahdanauDecoderBatched(nn.Module):
  def __init__(self, embed_dim, num_hiddens, output_size, num_layers=1):
    super(BahdanauDecoderBatched, self).__init__()
    self.num_hiddens = num_hiddens
    self.output_size = output_size
    self.num_layers = num_layers

    self.embedding = nn.Embedding(self.output_size, embed_dim)
    
    self.fc_hidden = nn.Linear(self.num_hiddens, self.num_hiddens, bias=False)
    self.fc_encoder = nn.Linear(self.num_hiddens, self.num_hiddens, bias=False)
    self.weight = nn.Parameter(torch.FloatTensor(num_hiddens, 1))
    self.attn_combine = nn.Linear(self.num_hiddens * 2, self.num_hiddens)
    self.lstm = nn.LSTM(self.num_hiddens + embed_dim, self.num_hiddens, batch_first=True)
    self.classifier = nn.Linear(self.num_hiddens, self.output_size)

  def forward(self, inputs, hidden, encoder_outputs):
    embeddings = self.embedding(inputs)
    x = torch.tanh(self.fc_hidden(hidden[0]).squeeze(0).unsqueeze(1)+self.fc_encoder(encoder_outputs))
    alignment_scores = x.matmul(self.weight)
    attn_weights = F.softmax(alignment_scores, dim=1)
    context_vector = torch.bmm(attn_weights.squeeze(2).unsqueeze(1), encoder_outputs)
    output = torch.cat((embeddings, context_vector), 2)
    output, hidden = self.lstm(output, hidden)
    output = F.log_softmax(self.classifier(output), dim=2)
    return output, hidden, attn_weights

In [None]:
class LuongDecoder(nn.Module):
  def __init__(self, hidden_size, output_size, attention, n_layers=1, drop_prob=0.1):
    super(LuongDecoder, self).__init__()
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.n_layers = n_layers
    self.drop_prob = drop_prob
    self.attention = attention
    
    self.embedding = nn.Embedding(self.output_size, self.hidden_size)
    self.dropout = nn.Dropout(self.drop_prob)
    self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)
    self.classifier = nn.Linear(self.hidden_size*2, self.output_size)
    
  def forward(self, inputs, hidden, encoder_outputs):
    embedded = self.embedding(inputs).view(1,1,-1)
    embedded = self.dropout(embedded)
    
    lstm_out, hidden = self.lstm(embedded, hidden)
    
    alignment_scores = self.attention(lstm_out,encoder_outputs)
    attn_weights = F.softmax(alignment_scores.view(1,-1), dim=1)
    context_vector = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs)
    output = torch.cat((lstm_out, context_vector),-1)
    output = F.log_softmax(self.classifier(output[0]), dim=1)
    return output, hidden, attn_weights
  
class Attention(nn.Module):
  def __init__(self, hidden_size, method="dot"):
    super(Attention, self).__init__()
    self.method = method
    self.hidden_size = hidden_size

    if method == "general":
      self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
      
    elif method == "concat":
      self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
      self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
  
  def forward(self, decoder_hidden, encoder_outputs):
    if self.method == "dot":
      return encoder_outputs.bmm(decoder_hidden.view(1,-1,1)).squeeze(-1)
    
    elif self.method == "general":
      out = self.fc(decoder_hidden)
      return encoder_outputs.bmm(out.view(1,-1,1)).squeeze(-1)
    
    elif self.method == "concat":
      out = torch.tanh(self.fc(decoder_hidden+encoder_outputs))
      return out.bmm(self.weight.unsqueeze(-1)).squeeze(-1)

In [161]:
VOCAB_SIZE = len(ENFIELD.vocab)
OUTPUT_SIZE = len(DEFIELD.vocab)
EMBED_DIM = 300
NUM_HIDDENS = 64
NUM_LAYERS = 1 
EPOCHS = 10 
LR = 0.001

encoder = Encoder(VOCAB_SIZE, EMBED_DIM, NUM_HIDDENS, NUM_LAYERS).to(device)
decoder = Decoder(EMBED_DIM, NUM_HIDDENS, OUTPUT_SIZE, NUM_LAYERS).to(device)
decoderbatched = DecoderBatched(EMBED_DIM, NUM_HIDDENS, OUTPUT_SIZE, NUM_LAYERS).to(device)
bahdanaudecoder = BahdanauDecoder(EMBED_DIM, NUM_HIDDENS, OUTPUT_SIZE, NUM_LAYERS).to(device)
bahdanaudecoderbatched = BahdanauDecoderBatched(EMBED_DIM, NUM_HIDDENS, OUTPUT_SIZE, NUM_LAYERS).to(device)

encoder_optimizer = torch.optim.SGD(encoder.parameters(), lr=LR)
decoder_optimizer = torch.optim.SGD(decoder.parameters(), lr=LR)
decoderbatched_optimizer = torch.optim.SGD(decoderbatched.parameters(), lr=LR)
bahdanaudecoder_optimizer = torch.optim.SGD(bahdanaudecoder.parameters(), lr=LR)
bahdanaudecoderbatched_optimizer = torch.optim.SGD(bahdanaudecoderbatched.parameters(), lr=LR)

In [28]:
teacher_forcing_prob = 0.5
def train_batched(dataloader, decoder, encoder_optimizer, decoder_optimizer):
    encoder.train()
    decoder.train()
    tk0 = tqdm.notebook.tqdm(range(1,EPOCHS+1),total=EPOCHS)
    for epoch in tk0:
        avg_loss = 0.
        tk1 = tqdm.notebook.tqdm(enumerate(dataloader),total=len(dataloader),leave=False)
        for i, batch in tk1:
            loss = 0.
            en_inp, de_out = batch.trg, batch.src
            h = encoder.init_hidden(en_inp.shape[0])
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            encoder_outputs, h = encoder(en_inp,h)

#             decoder_input = torch.tensor(DEFIELD.vocab.stoi['<sos>'], device=device).repeat([BATCH_SIZE, 1])
            decoder_input = de_out[:,0].unsqueeze(1)
            decoder_hidden = h
            output = []
            teacher_forcing = True if random.random() < teacher_forcing_prob else False

            for ii in range(1, de_out.shape[1]):
                decoder_output, decoder_hidden, attn_weights = decoder(decoder_input, 
                                                                       decoder_hidden, encoder_outputs)
                top_value, top_index = decoder_output.topk(1)
                if teacher_forcing:
                    decoder_input = de_out[:,ii].unsqueeze(1)
                else:
                    decoder_input = top_index.squeeze(2)

                loss += F.nll_loss(decoder_output.squeeze(1), de_out[:,ii])
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()
            avg_loss += loss.item()/len(dataloader)
        tk0.set_postfix(loss=avg_loss)
        
train_batched(train_dataloader, bahdanaudecoderbatched, encoder_optimizer, bahdanaudecoderbatched_optimizer)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

In [38]:
train_batched(train_dataloader, decoderbatched, encoder_optimizer, decoderbatched_optimizer)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

In [None]:
print("Actual: {}".format(' '.join(list(map(lambda x: DEFIELD.vocab.itos[x], de_out[0])))))
print("English: {}".format(' '.join(list(map(lambda x: ENFIELD.vocab.itos[x], en_inp[0])))))
print("Predicted: {}".format(' '.join(list(map(lambda x: DEFIELD.vocab.itos[x], output[:25])))))

In [None]:
encoder.eval()
decoder.eval()

batch = next(iter(test_dataloader))
en_inp, de_out = batch.trg, batch.src
h = encoder.init_hidden(1)
encoder_outputs, h = encoder(en_inp[0].unsqueeze(0),h)

decoder_input = de_out[0,0].unsqueeze(0).unsqueeze(0)
decoder_hidden = h

output = []
attentions = []
while True:
    
    decoder_output, decoder_hidden, attn_weights = bahdanaudecoderbatched(decoder_input, decoder_hidden, encoder_outputs)
    _, top_index = decoder_output.topk(1)
    decoder_input = top_index.squeeze(2)
    print(top_index.item())
    
    
    if top_index.item() == DEFIELD.vocab.stoi["<eos>"]:
        break
    output.append(top_index.item())
#     attentions.append(attn_weights.squeeze().cpu().detach().numpy())
#     print("Actual: {}".format(' '.join(list(map(lambda x: DEFIELD.vocab.itos[x], de_out[0])))))
#     print("English: {}".format(' '.join(list(map(lambda x: ENFIELD.vocab.itos[x], en_inp[0])))))
#     print("Predicted: {}".format(' '.join(list(map(lambda x: DEFIELD.vocab.itos[x], de_out[0])))))

In [162]:
teacher_forcing_prob = 0.5
def train(dataloader, decoder, encoder_optimizer, decoder_optimizer):
    encoder.train()
    decoder.train()
    tk0 = tqdm.notebook.tqdm(range(1,EPOCHS+1),total=EPOCHS)
    for epoch in tk0:
        avg_loss = 0.
        tk1 = tqdm.notebook.tqdm(enumerate(dataloader),total=len(dataloader),leave=False)
        for i, batch in tk1:
            loss = 0.
            h = encoder.init_hidden()
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            en_inp, de_out = batch.trg, batch.src
            encoder_outputs, h = encoder(en_inp,h)


            decoder_input = torch.tensor([[DEFIELD.vocab.stoi['<sos>']]], device=device)
            decoder_hidden = h
            output = []
            teacher_forcing = True if random.random() < teacher_forcing_prob else False

            for ii in range(1, de_out.shape[1]):
                decoder_output, decoder_hidden, attn_weights = decoder(decoder_input, 
                                                                       decoder_hidden, encoder_outputs)
                top_value, top_index = decoder_output.topk(1)
                if teacher_forcing:
                    decoder_input = torch.tensor([de_out[0][ii].item()],device=device)
                else:
                    decoder_input = torch.tensor([top_index.item()],device=device)

                loss += F.nll_loss(decoder_output.view(1,-1), torch.tensor([de_out[0][ii].item()],device=device))
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()
            avg_loss += loss.item()/len(dataloader)
        tk0.set_postfix(loss=avg_loss)
        
train(train_dataloader, bahdanaudecoder, encoder_optimizer, bahdanaudecoder_optimizer)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

In [None]:
encoder.eval()
decoder.eval()

batch = next(iter(test_dataloader))
h = encoder.init_hidden()
en_inp, de_out = batch.trg, batch.src
encoder_outputs, h = encoder(en_inp,h)



decoder_input = torch.tensor([[DEFIELD.vocab.stoi['<sos>']]], device=device)
decoder_hidden = h
output = []
attentions = []
i = 0
while True:
    print(i)
    i += 1
    decoder_output, decoder_hidden, attn_weights = decoder(decoder_input, decoder_hidden, encoder_outputs)
    _, top_index = decoder_output.topk(1)
    decoder_input = torch.tensor([top_index.item()],device=device)
    if top_index.item() == DEFIELD.vocab.stoi["<eos>"]:
        break
    output.append(top_index.item())
#     attentions.append(attn_weights.squeeze().cpu().detach().numpy())
print("Actual: {}".format(' '.join(list(map(lambda x: DEFIELD.vocab.itos[x], de_out)))))
print("English: {}".format(' '.join(list(map(lambda x: ENFIELD.vocab.itos[x], en_inp)))))
print("Predicted: {}".format(' '.join(list(map(lambda x: DEFIELD.vocab.itos[x], output)))))