## ref : https://nlp.seas.harvard.edu/2018/04/03/attention.html

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import math, copy, time
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context='talk')

# torch : 1.9.0+cu102  |  np : 1.19.5  |  seaborn : 0.11.1
print(f'torch : {torch.__version__}  |  np : {np.__version__}  |  seaborn : {seaborn.__version__}')

torch : 1.9.0+cu102  |  np : 1.19.5  |  seaborn : 0.11.1


## Model Architecture

In [2]:
class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture.
    Base for this and many other models.
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences"
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
    
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)


class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)
    
    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)