# Transformer

<center><img src="../images/attention_is_all_you_need.png" width=50% height=60% /></center>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

## Core functions

### Input/Output Embedding

1. A sequence of texts is converted into a sequence of token ids (i.e., the position of the word in the dictionary)
2. The sequence of token ids is converted into a matrix of one-hot vectors of shape *max_seq_len* $\times$ *vocab_size*).
3. The matrix is transformed into embeddings of shape *max_seq_len* $\times$ *emb_size* through a learnable weight matrix of shape *vocab_size* $\times$ *emb_size*.

In [6]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, emb_size):
        self.emb = nn.Embedding(vocab_size, emb_size)
        
    def forward(self, x):
        """
        Args:
        - x: tokenized seqs of shape [batch_size, max_seq_len]
        
        Returns
        - embeddings: of shape [batch_size, max_seq_len, emb_size]
        """
        return self.emb(x)

### Positional Encoding

Any other positional encoding function may apply, but the original paper uses the following waveforms:

<center>
$
\begin{align*}
PE(pos,i_{even})&=\sin\left(\frac{pos}{10000^{\frac{i}{d_{model}}}}\right)\\
PE(pos,i_{odd})&=\cos\left(\frac{pos}{10000^{\frac{i-1}{d_{model}}}}\right)\\
\end{align*}
$
</center>

where
- $pos$: position
- $i$: dimension

The original authors mentioned the following reasons of using the waveforms:
- It would allow the model to easily learn to attend by relative positions because $PE(pos+k)$ can always be written as a linear function of $PE(pos)$ given any $k$.
- It may allow the model to extrapolate to seq lengths longer than the ones encountered in training.

In [49]:
def PositionalEncoding(max_seq_len, emb_size):
    # get i_evens since i_odds will also use the same i_evens in the calculation
    i_even = torch.arange(0, emb_size, 2).float()

    # get positions, reshape to have 2d encoding
    pos = torch.arange(max_seq_len).reshape(max_seq_len, 1)

    # get the varioble inside sin and cos
    x = pos/torch.pow(10000, i_even/emb_size)

    # calculate and stack. flatten to match the final dimension
    PEs = torch.stack([torch.sin(x), torch.cos(x)], dim=2)
    PE = torch.flatten(PEs, start_dim=1, end_dim=2)
    return PE

In [50]:
PositionalEncoding(5, 4)

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0100,  0.9999],
        [ 0.9093, -0.4161,  0.0200,  0.9998],
        [ 0.1411, -0.9900,  0.0300,  0.9996],
        [-0.7568, -0.6536,  0.0400,  0.9992]])

### Scaled Dot-Product Attention

There is no magic behind the fancy name "self-attention". It literally means automatically attending each word to each other and see their similarity/differences. 

However, the real magic comes with the Scaled Dot-Product Attention. We all know how it works, but we do not know why it works.

$$
\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

1. Prepare 3 vectors for each word (i.e., 3 matrices):
    - Q (Query): What you are looking for
    - K (Key): The clues that guide you to your match
    - V (Value): The match
2. Calculate the attention scores between the queries and the keys using a dot product.
    - The query vector for a specific word is multiplied with the key vector of every single word in the sequence.
3. Scale the attention scores by the square root of the dimensionality of the key vectors to stabilize computation.
    - $\frac{QK^T}{\sqrt{d_k}}$ is of similar scales/varainces as $Q$ and $K$, while $QK^T$ is much larger (thus much higher variances). Easily verifiable through testing.
4. Softmax the scaled attention scores into probabilities.
    - The probability means how much each column word attends to each row word.
5. Multiply/Weight the values.
    - For each row word, return a weighted sum of the values of the column words.

In [51]:
def ScaledDotProductAttention(Q, K, V, masked=False):
    max_seq_len, d_k = Q.shape
    d_v = V.shape[1]
    scores = torch.mm(Q, torch.t(K)) / math.sqrt(d_k)
    if masked:
        mask = torch.tril(torch.ones((max_seq_len, max_seq_len)))
        mask[mask==0] = float("-Inf")
        mask[mask==1] = 0.
        scores += mask
    return torch.mm(F.softmax(scores, dim=1), V)

In [52]:
max_seq_len, d_k, d_v = 5,4,3
Q = torch.rand(max_seq_len, d_k)
K = torch.rand(max_seq_len, d_k)
V = torch.rand(max_seq_len, d_v)

In [55]:
sa = ScaledDotProductAttention(Q, K, V, masked=True)
sa

tensor([[0.3604, 0.0303, 0.5765],
        [0.2773, 0.1038, 0.6105],
        [0.3626, 0.3178, 0.6567],
        [0.3730, 0.4145, 0.5764],
        [0.3528, 0.3547, 0.4868]])

### Multi-Head Attention

If you have multiple modules that does the attention function above separately, and you concatenate them, you have multi-head attention:
<center>
$
\begin{align}
\text{MultiHead(Q,K,V)}&=\text{Concat}(\text{head}_1,\cdots,\text{head}_h)W^O\\
\text{head}_i&=\text{Attention}(QW_i^Q,KW_i^K,VW_i^V)
\end{align}
$
</center>

In [70]:
class MultiheadAttention(nn.Module):
    def __init__(self, emb_size, hidden_dim, n_heads, masked=False):
        super().__init__()
        self.emb_size = emb_size
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.masked = masked
        self.head_dim = hidden_dim // n_heads    # assume d_k == d_v
        self.qkv = nn.Linear(emb_size, 3*hidden_dim)
        self.linear = nn.Linear(hidden_dim, hidden_dim)
        
    def ScaledDotProductAttention(self, Q, K, V):
        max_seq_len, d_k = Q.shape[-2:]
        d_v = V.shape[1]
        scores = torch.matmul(Q, K.transpose(-1,-2)) / math.sqrt(d_k)
        if self.masked:
            mask = torch.tril(torch.ones((max_seq_len, max_seq_len)))
            mask[mask==0] = float("-Inf")
            mask[mask==1] = 0.
            scores += mask
        return torch.matmul(F.softmax(scores, dim=1), V)
    
    def forward(self, x, mask=None):
        batch_size, max_seq_len, _ = x.size()       # (batch_size, max_seq_len, emb_size)
        x = self.qkv(x)                             # (batch_size, max_seq_len, 3*hidden_dim)
        x = x.reshape(batch_size, sequence_length, self.n_heads, 3*self.head_dim)
        x = x.permute(0, 2, 1, 3)                   # (batch_size, n_heads, max_seq_len, 3*head_dim)
        Q, K, V = x.chunk(3, dim=-1)                # (batch_size, n_heads, max_seq_len, head_dim)
        x = self.ScaledDotProductAttention(Q, K, V) # (batch_size, n_heads, max_seq_len, head_dim)
        x = x.permute(0, 2, 1, 3)                   # (batch_size, max_seq_len, n_heads, head_dim)
        x = x.reshape(batch_size, max_seq_len, self.n_heads*self.head_dim)
        return self.linear(x)                       # (batch_size, max_seq_len, hidden_dim)

In [71]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 30
sequence_length = 5
x = torch.randn( (batch_size, sequence_length, input_dim) )

model = MultiheadAttention(input_dim, d_model, num_heads)
out = model(x)

torch.Size([30, 8, 5, 5])


In [65]:
out.shape

torch.Size([30, 5, 512])

## Encoder