# Seq2Seq
Unlike the previous two notebooks, the material in this one (unless you've worked closely with NLP or the like) will likely be new to you. However, the main novelty can be summarized as follows: using RNNs in an encoder/decoder architecture.

Encoder/decoders come up *all* the time in deep learning, so it's useful to become familiar with the concept now: there are differences in implementation details of what the encoder and decoder sides looks like, but fundamentally this design pattern is quite common. Let's load in the data.

The data loading is more involved this time around, but we'll ignore the details for now. We will ultimately need to see how such work would translate to the AlphaFold case, but let's focus on the model for now:

In [263]:
import re
import os
import random
import pandas as pd
import numpy as np
import urllib.request
import unicodedata
import zipfile
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F

torch.manual_seed(0)

<torch._C.Generator at 0x1056a8f50>

In [264]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
        read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

if not os.path.exists("data.zip"):
    url = "https://download.pytorch.org/tutorial/data.zip"
    urllib.request.urlretrieve(url, "data.zip")

with zipfile.ZipFile("data.zip", 'r') as zip_ref:
    zip_ref.extractall(".")
    
input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))

def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long).view(-1, 1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

tensor_pairs = [tensorsFromPair(pair) for pair in pairs]

Reading lines...
Read 135842 sentence pairs
Trimmed to 10599 sentence pairs
Counting words...
Counted words:
fra 4345
eng 2803
['nous ne sommes pas maries .', 'we re not married .']


The `Encoder` takes the data to a context, and the `Decoder` takes the corresponding context to a corresponding prediction of the word:

![](https://pytorch.org/tutorials/_images/seq2seq.png)

In [268]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        
    def forward(self, x, hidden):
        # x : N x T x D
        N, T, D = x.shape
        embedded = self.embedding(x).view(N, T, -1)
        _, hn = self.gru(embedded, hidden)
        return hn

In [270]:
N = 1
H = 128

encoder = Encoder(input_lang.n_words, H)
decoder = Decoder(output_lang.n_words, hidden_size)

input_tensor = torch.stack([tensor_pairs[k][0] for k in range(10, 10 + N)])
init_hidden = torch.zeros((1, N, H))

encoded = encoder(input_tensor, init_hidden)

print(encoded)

tensor([[[-1.4702e-01,  4.1379e-01,  3.6360e-01,  6.2382e-02, -1.8443e-02,
           2.7656e-01, -2.2305e-01,  3.3506e-01, -6.0423e-02, -5.3020e-01,
          -1.1673e-01,  1.4360e-03, -4.9969e-01, -1.2383e-01,  3.8583e-02,
           4.5550e-04,  2.7990e-01,  4.5993e-01,  3.0449e-01,  2.7023e-01,
          -1.2698e-01, -2.5040e-01, -1.6139e-01, -5.4098e-01, -3.0216e-01,
          -1.0756e-01,  1.3527e-02, -1.3165e-01, -1.8842e-03,  3.8053e-02,
           2.4843e-02,  4.5525e-02,  2.6044e-01, -4.8941e-01, -5.0603e-02,
          -7.3428e-01, -2.4093e-01,  2.3459e-01,  8.3031e-02, -7.1811e-02,
           3.8970e-01,  2.3943e-01, -4.1332e-01, -5.9104e-02, -2.5092e-02,
           2.5156e-01, -3.5193e-02,  3.2591e-01,  7.0028e-02,  2.9215e-01,
          -3.4554e-01,  2.7159e-01,  8.5373e-02, -2.6173e-02,  7.2862e-02,
           8.1261e-02, -1.1153e-01, -3.4511e-01,  2.2750e-02,  5.2740e-02,
          -2.0742e-01, -1.5062e-01, -3.6496e-01,  7.8150e-02,  1.1871e-01,
           1.9632e-01, -6

In [271]:
class Decoder(nn.Module):
    def __init__(self, output_size, hidden_size):
        super(Decoder, self).__init__()
        
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, hidden):
        N, T, D = x.shape
        embedded = self.embedding(x).view(N, T, -1)
        output, hn = self.gru(embedded, hidden)
        output = self.out(output.view(1, -1))
        return output, hn

In [272]:
N = 1
H = 128

encoder = Encoder(input_lang.n_words, H)
decoder = Decoder(output_lang.n_words, H)

input_tensor = torch.stack([tensor_pairs[k][0] for k in range(10, 10 + N)])
output_tensor = torch.stack([tensor_pairs[k][0] for k in range(10, 10 + N)])
init_hidden = torch.zeros((1, N, H))

encoded = encoder(input_tensor, init_hidden)
decoded, decoded_hn = decoder(output_tensor[:,0:1,:], encoded)

In [273]:
epochs = 100
N = 1
batch_size = 128
learning_rate = 0.01
teacher_forcing_prob = 0.5

tensor_pairs = np.array(tensor_pairs)

encoder = Encoder(input_lang.n_words, H)
decoder = Decoder(output_lang.n_words, H)

loss_criterion = nn.CrossEntropyLoss()
encoder_optim = optim.SGD(encoder.parameters(), lr=learning_rate)
decoder_optim = optim.SGD(decoder.parameters(), lr=learning_rate)

losses = []

for epoch in range(epochs):
    idx = np.random.choice(np.arange(len(tensor_pairs)), batch_size, replace=False)
    sample = tensor_pairs[idx]
    
    encoder_optim.zero_grad()
    decoder_optim.zero_grad()
    
    loss = 0
    
    for pair in sample:
        input_sentence = pair[0]
        output_sentence = pair[1]
        
        # context becomes first hidden state of decoder
        init_hidden = torch.zeros((1, N, H))
        hidden_enc = encoder(torch.unsqueeze(input_sentence, axis=0), init_hidden)
        
        prev_word = torch.tensor(np.array([[[SOS_token]]]))
        hidden_dec = hidden_enc
        for word in output_sentence:
            predicted_word_probs, hidden_dec = decoder(prev_word, hidden_dec)
            
            if random.random() < teacher_forcing_prob:
                prev_word_value = word
            else:
                prev_word_value = torch.argmax(predicted_word_probs)
            prev_word = torch.tensor(np.array([[[prev_word_value]]]))
            
            print(predicted_word_probs.shape)
            print(word.shape)
            loss += loss_criterion(predicted_word_probs, word)
            
            if prev_word_value == EOS_token:
                break
    
    loss.backward()
    
    encoder_optim.step()
    decoder_optim.step()
    
    losses.append(loss.detach().numpy())
    
    print(f"Epoch : {epoch}/{epochs} -- Loss : {loss}")

  tensor_pairs = np.array(tensor_pairs)
  tensor_pairs = np.array(tensor_pairs)
  prev_word = torch.tensor(np.array([[[prev_word_value]]]))


torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([

torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([1, 2803])
torch.Size([1])
torch.Size([

KeyboardInterrupt: 

In [None]:
plt.plot(range(len(losses)), losses)
plt.show()

In [None]:
pair = tensor_pairs[500] # 10, 15 : i'm
input_sentence = pair[0]
output_sentence = pair[1]
    
# context becomes first hidden state of decoder
init_hidden = torch.zeros((1, N, H))
hidden_enc = encoder(torch.unsqueeze(input_sentence, axis=0), init_hidden)

prev_word = torch.tensor(np.array([[[SOS_token]]]))
hidden_dec = hidden_enc

input_sentence_words = [input_lang.index2word[int(word_idx.numpy())] for word_idx in input_sentence]
output_sentence_words = [output_lang.index2word[int(word_idx.numpy())] for word_idx in output_sentence]
predicted_output_sentence = []

for word in output_sentence:
    predicted_word_probs, hidden_dec = decoder(prev_word, hidden_dec)
    prev_word_value = torch.argmax(predicted_word_probs) 
    prev_word = torch.tensor(np.array([[[prev_word_value]]]))

    predicted_output_sentence.append(output_lang.index2word[int(prev_word_value.numpy())])
                
    if prev_word_value == EOS_token:
        break

print(input_sentence_words)
print(output_sentence_words)
print(predicted_output_sentence)

In [276]:
# ------ Version 2 ------ #
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(input_dim, emb_dim)
        
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        
        #src = [src len, batch size]
        
        embedded = self.dropout(self.embedding(src))
        
        #embedded = [src len, batch size, emb dim]
        
        outputs, (hidden, cell) = self.rnn(embedded)
        
        #outputs = [src len, batch size, hid dim * n directions]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #outputs are always from the top hidden layer
        
        return hidden, cell

In [277]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, cell):
        
        #input = [batch size]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #n directions in the decoder will both always be 1, therefore:
        #hidden = [n layers, batch size, hid dim]
        #context = [n layers, batch size, hid dim]
        
        input = input.unsqueeze(0)
        
        #input = [1, batch size]
        
        embedded = self.dropout(self.embedding(input))
        
        #embedded = [1, batch size, emb dim]
                
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        
        #output = [seq len, batch size, hid dim * n directions]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #seq len and n directions will always be 1 in the decoder, therefore:
        #output = [1, batch size, hid dim]
        #hidden = [n layers, batch size, hid dim]
        #cell = [n layers, batch size, hid dim]
        
        prediction = self.fc_out(output.squeeze(0))
        
        #prediction = [batch size, output dim]
        
        return prediction, hidden, cell


In [278]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
        assert encoder.hid_dim == decoder.hid_dim, \
            "Hidden dimensions of encoder and decoder must be equal!"
        assert encoder.n_layers == decoder.n_layers, \
            "Encoder and decoder must have equal number of layers!"
        
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        
        #src = [src len, batch size]
        #trg = [trg len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time
        
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        #tensor to store decoder outputs
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        #last hidden state of the encoder is used as the initial hidden state of the decoder
        hidden, cell = self.encoder(src)
        
        #first input to the decoder is the <sos> tokens
        input = trg[0,:]
        
        for t in range(1, trg_len):
            
            #insert input token embedding, previous hidden and previous cell states
            #receive output tensor (predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)
            
            #place predictions in a tensor holding predictions for each token
            outputs[t] = output
            
            #decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio
            
            #get the highest predicted token from our predictions
            top1 = output.argmax(1) 
            
            #if teacher forcing, use actual next token as next input
            #if not, use predicted token
            input = trg[t] if teacher_force else top1
        
        return outputs


In [None]:
device = "cpu"

INPUT_DIM = input_lang.n_words
OUTPUT_DIM = output_lang.n_words
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
        
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)

model = Seq2Seq(enc, dec, device).to(device)
model.apply(init_weights)

loss_criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

losses = []

tensor_pairs = np.array(tensor_pairs)

for epoch in range(epochs):
    idx = np.random.choice(np.arange(len(tensor_pairs)), batch_size, replace=False)
    sample = tensor_pairs[idx]
    
    optimizer.zero_grad()
    
    loss = 0
    
    for pair in sample:
        input_sentence = pair[0]
        output_sentence = pair[1]
        
        predicted_word_probs = model(input_sentence, output_sentence).view(-1, output_lang.n_words)
        sentence_len = predicted_word_probs.shape[0]
        for word_idx in range(sentence_len):
            loss += criterion(predicted_word_probs[word_idx:word_idx+1, :], output_sentence[word_idx, :])
            
    loss.backward()
    
    optimizer.step()
    losses.append(loss.detach().numpy())
    
    print(f"Epoch : {epoch}/{epochs} -- Loss : {loss}")

Epoch : 0/100 -- Loss : 6865.9345703125
Epoch : 1/100 -- Loss : 6446.55419921875
Epoch : 2/100 -- Loss : 12435.83203125
Epoch : 3/100 -- Loss : 12583.98828125
Epoch : 4/100 -- Loss : 25489.0625
Epoch : 5/100 -- Loss : 65054.875
Epoch : 6/100 -- Loss : 25454.34375
Epoch : 7/100 -- Loss : 51668.37890625
Epoch : 8/100 -- Loss : 58970.375
Epoch : 9/100 -- Loss : 39113.125
Epoch : 10/100 -- Loss : 38274.78515625
Epoch : 11/100 -- Loss : 42296.50390625
Epoch : 12/100 -- Loss : 56096.71484375
Epoch : 13/100 -- Loss : 64049.4453125
Epoch : 14/100 -- Loss : 50648.62109375
Epoch : 15/100 -- Loss : 54449.42578125
Epoch : 16/100 -- Loss : 54283.6171875
Epoch : 17/100 -- Loss : 46803.60546875
Epoch : 18/100 -- Loss : 53224.8359375
Epoch : 19/100 -- Loss : 38135.640625
Epoch : 20/100 -- Loss : 27146.1484375
Epoch : 21/100 -- Loss : 16775.2578125
Epoch : 22/100 -- Loss : 45361.890625
Epoch : 23/100 -- Loss : 39904.90625
Epoch : 24/100 -- Loss : 29011.95703125
Epoch : 25/100 -- Loss : 18885.234375
Epo

In [283]:
pair = tensor_pairs[500] # 10, 15 : i'm
input_sentence = pair[0]
output_sentence = pair[1]

predicted_output = model(input_sentence, output_sentence).view(-1, output_lang.n_words)

input_sentence_words = [input_lang.index2word[int(word_idx.numpy())] for word_idx in input_sentence]
output_sentence_words = [output_lang.index2word[int(word_idx.numpy())] for word_idx in output_sentence]
predicted_output_sentence = [output_lang.index2word[int(np.argmax(word_idx.detach().numpy()))] for word_idx in predicted_output]

print(input_sentence_words)
print(output_sentence_words)
print(predicted_output_sentence)

['tu', 'es', 'en', 'securite', '.', 'EOS']
['you', 're', 'safe', '.', 'EOS']
['SOS', 'm', 'm', 'm', 'm']
