# Attention Mechanism

> Fill in a module description here

In [1]:
#| default_exp transformer.attention

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| hide
import nbdev; nbdev.nbdev_export()

In [4]:
#| export
import math

import torch
from torch import nn
import torch.nn.functional as F
from fastcore.foundation import docs

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def addition(a,b):
    "Adds two numbers together"
    return a+b

#| explain a+b

We take the sum of a and b, which is written in python with the "+" symbol

In [None]:
#| exports
@docs
class A:
    def __init__(self):
        pass
    _docs = dict(cls_doc="")

dasdasd

#| explain "pass"

We take the sum of a and b, which is written in python with the "+" symbol

In [None]:
#| exports
@docs
class PrepareForMultiHeadAttention(nn.Module):
    def __init__(self, d_model, heads, d_k, bias):
        super().__init__()
        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
        self.heads = heads
        self.d_k = d_k

    def forward(self, x):
        head_shape = x.shape[:-1]
        
        x = self.linear(x)
        x = x.view(*head_shape, self.heads, self.d_k)
        
        return x
    
    _docs = dict(cls_doc="",
                 forward="yyy")

#| explain "self.heads = heads"

dasdasd

### Attention

$\operatorname{Attention}(q, k, v)=\operatorname{softmax}\left(\frac{q k^T}{\sqrt{d_k}}\right) v$

Given

- $q$: the query vector
- $k$: the key vector
- $v$: the value vector

In [None]:
#| exports
def _calculate_attention(q, k, v, mask=None):
    d_k = k.shape[-1]

    score = (q @ k.T) / math.sqrt(d_k)
        
    score = F.softmax(score, dim=-1)
    attention_score = score @ v
    
    return attention_score

#| explain "score = (q @ k.T) / math.sqrt(d_k)"

Do dot product between vectors $q k^T$ and then divide by the dimension

**Example**

Suppose there're three words, each word has it owns query, key and value vectors. All vector have the same dimension - 5.

In [None]:
q1, k1, v1 = torch.randn(5), torch.randn(5), torch.randn(5)
q2, k2, v2 = torch.randn(5), torch.randn(5), torch.randn(5)
q3, k3, v3 = torch.randn(5), torch.randn(5), torch.randn(5)

In [None]:
q = torch.stack([q1, q2, q3], dim=0)

In [None]:
q

tensor([[-1.0566,  0.6788,  0.4330,  0.0026, -1.6254],
        [-0.7342,  0.2160,  0.0464, -1.0750, -1.7149],
        [ 2.2512, -0.8535,  0.1335, -0.0290, -0.8533]])

In [None]:
k = torch.stack([k1, k2, k3], dim=0)

In [None]:
v = torch.stack([v1, v2, v3], dim=0)

In [None]:
_calculate_attention(q, k, v)

tensor([[ 0.0230, -1.6170, -1.1616,  0.6197,  0.1060],
        [ 0.0221, -1.5535, -1.1491,  0.5655,  0.1212],
        [ 0.1577, -0.6645, -1.3145,  0.2053,  0.0691]])

### Multi-head Attention

In practice, we don't compute each attention score at once, but we concentrate all the `key` to one matrix, same for `value` and `query`. That's why it called Multi-head attention. Just stack multiple attention layers and calcualte at once.

$$\operatorname{Attention}(Q, K, V)=\underset{\text { seq }}{\operatorname{softmax}}\left(\frac{Q K^{\top}}{\sqrt{d_k}}\right) V$$

- `d_model`: the number of features in `query`, `key`, and `value` vectors.
- `n_head`: the number of attention layers.
- `d_k`: the number of dimension of each vector in each head

In [None]:
# #| exports
# @docs
# class MultiHeadAttention(nn.Module):
#     def __init__(
#         self,
#         heads: int,
#         d_model: int,
#         dropout_prop: float=0.1,
#         bias: bool = True
#     ):
#         super().__init__()
#         self.d_k = d_model // heads
        
#         self.heads = heads
        
#         self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
#         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
#         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
        
#         self.softmax = nn.Softmax(dim=1)
#         self.output = nn.Linear(d_model, d_model)
#         self.dropout = nn.Dropout(dropout_prop)
#         self.scale = 1 / math.sqrt(self.d_k)
        
#         self.attention = None
    
#     def get_scores(self, query, key):
#         return self.query @ self.key.T
        
#     _docs = dict(cls_doc="Calculate the multi-head attention",
#                  get_scores="Calculate the score")

#### Okay. Then where do we get those query, key... vectors?

So the key, value and query vector determines the attention score. We need someway to optimize it

In [None]:
def _initialize_weight(d_model):
    return nn.Linear(d_model, d_model)

In [None]:
# #| exports
# def _split_by_heads(tensor, n_head):
#     batch_size, length, d_model = tensor.size()
#     d_tensor = d_model // n_head
    
#     return tensor.view(batch_size, length, n_head, d_tensor).transpose(1, 2)

#| explain "d_tensor = d_model // n_head"

so here's

In [None]:
# #| exports
# @docs
# class MultiHeadAttention(nn.Module):
#     def __init__(self, d_model, n_head):
#         super().__init__()
#         self.n_head = n_head
    
#         self.w_q = nn.Linear(d_model, d_model)
#         self.w_k = nn.Linear(d_model, d_model)
#         self.w_v = nn.Linear(d_model, d_model)
            
#     def split(self, tensor):
#         batch_size, length, d_model = tensor.size()
#         d_tensor = d_model // self.n_head
        
#         tensor = tensor.view(batch_size, length, self.n_head, d_tensor)
#         tensor = tensor.transpose(1, 2)
        
#         return tensor

#     def forward(self, q, k, v, mask = None):
#         p, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        
#         # 2. split tensor by number of heads
#         q, k, v = self.split(q), self.split(k), self.split(v)
        
#         # 3. do scale dot product to compute similarity
#         out, attention = self.attention(q, k, v, mask=mask)
        
#         # 4. concat and pass to linear
#         out = self.concat(out)
#         out = self.w_concat(out)
        
#         return out
    
#     _docs = dict(cls_doc="", split="", forward="")

https://github.com/hyunwoongko/transformer

In [None]:
# MultiHeadAttention(d_model=5, n_head=3)

### How multi-head attention works?

Suppose the sentence: "Persistent is all you need"

In [None]:
text = "Persistence is all you need"

There're five words in this sentence

Each word represented by an vector has length `5` numbers in it (aka: dimension 5)

In [None]:
w1, w2 = torch.randn(5), torch.randn(5)

In [None]:
w3, w4, w5 = torch.randn(5), torch.randn(5), torch.randn(5)

In [None]:
w1

tensor([-1.2619,  0.4211,  0.6655,  1.5499,  0.9246])

##### First we create a matrix

What i don't understand:
- why each forward take `q`, `k`, `v`, not word?
- what is `d_model`
- why do split
- why do concat
- why dot w_concat
- why attention return key

### Attention

- `num_heads`: this is the number of heads used in the multi-head attention operation. Each head performs attention on a different subset of the keys, values, and queries. 
- `d_model`:  the dimensionality of the input and output tensors in the multi-head attention operation
- `d_k`: this is the dimensionality of the keys and values used in the multi-head attention operation

##### `d_k`

`d_k` is the dimensionality of the keys and values in the multi-head attention operation

For example, if `d_model` is 256 and `num_heads` is 4, `d_k` would be 64.

The reason for calculating `d_k` in this way is to ensure that the keys and values are split evenly among the different heads.

Keys and values need to be split evenly among the different heads in a multi-head attention operation because each head will use its own set of keys and values to calculate the attention weights.

##### `key` layer

In [None]:
# import ipdb

In [None]:
# class MultiHeadAttention(nn.Module):
#     def __init__(self, d_model, num_heads):
#         super().__init__()

#         # Save the number of heads and the dimensionality of the model.
#         self.num_heads = num_heads
#         self.d_model = d_model

#         # Calculate the dimensionality of the keys and values.
#         self.d_k = d_model // num_heads

#         # Create the linear layers for the keys, values, and queries.
#         self.key_layer = nn.Linear(d_model, num_heads * self.d_k)
#         self.value_layer = nn.Linear(d_model, num_heads * self.d_k)
#         self.query_layer = nn.Linear(d_model, num_heads * self.d_k)

#     def linear_layer(self, input_tensor):
#         # Apply the linear layer to the input tensor to get the output.
#         return self.linear_layer(input_tensor)

#     def split_heads(self, input_tensor):
#         # Split the input tensor into multiple heads along the last dimension.
#         return input_tensor.reshape(input_tensor.shape[0], -1, self.num_heads, self.d_k)

#     def dot_products(self, query, key):
#         # Calculate the dot product of the query with the key for each head.
#         return torch.einsum('bjhd,bkhd->bhjk', query, key)

#     def scale_dot_products(self, dot_products):
#         # Scale the dot products by the dimensionality of the keys.
#         return dot_products / self.d_k**0.5

#     def apply_weights(self, dot_products):
#         # Apply the attention mask and softmax to the dot products to get the weights.
#         return dot_products.softmax(dim=-1)

#     def weighted_sum(self, weights, value):
#         # Calculate the weighted sum of the values for each head.
#         return torch.einsum('bhjk,bkhd->bjhd', weights, value)

#     def concatenate_heads(self, output):
#         # Concatenate the outputs from each head along the last dimension.
#         return output.reshape(output.shape[0], -1, self.num_heads * self.d_k)

#     def forward(self, query, key, value):
#         # Transform the keys, values, and queries using the linear layers.
#         key = self.key_layer(key)
#         ipdb.set_trace()

#         value = self.value_layer(value)
#         query = self.query_layer(query)
        
#         ipdb.set_trace()
        
#         # Split the keys, values, and queries into multiple heads.
#         key = self.split_heads(key)
#         value = self.split_heads(value)
#         query = self.split_heads(query)
        
#         ipdb.set_trace()

#         # Calculate the dot product of the query with the key for each head
#         dot_products = self.dot_products(query, key)
        
#         ipdb.set_trace()

#         # Scale the dot products by the dimensionality of the keys.
#         dot_products = self.scale_dot_products(dot_products)
        
#         ipdb.set_trace()

#         # Apply the attention mask and softmax to the dot products to get the weights.
#         weights = self.apply_weights(dot_products)
        
#         ipdb.set_trace()

#         # Calculate the weighted sum of the values for each head.
#         output = self.weighted_sum(weights, value)

#         # Concatenate the outputs from each head along the last dimension.
#         output = self.concatenate_heads(output)

#         return output

Suppose we have a four different sentence (aka: batch size), each 

For example, if the book has 1000 words and the word you are looking for has 5 letters.

In [None]:
# # Create some random tensors for the query, key, and value.
# query_tensor = torch.randn(4, 5, 256)
# key_tensor = torch.randn(4, 7, 256)
# value_tensor = torch.randn(4, 7, 256)

In [None]:
# # Create a multi-head attention module with 4 heads and a dimensionality of 256.
# attention = MultiHeadAttention(256, 4)

# # Perform multi-head attention on some input tensors.
# output = attention(query_tensor, key_tensor, value_tensor)

In [None]:
# output.shape

In [None]:
# torch.randn(4, 5, 256) * torch.randn(256, 256)

### Multi-head Attention

In [5]:
#| export
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int = 4, num_heads: int = 2, dropout: float = 0.3):
        super().__init__()

        # d_q, d_k, d_v
        self.d_h: int = d_model // num_heads


        self.d_model = d_model
        self.num_heads = num_heads

        self.dropout = nn.Dropout(dropout)

        ##create a list of layers for K, and a list of layers for V
        self.linear_Qs = nn.ModuleList([nn.Linear(d_model, self.d_h)
                                        for _ in range(num_heads)])
        self.linear_Ks = nn.ModuleList([nn.Linear(d_model, self.d_h)
                                        for _ in range(num_heads)])
        self.linear_Vs = nn.ModuleList([nn.Linear(d_model, self.d_h)
                                        for _ in range(num_heads)])

        self.mha_linear = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask=None):
        # shape(Q) = [B x seq_len x D/num_heads]
        # shape(K, V) = [B x seq_len x D/num_heads]

        Q_K_matmul = torch.matmul(Q, K.permute(0, 2, 1))
        scores = Q_K_matmul/math.sqrt(self.d_h)
        # shape(scores) = [B x seq_len x seq_len]
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = F.softmax(scores, dim=-1)
        # shape(attention_weights) = [B x seq_len x seq_len]

        output = torch.matmul(attention_weights, V)
        # shape(output) = [B x seq_len x D/num_heads]

        return output, attention_weights

    def forward(self, pre_q, pre_k, pre_v, mask=None):
        # shape(x) = [B x seq_len x D]

        Q = [linear_Q(pre_q) for linear_Q in self.linear_Qs]
        K = [linear_K(pre_k) for linear_K in self.linear_Ks]
        V = [linear_V(pre_v) for linear_V in self.linear_Vs]
        # shape(Q, K, V) = [B x seq_len x D/num_heads] * num_heads

        output_per_head = []
        attn_weights_per_head = []
        # shape(output_per_head) = [B x seq_len x D/num_heads] * num_heads
        # shape(attn_weights_per_head) = [B x seq_len x seq_len] * num_heads
        
        for Q_, K_, V_ in zip(Q, K, V):
            
            ##run scaled_dot_product_attention
            output, attn_weight = self.scaled_dot_product_attention(Q_, K_, V_, mask)
            # shape(output) = [B x seq_len x D/num_heads]
            # shape(attn_weights_per_head) = [B x seq_len x seq_len]
            output_per_head.append(output)
            attn_weights_per_head.append(attn_weight)

        output = torch.cat(output_per_head, -1)
        attn_weights = torch.stack(attn_weights_per_head).permute(1, 0, 2, 3)
        # shape(output) = [B x seq_len x D]
        # shape(attn_weights) = [B x num_heads x seq_len x seq_len]
        
        projection = self.dropout(self.mha_linear(output))

        return projection, attn_weights

In [None]:
#| export
class ScaleDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, q, k, v, mask=None):
        
        batch_size, head, n_words, d_head = k.size()