# Chapter 3: Coding Attention Mechanisms
* Will implement four different variations of the attention mechanisms that will build upon eachother, the goal is to arrive at a compact and efficient implementation of multi-head attention

1. Simplified self-attention -> simplified version of self-attention before adding trainable weights
2. Self-attention -> self attention with the trainable weights
3. Casual attention -> adds a mask to self-attention that allows the LLM to generate one word at a time
4. Multi-head attention -> organizes attention and allows model to capture various aspects of the input data in parallel

### 3.3.1 Simple self-attention mechanism without trainable weights

In [7]:
import torch
inputs = torch.tensor(
   [[0.43, 0.15, 0.89], # Your     
    [0.55, 0.87, 0.66], # journey  
    [0.57, 0.85, 0.64], # starts    
    [0.22, 0.58, 0.33], # with
    [0.77, 0.25, 0.10], # one
    [0.05, 0.80, 0.55]] # step
)
inputs.shape

torch.Size([6, 3])

In [8]:
# Compute attention scores (omega), between the query and all other inputs elements as a dot product

query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])

for i, x_i, in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
    print(f"omega 2,{i}: {attn_scores_2[i]}")

omega 2,0: 0.9544000625610352
omega 2,1: 1.4950001239776611
omega 2,2: 1.4754000902175903
omega 2,3: 0.8434000015258789
omega 2,4: 0.7070000171661377
omega 2,5: 1.0865000486373901


In [9]:
# Normalize these attention scores and obtain attention weights (alpha) that sum to 1

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights: ", attn_weights_2_tmp)
print("Sum: ", attn_weights_2_tmp.sum())

Attention weights:  tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum:  tensor(1.0000)


In [10]:
# Softmax is better for normalizing values
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights: ", attn_weights_2_naive)
print("Sum: ", attn_weights_2_naive.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum:  tensor(1.)


In [11]:
# Just use the PyTorch implementation of softmax which has been optimized for performance

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights: ", attn_weights_2)
print("Sum: ", attn_weights_2.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum:  tensor(1.)


In [12]:
# Now, compute context vector z(2), a combination of all input vectors weighted by the attention weights
query = inputs[1]
context_vector = torch.zeros(query.shape)

for i, x_i in enumerate(inputs):
    context_vector += attn_weights_2[i] * x_i
    
print(context_vector)

tensor([0.4419, 0.6515, 0.5683])


### 3.3.2 Computing attention weights for all input tokens

In [14]:
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [15]:
# forloops are slow, use matrix multiplication

attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [16]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [17]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## 3.4 Implementing self attention with trainable weights
* Now we add trainable weights to this attention mechanism through weight matrices.
* These weight matrices are updated through training so the model can learn to produce "good" context vectors

We will tackle this self-attention mechanism in two steps:
1. First, code it step-by-step like before
2. Second, organize it into a class that can be imported into the LLM architecture

### 3.4.1 Computing attention weights step by step
* Introduce weight vectors W<sub>q</sub>, W<sub>k</sub>, W<sub>v</sub>
* These three matrices are used to project the embedded input token x<sup>(i)</sup>, into query, key, and value vectors, respectively

In [20]:
x_2 = inputs[1] # second input element
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d_out=2

print(x_2)

# note: in GPT-like models, usually d_in and d_out are the same size, but to better calculate the computations we choose d_out=2

tensor([0.5500, 0.8700, 0.6600])


In [21]:
# initialize the weight matrices

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


In [22]:
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [23]:
# The attention scores are computed as a dot product between the query and key vectors, since we are trying to compute the context vector for the second input, the query is derived from that input token

# comput attention score omega22
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


In [24]:
# generalize to all attention scores via matrix mult

attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [25]:
# Now scale the attention scores by dividing them by the square root of the embedding dimension of the keys, then use softmax

d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


In [26]:
# last step is to multiply each value vector with its respective attention weight and then summing them to obtain the context vector

context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


### 3.4.2 Implementing a compact self-attention Python class

In [28]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec
        

In [29]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [30]:
# can improve our implementation by using nn.Linear layers instead of nn.Parameter

import torch.nn as nn
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qvk_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qvk_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qvk_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qvk_bias)
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [31]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


####################### Exercise 3.1 #############################

In [33]:
sa_v1_exercise = SelfAttention_v1(d_in, d_out)
sa_v1_exercise.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)
sa_v1_exercise.W_key = torch.nn.Parameter(sa_v2.W_key.weight.T)
sa_v1_exercise.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)

In [34]:
print(sa_v1_exercise(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


#############################################################

## 3.5 Hiding future words with casual attention
* You want the self-attention mechanism to consider only the tokens that appear prior to the current position when predicting the next token in a sequence -> this is called "causual attention" or "masked attention"
* Contrasts the self-attention we implemented above which gives access to the entire input sequence at once

### 3.5.1 Applying a casual attention mask

In [38]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [39]:
# use the PyTorch tril function to create a mask

context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [40]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [41]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
print(row_sums)
print("\n")
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[0.1921],
        [0.3700],
        [0.5357],
        [0.6775],
        [0.8415],
        [1.0000]], grad_fn=<SumBackward1>)


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


In [42]:
# can implement the computation of the masked attention weights more efficiently in fewer steps

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
print(mask,"\n")
masked = attn_scores.masked_fill(mask.bool(), -torch.inf) 
# masked_fill fills elements of self tensor with value where mask is True 
# .bool() replaces values in tensor that are 0 with False, and 1 with True
# where 0's (False) are in the mask, it will keep original values, where 1's (True) are in the mask, will replace with -inf

print(masked)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]]) 

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [43]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


### 3.5.2 Masking additional attention weights with dropout
* <b>Dropout</b> is a deep learning technique where randomly selected hidden layer units are ignored during training, effectively "dropping" them out
* This method helps overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units
* Only done during training and disabled afterwards

In [45]:
# implementing dropout mechanism after calculating the attention weights, and using a dropout rate of 50%

# first to a 6x6 tensor of 1's for simplicity
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)

'''
with a dropout rate (p) of 50%, half of the elements in the matrix are randomly set to zero. To comepnsate for the reduction in active elements, 
the values of the remaining elemnts in the matrix are scaled up by a factor of (1 / 1 - p)
'''

print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [46]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


### 3.5.3 Implementing a compact casual attention class

In [48]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [49]:
class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
            b, num_tokens, d_in = x.shape
            keys = self.W_key(x)
            queries = self.W_query(x)
            values = self.W_value(x)

            attn_scores = queries @ keys.transpose(1, 2)
            attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
            attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
            attn_weights = self.dropout(attn_weights)
            context_vec = attn_weights @ values
            return context_vec
            

In [50]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CasualAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


## 3.6 Extending single-head attention to multi-head attention
* "multi-head" refers to dividing the attention mechanism into multiple "heads", each operating independently
* In this context, a single casual attention module can be considered single-head attention, where there is only one set of attention weights processing the inputs sequentially

### 3.6.1 Stacking multiple single-head attention layers
* In practical terms, implementing multi-head attention involves creating multiple instances of the self-attention mechanism, each with its own weights, and then combining their outputs

In [53]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([CasualAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [54]:
torch.manual_seed(123)
context_length = batch.shape[1] # number of tokens
d_in, d_out = 3, 2

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)

print(context_vecs)

'''
The first dimension of context_vecs is 2 since there are two input texts in the batch
The second dimension refers to the 6 input tokens in each input
The third dimension refers to the four-dimensional embedding of each token
'''

print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


### 3.6.2 Implementing multi-head attention with weight splits

In [163]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert(d_out % num_heads == 0), \
        "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduces the projection dim to match the desired output dim
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape 
        keys = self.W_key(x)            # shape: (b, num_tokens, d_out)
        queries = self.W_query(x)        # ^
        values = self.W_value(x)         # ^

        print("keys.shape:", keys.shape)
        
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        
        context_vec = self.out_proj(context_vec)

        print("queries:", queries.shape)  # (b, A, num_tokens, d_v)
        print("attn_scores:", attn_scores.shape)  # (b, A, num_tokens, num_tokens)
        print("context_vec before out_proj:", context_vec.shape)  # (b, num_tokens, d_out)
        
        return context_vec

In [165]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
print("batch.shape:", batch.shape, "\n")
d_out = 9
mha = MultiHeadAttention(d_in, d_out,context_length, 0.0, num_heads=3)
context_vecs = mha(batch)
print("\n")
print("context_vecs.shape:", context_vecs.shape, "\n")
print(context_vecs)

batch.shape: torch.Size([2, 6, 3]) 

keys.shape: torch.Size([2, 6, 9])
queries: torch.Size([2, 3, 6, 3])
attn_scores: torch.Size([2, 3, 6, 6])
context_vec before out_proj: torch.Size([2, 6, 9])


context_vecs.shape: torch.Size([2, 6, 9]) 

tensor([[[-0.3019,  0.3177,  0.4985, -0.0611, -0.1675,  0.3447,  0.3815,
           0.4080, -0.3164],
         [-0.3411,  0.3692,  0.5127, -0.0264,  0.0238,  0.2869,  0.4995,
           0.3322, -0.3014],
         [-0.3538,  0.3897,  0.5139, -0.0094,  0.0937,  0.2692,  0.5391,
           0.3010, -0.2925],
         [-0.3461,  0.3929,  0.4836,  0.0283,  0.1188,  0.2660,  0.5234,
           0.2770, -0.2561],
         [-0.3386,  0.3733,  0.4685,  0.0781,  0.1546,  0.2679,  0.5007,
           0.2599, -0.2222],
         [-0.3399,  0.3842,  0.4626,  0.0693,  0.1497,  0.2634,  0.5065,
           0.2580, -0.2231]],

        [[-0.3019,  0.3177,  0.4985, -0.0611, -0.1675,  0.3447,  0.3815,
           0.4080, -0.3164],
         [-0.3411,  0.3692,  0.5127, -0.0264