<h1>Let's see what's the transformer inside...</h1>

<h2>Attention is all you need</h2>

* The core behind the Transformer model is the attention mechanism
* The intuition behind attention is that rather than compressing the input, it might be better for the model to revisit the input sequence at every it's logical part.
* Rather than always seeing the same representation of the input, one might imagine that the decoder should selectively focus on particular parts of the input sequence at particular decoding steps.
* The encoder could produce a representation of length equal to the original input sequence. Then, at decoding time, the decoder can (via some control mechanism) receive as input a context vector consisting of a weighted sum of the representations on the input at each part. Intuitively, the weights determine the extent to which each step’s context “focuses” on each input token, and the key is to make this process for assigning the weights differentiable so that it can be learned along with all of the other neural network parameters.

<h2>But what does Attention actually do?</h2>

Let's breakdown the simple example. Cosider we're having a dict of values:

In [27]:
d = {
    "apple": "green",
    "banana": "yellow",
    "orange": "orange",
    "pickles": "green",
    "pineapple": "brown",
    "raspberry": "crimson"
}

In [28]:
# name -> height
d = {
    "Ann": 167,
    "Mike": 186,
    "Michelle": 170,
    "Andrew": 176
}

In [29]:
# d["Anna"]

# "Anna": 0.98 * d["Ann"] + 0.05 * d["Mike"] + 0.05 * d["Michelle"] + 0.1 * d["Andrew"] = ...
# \alpha(q, k_i) * v_i

# \alpha(q, k_i) = 1 / len(d)
# \sum (\alpha(q, k_i) * v_i )
# 1/4 * 167 + 1/4 * 186 + 1/4 * 170 + 1/4 * 176

This dict $D$ has keys $k$ and values $v$. We can operate on $D$, for instance with the exact query $q$ for “pickles” which would return the value “green”. If “pickles”:"green" was not a record in $D$, there would be no valid answer. If we also allowed for approximate matches, we would retrieve (“pineapple”, “brown”) instead. This quite simple and trivial example nonetheless teaches us a number of useful things:

* We can design queries $q$ that operate on ($k$, $v$) pairs in such a manner as to be valid regardless of the database size.

* The same query can receive different answers, according to the contents of the database.
* The “code” being executed for operating on a large state space (the database) can be quite simple (e.g., exact match, approximate match, top-k
).

* There is no need to compress or simplify the database to make the operations effective.

For example, we would like to design something which will define the "closiness" of the query $q$ to the keys $k$ in $D$. Let's suppose we are having a function $\alpha(q, k): \mathbb{R}^{[2, n]->\mathbb{R}^n}$ which calculates some sort of key-query similarity. E.g, for $q=\text{"pi"}$ it could be something like $(0.05,\ 0.075,\ 0.075,\ 0.3,\ 0.4,\ 0.1)$

Also we want to associate this similarity with the values in the dataset. For example, we can multiply those values above by the values in the dictionary.

This gives us a basic intuition what the attention actually is:

$$\text{A} = \sum\limits_{d\in D}{\alpha(q, k_d)*v_d}$$

The operation itself is typically referred to as attention pooling. The name attention derives from the fact that the operation pays particular attention to the terms for which the $\alpha$ weight 
 is significant (i.e., large). As such, the attention over $D$
 generates a linear combination of values contained in the database. This gives us the next situations:

* The weights $\alpha(q, k) \geq 0$ : in this case the output of the attention mechanism is contained in the convex cone spanned by
the values $v$.

* The weights  $\alpha(q, k) \geq 0$ 
 form a convex combination $\sum\limits_{d\in D}{\alpha(q, k_d)}=1$,
. This is the most common setting in deep learning.

* Exactly one of the weights 
 is 1, while all others are 0
. This is akin to a traditional database query.

* All weights are equal: this amounts to averaging across the entire database, also called average pooling in deep learning.

A common strategy for ensuring that the weights sum up to 
 is to 1 normalize them via

$$\alpha(q, k_i) = \frac{\alpha(q, k_i)}{\sum_i\alpha(q, k_i)}$$
 
In particular, to ensure that the weights are also nonnegative, one can resort to exponentiation. This means that we can now pick any function 
 and then apply the softmax operation used for multinomial models to it via

$$\alpha(q, k_i) = \frac{\exp a(q, k_i)}{\sum_i\exp a(q, k_i)}$$
 


![title](qkv.svg)
![title](attention-output.svg)

<h2>Let's code one of them: scaled dot product</h2>

Both the query and the key have the same vector length, say 
$d$, even though this can be addressed easily by replacing $q^T k$
 with $q^T Mk$
 where $M$
 is a matrix suitably chosen for translating between both spaces. For now assume that the dimensions match.

In practice, we often think of minibatches for efficiency, such as computing attention for 
 queries and 
 key-value pairs, where queries and keys are of length 
 and values are of length 
. The scaled dot product attention of queries $Q\in R^{n\times d}$
, keys $K\in R^{m\times d}$
, and values $V\in R^{m\times v}$
 thus can be written as
 $$\text{softmax}\Big(\frac{QK^T}{\sqrt{d}}\Big)V$$

In [30]:
import torch
from torch import nn
import numpy as np
from typing import List

In [31]:
class DotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, queries, keys, values):
        # queries: (B, N, D)
        # keys: (B, M, D)
        # values: (B, M, V)

        assert queries.shape[0] == keys.shape[0]
        assert queries.shape[0] == values.shape[0]
        assert queries.shape[-1] == keys.shape[-1]
        assert keys.shape[1] == values.shape[1]

        d = keys.shape[-1]

        # scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        scores = queries @ keys.transpose(1, 2) / np.sqrt(d)
        attention_weights = nn.functional.softmax(scores, dim=-1)

        return attention_weights @ values

In [32]:
queries = torch.zeros((2, 1, 2))
keys = torch.zeros((2, 10, 2))
values = torch.zeros((2, 10, 4))

In [33]:
att = DotProductAttention()
att(queries, keys, values).shape

torch.Size([2, 1, 4])

<h2>Multihead attention</h2>

Before providing the implementation of multi-head attention, let’s formalize this model mathematically. Given a query 
$q\in \mathbb{R}^{d_q}$, a key $ q\in \mathbb{R}^{d_k}$
, and a value $v\in \mathbb{R}^{d_v}$
, each attention head $h_i$
is computed as
$$h_i = f(W_i^q q, W_i^k k, W_i^v v)$$ 

where $W_i^q \in \mathbb{R}^{p_q \times d_q}, W_i^k \in \mathbb{R}^{p_k \times d_k}, W_i^v \in \mathbb{R}^{p_v \times d_v}$

 are learnable parameters and $f$
 is attention pooling, such as additive attention and scaled dot product attention. The multi-head attention output is another linear transformation via learnable parameters $W_o \in \mathbb{R}^{p_o \times hd_v}$
 of the concatenation of $h$
 heads:

$$W_o [h_1, ..., h_n]^T \in \mathbb{R^{p_o}}$$
 
 
Based on this design, each head may attend to different parts of the input. More sophisticated functions than the simple weighted average can be expressed.

To avoid significant growth of computational cost and parametrization cost, lets set 
$p_q = p_k = p_v = p_o / h$
. Note that 
 heads can be computed in parallel if we set the number of outputs of linear transformations for the query, key, and value to 
$p_q h = p_k h = p_v h = p_o$. In the following implementation, 
 is specified via the argument num_hiddens.

In [34]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_hiddens, num_heads):
        super().__init__()

        self.num_heads = num_heads
        self.num_hiddens = num_hiddens

        self.attention = DotProductAttention()

        self.W_q = nn.LazyLinear(num_hiddens)
        self.W_k = nn.LazyLinear(num_hiddens)
        self.W_v = nn.LazyLinear(num_hiddens)

        self.W_o = nn.LazyLinear(num_hiddens)

    def transpose_qkv(self, X):
        """Transposition for parallel computation of multiple attention heads."""
        # Shape of input X: (batch_size, no. of queries or key-value pairs,
        # num_hiddens). Shape of output X: (batch_size, no. of queries or
        # key-value pairs, num_heads, num_hiddens / num_heads)
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        # Shape of output X: (batch_size, num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        X = X.permute(0, 2, 1, 3)
        # Shape of output: (batch_size * num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        return X.reshape(-1, X.shape[2], X.shape[3])
    
    def transpose_output(self, X):
        """Reverse the operation of transpose_qkv."""
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)

    def forward(self, queries, keys, values):
        # qkv: (B, N_qkv, num_hiddens)

        q = self.transpose_qkv(self.W_q(queries))
        k = self.transpose_qkv(self.W_k(keys))
        v = self.transpose_qkv(self.W_v(values))

        output = self.transpose_output(self.attention(q, k, v))

        return self.W_o(output)

        

In [35]:
num_hiddens = 100
num_heads = 5

mha = MultiHeadAttention(num_hiddens, num_heads)

X = torch.ones((16, 5, 5))
mha(X, X, X).shape

torch.Size([16, 5, 100])

<h2>Positional encoding</h2>

The dominant approach for preserving information about the order of tokens is to represent this to the model as an additional input associated with each token. These inputs are called positional encodings, and they can either be learned or fixed a priori. We now describe a simple scheme for fixed positional encodings based on sine and cosine functions (Vaswani et al., 2017).

Suppose that the input representation $X \in \mathbb{R}^{n\times d}$
 contains the 
d-dimensional embeddings for 
 tokens of a sequence. The positional encoding outputs $X + P$
 using a positional embedding matrix $P$
 of the same shape, whose element on the 
 i-th row and the 2j
 or the 2j+1  
 column is
 $$ p_{i, 2j} = \sin \frac{i}{10000^{2j/d}} $$
 $$ p_{i, 2j+1} = \cos \frac{i}{10000^{2j+1/d}} $$

In [36]:
class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, max_tokens=1000):
        super().__init__()
        self.P = torch.zeros((1, max_tokens, num_hiddens))
        X = torch.arange(max_tokens, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32)/num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :]
        return X

In [37]:
encoding_dim, num_steps = 32, 60
pe = PositionalEncoding(encoding_dim)
X = torch.zeros((1, num_steps, encoding_dim))
XPE = pe(X)

print(X.shape, XPE.shape)

torch.Size([1, 60, 32]) torch.Size([1, 60, 32])


In [38]:
XPE

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  5.3317e-01,  ...,  1.0000e+00,
           1.7783e-04,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.0213e-01,  ...,  1.0000e+00,
           3.5566e-04,  1.0000e+00],
         ...,
         [ 4.3616e-01,  8.9987e-01,  5.9521e-01,  ...,  9.9984e-01,
           1.0136e-02,  9.9995e-01],
         [ 9.9287e-01,  1.1918e-01,  9.3199e-01,  ...,  9.9983e-01,
           1.0314e-02,  9.9995e-01],
         [ 6.3674e-01, -7.7108e-01,  9.8174e-01,  ...,  9.9983e-01,
           1.0492e-02,  9.9994e-01]]])

<h2>Bringing everything together</h2>

![title](transformer.svg)

In [39]:
class FFN(nn.Module):
    def __init__(self, ffn_dim, num_hiddens):
        super().__init__()
        self.w1 = nn.LazyLinear(ffn_dim)
        self.w2 = nn.LazyLinear(num_hiddens)
        self.relu = nn.ReLU()
        
    def forward(self, X):
        return self.w2(self.relu(self.w1(X)))

class AddNorm(nn.Module):
    def __init__(self, num_hiddens):
        super().__init__()
        self.ln = nn.LayerNorm(num_hiddens)

    def forward(self, res, X):
        return self.ln(X + res)

In [40]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, num_hiddens, ffn_dim, num_heads):

        super().__init__()

        self.num_hiddens = num_hiddens
        self.ffn_dim = ffn_dim
        self.num_heads = num_heads

        self.attention = MultiHeadAttention(num_hiddens, num_heads)
        self.ffn = FFN(ffn_dim, num_hiddens)
        self.addnorm1 = AddNorm(num_hiddens)
        self.addnorm2 = AddNorm(num_hiddens)

    def forward(self, X):
        Y = self.addnorm1(X, self.attention(X, X, X))
        return self.addnorm2(Y, self.ffn(Y))

In [41]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, num_hiddens, num_blocks, ffn_dim, num_heads):
        super().__init__()

        self.vocab_size = vocab_size
        self.num_hiddens = num_hiddens
        
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.positional_encoding  = PositionalEncoding(num_hiddens)

        self.blocks = nn.ModuleList()

        for i in range(num_blocks):
            self.blocks.append(TransformerEncoderBlock(num_hiddens, ffn_dim, num_heads))

    def forward(self, X):
        X = self.embedding(X)
        X = self.positional_encoding(X)

        for i, block in enumerate(self.blocks):
            X = block(X)

        return X

In [42]:
encoder = TransformerEncoder(1000, 300, 6, 100, 5)

In [43]:
inp = torch.zeros((2, 100)).long()
encoder(inp).shape

torch.Size([2, 100, 300])

<h2>Your supertask!</h2>

Using this encoder implementation try to implement ViT. Have fun! :)

In [44]:
class Patcher(nn.Module):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size
    
    def forward(self, image):
        b, c, h, w = image.shape

        assert h % self.patch_size == 0 and w % self.patch_size == 0

        patches = image.unfold(-2, self.patch_size, self.patch_size).unfold(-2, self.patch_size, self.patch_size)\
            .reshape(b, c, -1, self.patch_size ** 2).transpose(1, 2).reshape(b, -1, c * self.patch_size ** 2)
        return patches

In [45]:
patcher = Patcher(16)
img = torch.rand((2, 3, 256, 256))
patcher(img).shape

torch.Size([2, 256, 768])

In [46]:
class ViTEncoder(nn.Module):
    def __init__(self, patch_size, channels, num_blocks, ffn_dim, num_heads):
        super().__init__()
        
        self.num_hiddens = patch_size ** 2 * channels
        
        self.patcher = Patcher(patch_size)
        self.positional_encoding = PositionalEncoding(self.num_hiddens)

        self.blocks = nn.ModuleList()

        for i in range(num_blocks):
            self.blocks.append(TransformerEncoderBlock(self.num_hiddens, ffn_dim, num_heads))

    def forward(self, X):
        X = self.patcher(X)
        X = self.positional_encoding(X)

        for i, block in enumerate(self.blocks):
            X = block(X)

        return X

In [47]:
encoder = ViTEncoder(16, 3, 6, 100, 6)
inp = torch.rand((2, 3, 256, 256))
encoder(inp).shape

torch.Size([2, 256, 768])

In [48]:
def get_activation(name: str, **kwargs):
    activations = {
        'relu': nn.ReLU,
        'tanh': nn.Tanh,
        'sigmoid': nn.Sigmoid,
        'silu': nn.SiLU,
        'softplus': nn.Softplus,
        'leakyrelu': nn.LeakyReLU
    }
    if name in activations.keys():
        return activations[name.lower()](**kwargs)
    else:
        raise KeyError('No such activation')

In [49]:
class ViTHead(nn.Module):
    def __init__(self, neurons: List[int], activation: str = 'relu', dropout_rate: int = 0):
        super().__init__()
        self.n_layers = len(neurons)
        
        self.blocks = nn.ModuleList()
        for n in neurons[:-1]:
            self.blocks.append(
                nn.Sequential(
                    nn.Dropout(dropout_rate),
                    nn.LazyLinear(n),
                    nn.BatchNorm1d(n),
                    get_activation(activation)
                ))
        self.blocks.append(
            nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.LazyLinear(neurons[-1])
            ))
        
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [50]:
class ViT(nn.Module):
    def __init__(
            self,
            patch_size,
            channels,
            num_blocks,
            ffn_dim,
            num_heads,
            head_neurons: List[int],
            head_activation: str = 'relu',
            head_dropout_rate: int = 0):
        super().__init__()

        self.encoder = ViTEncoder(patch_size, channels, num_blocks, ffn_dim, num_heads)
        self.head = ViTHead(head_neurons, head_activation, head_dropout_rate)

    def forward(self, X):
        X = self.encoder(X)
        X = X.view(X.shape[0], -1)
        return self.head(X)

In [53]:
model = ViT(16, 3, 6, 100, 6, [1000, 1000, 10])
img = torch.rand((2, 3, 256, 256))
model(img).shape

torch.Size([2, 10])