In [29]:
import torch
import torch.nn as nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [30]:
class SingleAttentionHead(nn.Module):
    def __init__(self, dim):
        super(SingleAttentionHead, self).__init__() # init
        self.dim = dim # dimension of model
        self.q = nn.Linear(dim, dim) # query weights matrix
        self.k = nn.Linear(dim, dim) # keys weights matrix
        self.v = nn.Linear(dim, dim) # values weight matrix

    def scaled_dot_product_attention(self, q, k, v):
        # transpose (-2, 1)
        attention = q.bmm(k.transpose(1, 2)) / math.sqrt(q.size(-1)) # matrix multiplying Q and K, scaling it
        softmax = torch.softmax(attention, dim=-1) # softmax in order to normalize and make into probabilities
        output = softmax.bmm(v) # multiply values to attention scores to get final outputs

        return output

    def forward(self, q, k, v):
        attentions = self.scaled_dot_product_attention(q, k, v)
        return attentions

In [31]:
class MultiHeadedAttention(nn.Module):
  def __init__(self, dim, num_heads):
      super(MultiHeadedAttention, self).__init__()
      assert dim % num_heads == 0, "model dimension must be divisible by number of heads"

      self.num_heads = num_heads # number of heads
      self.dim_head = dim // num_heads # dim_model/num_heads, this is the dimension of each head

      self.single_heads = nn.ModuleList([SingleAttentionHead(dim).to(device) for i in range(num_heads)])
      self.linear = nn.Linear(num_heads * dim, dim)

  def forward(self, q, k, v):
    individual_head_results = [head(q, k, v) for head in self.single_heads] # Get individual head results
    out = torch.cat(individual_head_results, dim=-1) # concatenate individual head results
    return self.linear(out)

In [32]:
class FeedForward(nn.Module):
    def __init__(self, dim, dim_ff):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(dim, dim_ff)
        self.fc2 = nn.Linear(dim_ff, dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x))) # Densely connected layer

In [33]:
import math

In [34]:
# Implements exactly positional encoding formula from "Attention is all you need" paper
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_seq_length, d_model).to(device)
        for pos in range(max_seq_length):
          for i in range((lambda x : x // 2 + 1 if x % 2 == 1 else x // 2)(d_model)):
            pe[pos, 2*i] = math.sin((pos/(pow(10000, 2*i / d_model))))
            pe[pos, 2*i + 1] = math.cos((pos/pow(10000, 2*i / d_model)))
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


In [35]:
class EncoderLayer(nn.Module):
    def __init__(self, dim, num_heads, dim_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadedAttention(dim, num_heads).to(device) # Attention layer
        self.feed_forward = FeedForward(dim, dim_ff).to(device) # FF(Dense) layer
        self.norm1 = nn.LayerNorm(dim) # Normalize
        self.norm2 = nn.LayerNorm(dim) # Normalize
        self.dropout = nn.Dropout(dropout) # Dropout

    def forward(self, x):
        attn_output = self.self_attn(x, x, x) # First get attention outputs
        x = self.norm1(x + self.dropout(attn_output)) # First add and norm
        ff_output = self.feed_forward(x) # Feed forward
        x = self.norm2(x + self.dropout(ff_output)) # Second add and norm
        return x

In [36]:
class DecoderLayer(nn.Module):
    def __init__(self, dim, num_heads, dim_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadedAttention(dim, num_heads).to(device) # First attention, just like the others
        self.cross_attn = MultiHeadedAttention(dim, num_heads).to(device)  # Cross attention, only difference is keys and queries come from encoder output
        self.feed_forward = FeedForward(dim, dim_ff).to(device) # FF
        self.norm1 = nn.LayerNorm(dim) # Norm1
        self.norm2 = nn.LayerNorm(dim) # Norm2
        self.norm3 = nn.LayerNorm(dim) # Norm3
        self.dropout = nn.Dropout(dropout) # Dropout

    def forward(self, x, enc_output):
        attn_output = self.self_attn(x, x, x) # Simply uses the current sequence to generate attention output
        x = self.norm1(x + self.dropout(attn_output)) # Add + Norm
        # x.size() = 99, 512
        # enc_output.size() = 100, 512
        attn_output = self.cross_attn(x, enc_output, enc_output) # Now, uses keys and queries from encoder and values from decoder
        x = self.norm2(x + self.dropout(attn_output)) # Add + Norm
        ff_output = self.feed_forward(x) # FF
        x = self.norm3(x + self.dropout(ff_output)) # Add + Norm
        return x

In [37]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, dim, num_heads, num_layers, dim_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()

        self.encoder_embedding = nn.Embedding(src_vocab_size, dim) # Encoder embedding ()
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, dim) # Decoder embedding
        self.positional_encoding = PositionalEncoding(dim, max_seq_length).to(device) # Positional encoding

        self.encoder_layers = [EncoderLayer(dim, num_heads, dim_ff, dropout).to(device) for i in range(num_layers)]
        self.decoder_layers = [DecoderLayer(dim, num_heads, dim_ff, dropout).to(device) for i in range(num_layers)]


        self.fc = nn.Linear(dim, tgt_vocab_size) # FC
        self.dropout = nn.Dropout(dropout) # Dropout

    def forward(self, src, tgt):
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src))) # Encoder EMBEDDING-->Positional Encoding-->Dropout
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt))) # Decoder EMBEDDING-->Positional Encoding-->Dropout


        enc_output = src_embedded # Because this is simply the encoder output (or will be)
         # This is where we use the mask, just run through enc layers
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output)

        dec_output = tgt_embedded # Decoder output (will be)
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output) # Pass in both decoder and encoder output as well as src

        output = self.fc(dec_output)
        return output

In [38]:
inp_vocab_size = 5000
target_vocab_size = 5000
dim = 512
num_heads = 8
num_layers = 6
dim_ff = 1024
max_len = 150
dropout = 0.1

transformer = Transformer(inp_vocab_size, target_vocab_size, dim, num_heads, num_layers, dim_ff, inp_vocab_size, dropout)

# Random data
inp_dat = torch.randint(1, inp_vocab_size, (64, max_len)).to(device)  # (batch_size, seq_length)
targets = torch.randint(1, inp_vocab_size, (64, max_len)).to(device)  # (batch_size, seq_length)

In [39]:
from tqdm import tqdm
transformer.to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

num_epochs = 500

for epoch in tqdm(range(num_epochs)):
    optimizer.zero_grad()
    output = transformer(inp_dat, targets[:, :-1])
    loss = criterion(output.contiguous().view(-1, target_vocab_size), targets[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    if epoch % 50 == 0:
      print("\nLoss: " + str(loss.item()) + " Epoch: " + str(epoch))

  0%|          | 1/500 [00:00<03:54,  2.13it/s]


Loss: 8.684022903442383 Epoch: 0


 10%|█         | 51/500 [00:22<03:20,  2.23it/s]


Loss: 7.7932963371276855 Epoch: 50


 20%|██        | 101/500 [00:45<02:59,  2.23it/s]


Loss: 7.108153343200684 Epoch: 100


 30%|███       | 151/500 [01:07<02:37,  2.22it/s]


Loss: 6.446406841278076 Epoch: 150


 40%|████      | 201/500 [01:30<02:14,  2.22it/s]


Loss: 5.812021732330322 Epoch: 200


 50%|█████     | 251/500 [01:52<01:51,  2.23it/s]


Loss: 5.187049388885498 Epoch: 250


 60%|██████    | 301/500 [02:15<01:29,  2.22it/s]


Loss: 4.577675819396973 Epoch: 300


 70%|███████   | 351/500 [02:37<01:06,  2.23it/s]


Loss: 3.9871010780334473 Epoch: 350


 80%|████████  | 401/500 [03:00<00:44,  2.22it/s]


Loss: 3.430532693862915 Epoch: 400


 90%|█████████ | 451/500 [03:22<00:22,  2.22it/s]


Loss: 2.931042194366455 Epoch: 450


100%|██████████| 500/500 [03:44<00:00,  2.22it/s]
