# Attention Mechanism

> Fill in a module description here

In [None]:
#| default_exp attention

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"


In [None]:
#| export
import math

import torch
from torch import nn
from transformers import BertTokenizer, BertModel

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)

#### Process data

the sentence would need to be converted into numerical representation, such as a vector or a tensor, that can be processed by the function

The sentence "Persistence is all you need" could be represented as a tensor with dimensions `(5, d)`, where `d` is the dimensionality of the word vectors. The tensor would contain the word vectors for each of the words in the sentence, in the same order as they appear in the sentence.

In [None]:
# Define the sentence to be converted
sentence = "persistence is all you need"

In [None]:
# Tokenize the sentence into individual words
tokens = sentence.split()

# Embed each token into a word vector
# using a pre-trained word embedding model
word_vectors = []
for token in tokens:
    word_vector = model.embed(token)
    word_vectors.append(word_vector)

# Convert the list of word vectors into a tensor
# with dimensions (sequence_length, d)
sequence_length = len(tokens)
d = word_vectors[0].size(0)
tensor = torch.stack(word_vectors, dim=0)
tensor = tensor.view(sequence_length, d)

# The resulting tensor can now be used as input
# to a machine learning model, such as the multi_head_attention function

NameError: name 'model' is not defined

#### Split heads

In [None]:
def split_heads(x, num_heads, d_k):
    # Split the tensor along the last dimension
    # into num_heads tensors of shape (batch_size, sequence_length, d_k)
    x = x.view(x.size(0), x.size(1), num_heads, d_k)
    return x.permute(0, 2, 1, 3)

#| explain "x = x.view(x.size(0), x.size(1), num_heads, d_k)"

`x.size(0), x.size(1)`: keep the batch_size and sequence length of a tensor and `num_heads`

#| explain "return x.permute(0, 2, 1, 3)"

Change the shape of tensor to use (`batch_size`, `num_heads`, `sequence_length`, `d_k`)

#### Combine heads

In [None]:
def combine_heads(x, num_heads, d_k):
    # Combine the tensors along the last two dimensions
    # into a tensor of shape (batch_size, sequence_length, d_model)
    x = x.permute(0, 2, 1, 3).contiguous()
    return x.view(x.size(0), x.size(1), num_heads * d_k)

#### Attention

In [None]:
def attention(query, key, value, d_k):
    # Calculate the dot product attention
    scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d_k)
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, value)

#| explain "torch.matmul(query, key.transpose(-1, -2))"

In order to compute the dot product between the `query` and `key` tensors, the `key` tensor needs to be transposed so that its dimensions match the dimensions of the `query` tensor. Specifically, the `key` tensor needs to be transposed along the last two dimensions, so that its dimensions become `(batch_size, d_k, sequence_length)`. Producing an output tensor of dimensions `(batch_size, sequence_length, sequence_length)`.

In [None]:
def multi_head_attention(d_model, num_heads, query_tensor, key_tensor, value_tensor):
    d_k = d_model // num_heads

    key_layer = nn.Linear(d_model, num_heads * d_k)
    value_layer = nn.Linear(d_model, num_heads * d_k)
    query_layer = nn.Linear(d_model, num_heads * d_k)

    # Apply the linear layers to the input tensors
    key_tensor = key_layer(key_tensor)
    value_tensor = value_layer(value_tensor)
    query_tensor = query_layer(query_tensor)

    # Split the tensors into multiple heads
    key_tensor = split_heads(key_tensor, num_heads, d_k)
    value_tensor = split_heads(value_tensor, num_heads, d_k)
    query_tensor = split_heads(query_tensor, num_heads, d_k)

    # Apply attention to each head
    attention_output = attention(query_tensor, key_tensor, value_tensor, d_k)

    # Combine the attention output from each head
    attention_output = combine_heads(attention_output, num_heads, d_k)

    return attention_output

In [None]:
# Set the hyperparameters
d_model = 256
num_heads = 8

- `query_tensor`: (batch_size=4, sequence_length=5, d_model=256)
- `key_tensor`: (batch_size=4, sequence_length=7, d_model=256)
- `value_tensor`: (batch_size=4, sequence_length=7, d_model=256)

In [None]:
# Create some random input tensors
query_tensor = torch.randn(4, 5, d_model)
key_tensor = torch.randn(4, 7, d_model)
value_tensor = torch.randn(4, 7, d_model)

In [None]:
# Apply multi-head attention
attention_output = multi_head_attention(d_model, num_heads, query_tensor, key_tensor, value_tensor)

# Print the output shape
print(attention_output.shape)  # should be (4, 5, 256)

torch.Size([4, 5, 256])


### Multi-Head Attention 