# Breaking Down a Multi-Headed Self-Attention Transformer

In [ ]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.autograd import Variable
from tqdm import tqdm, trange

The following figure depicts the different layers a Multi-Headed Self-Attention transformer is divided into. 

<a href="https://imgur.com/GmgyTLk"><img src="https://i.imgur.com/GmgyTLk.png" title="source: imgur.com" alt = 'transformer in layers' height = '800' width = '100%' /></a>

Let's check the flow of data inside the network

<a href="https://imgur.com/qIo7euI"><img src="https://i.imgur.com/qIo7euI.png" title="source: imgur.com" /></a>

We will take every block one by one and see how a batch of input get's transformed through it.

# Self-Attention Block

The following figure, depicts the functions inside a **Self-Attention** block

<a href="https://imgur.com/0uNKIPA"><img src="https://i.imgur.com/0uNKIPA.png" title="source: imgur.com" /></a>

Before starting with the exploration of data flow and processing we need to set some parameters for like,

    1. Sequence Length of the input: **seq_len**
    2. Embedding Dimension: **emb_dim**
    3. Token Embedding
    4. Position Embeddings
    5. Total Number of heads: **heads = 8**
    6. Number of tokens: **total number of tokens in a vocabulary**

Here we are considering a single input, so there is only one batch with a single input, hence if we can assume `batch_size = 1`.

In [ ]:
seq_length = 128
emb_dim = 256
torch.manual_seed(1111)
num_tokens = 1000
batch_size = 1

In [ ]:
token_embedding = nn.Embedding(num_embeddings = num_tokens, embedding_dim = emb_dim)
position_embedding = nn.Embedding(num_embeddings = seq_length, embedding_dim = emb_dim)

In [ ]:
input_ = torch.randint(1, 128, (batch_size, seq_length))

In [ ]:
input_.size()

torch.Size([1, 128])

The summation of Token Embedding and Position embedding is being transferred to the Self-Attention Block

In [ ]:
input_tok = token_embedding(input_)

In [ ]:
input_tok.size()

torch.Size([1, 128, 256])

In [ ]:
b, t, e = input_tok.size()

In [ ]:
input_pos = position_embedding(torch.arange(t))[None, :, :].expand(b, t, e)

In [ ]:
input_tnsr = input_tok + input_pos

In [ ]:
input_tnsr.size()

torch.Size([1, 128, 256])

### Now, let's come to the **Self-Attention** block

The attention function is mapping a _query_ to a set _key-value_ pairs to an output, and query, keys, and values are all vectors

In [ ]:
toqueries = nn.Linear(emb_dim, emb_dim * heads)
tokeys = nn.Linear(emb_dim, emb_dim * heads)
tovalues = nn.Linear(emb_dim, emb_dim * heads)

In [ ]:
keys = tokeys(input_tnsr).view(b, t, heads, e)
values = tovalues(input_tnsr).view(b, t, heads, e)
queries = toqueries(input_tnsr).view(b, t, heads, e)

In [ ]:
keys.size(), values.size(), queries.size()

(torch.Size([1, 128, 8, 256]),
 torch.Size([1, 128, 8, 256]),
 torch.Size([1, 128, 8, 256]))

Folding heads into the batch dimension

In [ ]:
keys = keys.transpose(1, 2).contiguous().view(b * heads, t, e)
values = values.transpose(1, 2).contiguous().view(b * heads, t, e)
queries = queries.transpose(1, 2).contiguous().view(b * heads, t, e)

In [ ]:
keys.size(), values.size(), queries.size()

(torch.Size([8, 128, 256]),
 torch.Size([8, 128, 256]),
 torch.Size([8, 128, 256]))

In [ ]:
queries = queries / (e ** (1/4))
keys = keys / (e ** (1/4))

Scaled dot-product attention

In [ ]:
dot_p = torch.bmm(queries, keys.transpose(1, 2))

In [ ]:
dot_p.size()

torch.Size([8, 128, 128])

In [ ]:
dot_p

tensor([[[ 5.1956e-01,  6.0915e-01, -6.4917e-01,  ...,  8.6953e-01,
          -6.6157e-02,  7.3565e-01],
         [ 6.6372e-01,  3.4092e-01, -8.4028e-01,  ..., -1.5064e-01,
           5.8384e-01,  2.6677e-01],
         [ 5.0197e-01, -1.5173e+00, -6.3352e-01,  ..., -3.9618e-01,
          -3.3248e-01, -1.4102e+00],
         ...,
         [ 5.8578e-01, -2.9937e-01, -3.8499e-01,  ...,  7.2306e-01,
          -6.8306e-01, -1.1674e+00],
         [-1.3212e+00,  3.1136e-01,  1.0709e+00,  ...,  4.4921e-01,
          -2.3707e-01,  5.8058e-01],
         [ 1.0418e-02,  3.7939e-01, -2.1556e+00,  ..., -3.2824e-01,
          -2.8023e-01,  3.1092e-01]],

        [[-1.0443e+00, -2.6809e-01,  4.0622e-01,  ..., -7.8090e-02,
           1.4918e+00, -1.8269e-01],
         [-3.7174e-01, -1.4060e+00, -9.4674e-01,  ...,  2.1743e-02,
          -1.0852e+00, -8.4526e-01],
         [-5.1599e-01, -3.8519e-01, -5.2806e-01,  ..., -2.7864e-01,
          -2.5623e-01,  3.9407e-01],
         ...,
         [-1.5574e+00,  4

Masking the upper half of the self-attention dot-product excluding the diagonal 

In [ ]:
batch, height, width = dot_p.size()

In [ ]:
indexes = torch.triu_indices(height, width, offset = 0)

In [ ]:
maskval = float('-inf')

In [ ]:
dot_p[:, indexes[0], indexes[1]] = maskval

In [ ]:
dot_p.size()

torch.Size([8, 128, 128])

In [ ]:
# dot_p

In [ ]:
dot = F.softmax(dot_p, dim = 2)

In [ ]:
dot.size()

torch.Size([8, 128, 128])

The computed self-attention is now being applied to the values vector

In [ ]:
out = torch.bmm(dot, values).view(b, heads, t, e)

In [ ]:
out.size()

torch.Size([1, 8, 128, 256])

All the _eight_ heads are back and now we need to unify them using a linear layer 

We need to swap back the heads with the batch size to unify all the heads

In [ ]:
out = out.transpose(1, 2).contiguous().view(b, t, heads * emb_dim)

In [ ]:
out.size()

torch.Size([1, 128, 2048])

In [ ]:
unifyheads = nn.Linear(emb_dim * heads, emb_dim)

In [ ]:
input_tnsr_tfmb = unifyheads(out)

In [ ]:
# output from the self-attention block
input_tnsr_tfmb.size()

torch.Size([1, 128, 256])

# Transformer Block

The **Transformer Block** consists of _n_ **Transformers** in a _Sequential_ order. 

Here, we will look at a single such **Transformer**

<a href="https://imgur.com/xSJD2TR"><img src="https://i.imgur.com/xSJD2TR.png" title="source: imgur.com" /></a>

A **Transformer** contains the following things,

    1. Self-Attention (The one explored above)
    2. LayerNorm-1
    3. LayerNorm-2
    4. Feed-Forward Network

The feed-forward network contains a **`feed_forward_multiplication_factor`**

In [ ]:
feed_forward_mult = 3
norm_1 = nn.LayerNorm(emb_dim)
norm_2 = nn.LayerNorm(emb_dim)
feed_forward = nn.Sequential(
    nn.Linear(emb_dim, feed_forward_mult * emb_dim),
    nn.ReLU(),
    nn.Linear(feed_forward_mult*emb_dim, emb_dim)

)

In [ ]:
attention_output = input_tnsr_tfmb

In [ ]:
attention_output.size()

torch.Size([1, 128, 256])

In [ ]:
op_ = norm_1(attention_output)

In [ ]:
op_.size()

torch.Size([1, 128, 256])

In [ ]:
op_ = feed_forward(op_)

In [ ]:
op_.size()

torch.Size([1, 128, 256])

In [ ]:
op_ = norm_2(op_)

In [ ]:
op_.size()

torch.Size([1, 128, 256])

In [ ]:
# Output from transformer block
op_

tensor([[[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [-1.0195,  0.0225,  0.9978,  ..., -0.9235, -1.5130, -0.2702],
         [-0.4327, -1.2847,  0.7194,  ...,  0.2320, -1.9380, -0.3782],
         ...,
         [-0.8535,  0.1518,  2.2584,  ..., -0.1849, -1.5483,  1.2892],
         [-1.4051,  0.1749,  1.5635,  ..., -0.1841, -1.1818,  0.5500],
         [-0.4399,  0.8001,  2.0931,  ..., -0.5775, -0.9920,  0.6277]]],
       grad_fn=<NativeLayerNormBackward>)

# Generator Block

The **Generator Block** contains the following, 

    1. Token Embedding (Used earlier in the Self-Attention)
    2. Position Embedding (Used earlier in the Self-Attention)
    3. TransformerBlock (Here we will be using only one Transformer and this can be adjusted with the **depth** parameter)
    4. Probabilities Linear


<a href="https://imgur.com/nJf7hYk"><img src="https://i.imgur.com/nJf7hYk.png" title="source: imgur.com" /></a>

Here, we are taking `seq_length=128` i.e. _128_ tokens will be generated. Therefore, we map the output from the **Transformer Block** using a _Linear Layer_ to 1000 neurons. 

Now we take a `log_softmax` over the output of the _Linear Layer_ on the second dimension. 

In [ ]:
toprobabilities = nn.Linear(emb_dim, num_tokens)d

In [ ]:
out_probs = toprobabilities(op_.view(b*t, e))

In [ ]:
out_probs.size()

torch.Size([128, 1000])

In [ ]:
out_probs = out_probs.view(b, t, num_tokens)

In [ ]:
out_probs.size()

torch.Size([1, 128, 1000])

In [ ]:
all_op_tokens = F.log_softmax(out_probs, dim = 2)

In [ ]:
all_op_tokens.size()

torch.Size([1, 128, 1000])

In [ ]:
len((torch.topk(all_op_tokens, k = 1, dim = 2))[1].flatten().tolist())

128

# Classification Block

<a href="https://imgur.com/n0KeriK"><img src="https://i.imgur.com/n0KeriK.png" title="source: imgur.com" /></a>

In [ ]:
max_pool = True

In [ ]:
NUM_CLASSES = 20

In [ ]:
toprobabilities = nn.Linear(emb_dim, NUM_CLASSES)

In [ ]:
tt_ = op_.max(dim=1)[0] if max_pool else op_.mean(dim=1)

In [ ]:
tt_.size()

torch.Size([1, 256])

In [ ]:
tt_ = toprobabilities(tt_)

In [ ]:
tt_.size()

torch.Size([1, 20])

In [ ]:
tt_ = F.log_softmax(tt_, dim = 1)

In [ ]:
tt_.size()

torch.Size([1, 20])

In [ ]:
# max_pool = Flase 
torch.topk(tt_, 3)

torch.return_types.topk(
values=tensor([[nan, nan, nan]], grad_fn=<TopkBackward>),
indices=tensor([[12, 14, 13]]))

In [ ]:
# max_pool = True
torch.topk(tt_, 3)

torch.return_types.topk(
values=tensor([[nan, nan, nan]], grad_fn=<TopkBackward>),
indices=tensor([[12, 14, 13]]))