To code out a sample attention block which will be later used inside transformer block to build an llm, coding out the attention mechanism in steps -

1. Simple Self Attention 
2. Self Attention with trainable weights
3. Causal Attention and Dropout
4. Single Head to Multi Head Attention

Why Attention?

Starting with an input embedding of a sentance - "Your journey starts with one step", embedding dimension being 3 (emb_dim=3)

In [None]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

<pre>To calculate context vector for a token, lets say the 3rd token - 'starts':
    1. calculate attention scores of this token wrt to every other token by dot product
    2. normalize attention scores to get attention weights
    3. addition of (attention weight * input token embedding) -> context vector of this token
</pre>
![Context Vector of a token](self-attention-1.png)


For single input token

In [None]:
# the dimension of attention score matrix should be the number of input tokens in our example sentence
query = inputs[2]

# getting attentions scores
attention_scores_3 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attention_scores_3[i] = torch.dot(x_i, query)

# getting attention weights
# attention_weights_3 = attention_scores_3/attention_scores_3.sum()
attention_weights_3 = torch.softmax(attention_scores_3, dim=0)

# getting context vector
context_vector_3 = torch.zeros(query.shape)
for i, x_emd in enumerate(inputs):
    context_vector_3 += attention_weights_3[i]*x_emd

print(context_vector_3)


tensor([0.4431, 0.6496, 0.5671])


For the whole input

In [None]:
attention_scores = inputs @ inputs.T
attention_weights = torch.softmax(attention_scores, dim=-1)
context_vectors = attention_weights @ inputs

print(context_vectors)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


Calculating Self-Attention with trainable weights

![Self-Attention](self-attention-2.png)

Self-Attention with trainable weights

In [None]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [None]:
import torch.nn as nn
d_in = 3 # input dim for trainable weights must be same as embedding dimension for matrix multiplication to work
d_out=2 # this is in our hand, mostly d_in and d_out is kept same; but here we are taking the output dim of weights/context vector to be 2

class SelfAttention(nn.Module):
    # initialise trainable weights, then write a forward function that does the attention weights calculation
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        # get attention scores with dot product
        # normalize to get attention weights
        # addition of (attention weights * embedding input tokens (or values in this case)) gives context vector

        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores/keys.shape[-1]**0.5, dim=-1)
        context_vectors = attention_weights @ values
        return context_vectors

torch.manual_seed(123)
sa = SelfAttention(d_in, d_out)
print(sa(inputs))



tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


Now instead of using nn.Parameters when creating those trainable weights, we will use nn.Linear because -
1. it has better weight initialization
2. takes care of bias addition and broadcasting (it we want to have a bias) and transpose and such small matrix mul operations

so doing nn.Linear(d_in, d_out, qkv_bias=False) is effectively -> x@W_query, which we were doing manually when we are using nn.Parameter -> since no bias we are adding. Essentially nn.Linear does the forward pass itself (its a high level function) whereas when using nn.Parameter, its a low-level LEGO block where we are doing the forward pass (the x@W_query) manually as written above.

QKV_BIAS -> adding bias to all three linear layers - Query, Key, Value, just a normal bias like -> y = x@W.T + b
Adding a bias or not depends on architecture and is not present in original transformers or GPT architectures, present in vision transformers though

In [None]:
import torch.nn as nn

class SelfAttentionV2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)


    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores/keys.shape[1]**0.5, dim=-1)
        context_vectors = attention_weights@values
        return context_vectors

torch.manual_seed(789)
sa = SelfAttentionV2(3, 2)
print(sa(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


Causal Attention with Dropout

![Causal](causal-attention-1.png)

![causal2](causal-attention-2.png)

In [None]:
torch.manual_seed(789)
queries = sa.W_query(inputs)
keys = sa.W_key(inputs)
attention_scores = queries @ keys.T

context_length = attention_scores.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
attention_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
print(attention_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


- In addition, we also apply dropout to reduce overfitting during training
- Dropout can be applied in several places:
  - for example, after computing the attention weights;
  - or after multiplying the attention weights with the value vectors
- Here, we will apply the dropout mask after computing the attention weights because it's more common

In [None]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) # 2 inputs with 6 tokens each, and each token has embedding dimension 3

torch.Size([2, 6, 3])


In [11]:
import torch.nn as nn

class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) 
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) 

    def forward(self, x):
        batch, num_tokens, d_in = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attention_scores = queries @ keys.transpose(1,2)
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attention_weights = torch.softmax(attention_scores/keys.shape[-1]**0.5, dim=-1)
        attention_weights = self.dropout(attention_weights)
        context_vectors = attention_weights @ values
        return context_vectors

torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(3, 2, context_length, 0.0)
context_vectors = ca(batch)
print(context_vectors)


tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


[:num_tokens, :num_tokens] trims the precomputed causal mask to match the actual sequence length in the current batch.