In [1]:
import torch
import torch.nn as nn

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_size, heads):
        super().__init__()
        self.embedding_size = embedding_size
        self.heads = heads
        self.head_dim = embedding_size // heads

        assert(self.heads * self.head_dim == self.embedding_size), "Invalid number of heads"

        self.fc_values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_queries = nn.Linear(self.head_dim, self.head_dim, bias=False)

        self.fc_out = nn.Linear(heads * self.head_dim, embedding_size)

    def forward(self, values, keys, queries, mask):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        values = self.fc_values(values)
        keys = self.fc_keys(keys)
        queries = self.fc_queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        energy = torch.softmax(energy / (self.embedding_size ** 0.5), dim=3)
        attention = torch.einsum("nhql,nlhd->nqhd", [energy, values])
        attention = attention.reshape(N, query_len, self.heads * self.head_dim)
        out = self.fc_out(attention)

        return out

In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, embedding_size, heads, forward_expansion, p):
        super().__init__()
        self.attention = MultiHeadAttention(embedding_size, heads)
        self.norm1 = nn.LayerNorm(embedding_size)

        self.feed_forward = nn.Sequential(nn.Linear(embedding_size, forward_expansion * embedding_size),
                                          nn.ReLU(),
                                          nn.Linear(forward_expansion * embedding_size, embedding_size))
        self.norm2 = nn.LayerNorm(embedding_size)

        self.dropout = nn.Dropout(p)

    def forward(self, values, keys, queries, mask):
        attention_out = self.attention(values, keys, queries, mask)
        x = self.norm1(attention_out + queries)
        x = self.dropout(x)

        ff_out = self.feed_forward(x)
        out = self.norm2(ff_out + x)
        out = self.dropout(out)

        return out

In [4]:
class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embedding_size, num_layers, heads, 
                 forward_expansion, max_length, p, device):
        super().__init__()
        self.device = device

        self.word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.positional_embedding = nn.Embedding(max_length, embedding_size)
        
        self.layers = nn.ModuleList([TransformerBlock(embedding_size, heads, forward_expansion, p) for _ in range(num_layers)])
        self.dropout = nn.Dropout(p)

    def forward(self, x, mask):
        N, seq_len = x.shape
        positions = torch.arange(0, seq_len).expand(N, seq_len).to(self.device)
        out = self.dropout((self.word_embedding(x) + self.positional_embedding(positions)))

        for layer in self.layers:
            out = layer(out, out, out ,mask)

        return out

In [5]:
class DecoderBlock(nn.Module):
    def __init__(self, embedding_size, heads, forward_expansion, p, device):
        super().__init__()
        self.attention = MultiHeadAttention(embedding_size, heads)
        self.norm = nn.LayerNorm(embedding_size)
        
        self.transformer_block = TransformerBlock(embedding_size, heads, forward_expansion, p)
        self.dropout = nn.Dropout(p)

    def forward(self, x, values, keys, src_mask, trg_mask):
        attention_out = self.attention(x, x, x, trg_mask)
        queries = self.dropout(self.norm(attention_out + x))
        out = self.transformer_block(values, keys, queries, src_mask)

        return out

In [6]:
class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, embedding_size, num_layers, heads,
                 forward_expansion, max_length, p, device):
        super().__init__()
        self.device = device

        self.word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.positional_embedding = nn.Embedding(max_length, embedding_size)

        self.layers = nn.ModuleList([DecoderBlock(embedding_size, heads, forward_expansion, p, device) for _ in range(num_layers)])
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(p)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_len = x.shape
        positions = torch.arange(0, seq_len).expand(N, seq_len).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.positional_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)

        return out

In [7]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, embedding_size=512,
                 num_layers=6, forward_expansion=4, heads=8, max_length=100, p=0, device="cpu"):
        super().__init__()
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

        self.encoder = Encoder(src_vocab_size, embedding_size, num_layers, heads, 
                               forward_expansion, max_length, p, device)
        
        self.decoder = Decoder(trg_vocab_size, embedding_size, num_layers, heads, 
                               forward_expansion, max_length, p, device)
        
    def get_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2).to(self.device)
        return src_mask

    def get_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len).to(self.device)
        return trg_mask

    def forward(self, src, trg):
        src_mask = self.get_src_mask(src)
        trg_mask = self.get_trg_mask(trg)

        enc_out = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_out, src_mask, trg_mask)

        return out

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10

In [9]:
device

device(type='cpu')

In [10]:
src = torch.tensor([[1, 5, 7, 3, 4, 6, 8, 9, 2, 0, 0],
                    [1, 4, 3, 3, 7, 8, 9, 5, 3, 2, 0],
                    [1, 8, 9, 4, 5, 8, 3, 6, 6, 9, 2]]).to(device)

trg = torch.tensor([[1, 5, 4, 3, 9, 4, 2, 0],
                    [1, 6, 9, 8, 3, 6, 8, 2],
                    [1, 9, 8, 7, 6, 2, 0, 0]]).to(device)

In [11]:
src.shape, trg.shape

(torch.Size([3, 11]), torch.Size([3, 8]))

In [12]:
net = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx).to(device)

In [13]:
net

Transformer(
  (encoder): Encoder(
    (word_embedding): Embedding(10, 512)
    (positional_embedding): Embedding(100, 512)
    (layers): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadAttention(
          (fc_values): Linear(in_features=64, out_features=64, bias=False)
          (fc_keys): Linear(in_features=64, out_features=64, bias=False)
          (fc_queries): Linear(in_features=64, out_features=64, bias=False)
          (fc_out): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (feed_forward): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0, inplace=False)
      )
      (1): TransformerBlock(
        (attention): MultiHeadAttention(


In [14]:
out = net(src, trg[:, :-1])
out, out.shape

(tensor([[[ 1.8968e-01,  4.6380e-01,  5.1210e-01,  8.0332e-01,  4.1375e-01,
           -1.5518e-02, -3.1026e-01, -1.0664e+00, -5.0201e-01,  1.4194e-01],
          [-7.5107e-01, -6.0959e-03,  8.5103e-01, -2.7004e-01,  1.4141e-01,
            6.5634e-01, -4.3093e-01, -2.4749e-01, -7.1317e-01,  6.8862e-02],
          [-5.7888e-01, -1.1531e-01,  5.0558e-01, -4.4965e-03,  5.6342e-01,
           -2.3258e-01,  2.5970e-02, -8.3897e-01, -3.6586e-01,  2.1711e-01],
          [-1.6662e-01,  8.8083e-01,  7.3900e-01,  1.4125e-01, -1.2750e-01,
            2.5745e-01, -1.1150e-01, -7.0900e-01,  1.0390e-01,  5.8816e-01],
          [-4.2889e-01,  1.8837e-01,  8.5323e-01,  3.0196e-01,  2.9685e-01,
           -6.1869e-02,  1.3138e-01, -1.3593e+00, -3.5308e-01, -7.7775e-01],
          [-6.5435e-01, -2.1466e-01,  2.1622e-01, -7.4070e-01,  5.3855e-01,
            2.5033e-03, -5.9645e-01, -8.0153e-01,  3.9680e-01,  2.5530e-02],
          [ 1.9032e-01, -2.9666e-02,  1.3042e-01, -3.1520e-01, -5.4711e-01,
      