<a href="https://colab.research.google.com/github/space4VV/LLM_trailblazr/blob/main/chapter3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#print torch version
import torch
print(torch.__version__)


2.4.1+cu121


In [2]:
import torch
inputs = torch.tensor(
   [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)
print(inputs)

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])


- We use input sequence element 2, as an example to compute context vector z
- The context vector is "context"-specific to a certain input


# Compute the attention weights and context vector for input 2



## Step 1 - compute the unnormalized attention scores by computing the dot product between the query and all other input tokens:


In [3]:
query = inputs[1]
print(query)
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i,query)
print(attn_scores_2)

tensor([0.5500, 0.8700, 0.6600])
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


## Step 2 - Normalize the attention scores so that they sum upto 1


In [4]:
attn_weights_2_tmp = attn_scores_2 /attn_scores_2.sum()
print(attn_weights_2_tmp)
print("sum of attn scores:",attn_weights_2_tmp.sum())



tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
sum of attn scores: tensor(1.0000)


However, in practice, using the softmax function for normalization, which is better at handling extreme values and has more desirable gradient properties during training, is common and recommended.

In [5]:
# naive softmax
def softmax_naive(x):
  return torch.exp(x) / torch.exp(x).sum(dim=0)
attn_weights_2_naive = softmax_naive(attn_scores_2)
print(attn_weights_2_naive)
print("sum of attn scores:",attn_weights_2_naive.sum())


tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
sum of attn scores: tensor(1.)


In [6]:
# in practice its better to use the one from torch directly
atten_weights_2 = torch.softmax(attn_scores_2, dim=0)
print(atten_weights_2)
print("sum of attn scores:",atten_weights_2.sum())


tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
sum of attn scores: tensor(1.)


## Step 3 - compute the context vector by multiplying the embedded input tokens, with the attention weights and sum the resulting vectors

In [7]:
print("inputs:",inputs)
print("attention_weights",atten_weights_2)
query =  inputs[1] # 2nd token is the query
context_vec_2  = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += atten_weights_2[i] * x_i
print("context_vector -",context_vec_2)

inputs: tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])
attention_weights tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
context_vector - tensor([0.4419, 0.6515, 0.5683])


# Compute all attention weights and context vectors

In [8]:
attn_scores = torch.empty((inputs.shape[0], inputs.shape[0]))
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [9]:
# matric multiplication vs the above for loops to improve the performance
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [10]:
# now normalize
attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [11]:
# now calculate context vectors
all_context_vectors = attn_weights @ inputs
print(all_context_vectors)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


# Implementing self-attention with trainable weights
There are only slight differences compared to the basic attention mechanism introduced earlier:
- The most notable difference is the introduction of weight matrices that are updated during model training
- These trainable weight matrices are crucial so that the model (specifically, the attention module inside the model) can learn to produce "good" context vectors


In [12]:
# we need 3 weight matrices K,Q,V -  these 3 matrices project the embedded input
# tokens into K;Q;V Vectors

In [13]:
print(inputs)

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])


In [14]:
x_2 = inputs[1]
print(x_2.shape)
d_in = inputs.shape[1]
print("input dim:",d_in)
d_out = 2

torch.Size([3])
input dim: 3


In [15]:
# initialize the 3 weight matrices
torch.manual_seed(123)
w_query = torch.nn.Parameter(torch.randn(d_in, d_out),requires_grad=False)
w_key = torch.nn.Parameter(torch.randn(d_in, d_out),requires_grad=False)
w_value = torch.nn.Parameter(torch.randn(d_in, d_out),requires_grad=False)

In [16]:
query_2 = x_2 @ w_query # _2 because it's with respect to the 2nd input element
key_2 = x_2 @ w_key
value_2 = x_2 @ w_value
print("query:",query_2)
print("key:",key_2)

query: tensor([-1.1729, -0.0048])
key: tensor([-0.1142, -0.7676])


In [17]:
keys = inputs @ w_key
values = inputs @ w_value
print("shape of keys:",keys.shape)
print("shape of values:",values.shape)

shape of keys: torch.Size([6, 2])
shape of values: torch.Size([6, 2])


In [18]:
# step 2, we compute the unnormalized attention scores by computing the dot
#product between the query and each key vector:

keys_2 = keys[1]
print("keys2:",keys_2)
print("query2:",query_2)
attn_score_22 = torch.dot(query_2, keys_2)
print("attention score:",attn_score_22)

keys2: tensor([-0.1142, -0.7676])
query2: tensor([-1.1729, -0.0048])
attention score: tensor(0.1376)


In [19]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([ 0.2172,  0.1376,  0.1730, -0.0491,  0.7616, -0.3809])


In [20]:
# compute attention weights
d_k = keys.shape[1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1704, 0.1611, 0.1652, 0.1412, 0.2505, 0.1117])


In [21]:
# step 4, we now compute the context vector for input query vector 2:

context_vec_2 = attn_weights_2 @ values
print(context_vec_2)



tensor([0.2854, 0.4081])


## Self attention class


In [22]:
import torch.nn as nn
class SelfAttention_V1(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    self.w_query = nn.Parameter(torch.rand(d_in, d_out))
    self.w_key = nn.Parameter(torch.rand(d_in, d_out))
    self.w_value = nn.Parameter(torch.rand(d_in, d_out))

  def forward(self, x):
    keys = x @ self.w_key
    values = x @ self.w_value
    queries = x @ self.w_query
    attn_scores = queries @ keys.T # omega
    attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

    context_vec = attn_weights @ values
    return context_vec

torch.manual_seed(123)
self_attn_v1 = SelfAttention_V1(d_in, d_out)
print(self_attn_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [29]:
# new class using pytorch linear layer
class SelfAttention_V2(nn.Module):
  def __init__(self, d_in, d_out,qkv_bias=False):
    super().__init__()
    self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)

  def forward(self, x):
    keys = self.w_key(x)
    values = self.w_query(x)
    queries = self.w_value(x)
    attn_scores = queries @ keys.T # omega
    attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

    context_vec = attn_weights @ values
    return context_vec

torch.manual_seed(789)
self_attn_v2 = SelfAttention_V2(d_in, d_out)
print(self_attn_v2(inputs))

tensor([[ 0.6733, -0.3186],
        [ 0.6726, -0.3185],
        [ 0.6722, -0.3183],
        [ 0.6739, -0.3189],
        [ 0.6656, -0.3147],
        [ 0.6774, -0.3207]], grad_fn=<MmBackward0>)


## Hiding future words with causal attention
In causal attention, the attention weights above the diagonal are masked, ensuring that for any given input, the LLM is unable to utilize future tokens while calculating the context vectors with the attention weight


In [25]:
# Causal self-attention ensures that the model's prediction for a certain position
#  in a sequence is only dependent on the known outputs at previous positions, not on future positions

In [31]:
queries = self_attn_v2.w_query(inputs)
keys = self_attn_v2.w_key(inputs)
attn_scores = queries @ keys.T
print(attn_scores)

attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
print(attn_weights)

tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
        [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)
tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [33]:
# masking
context_length = attn_scores.shape[0]
print(context_length)
# set elements below the main diagonal (including the diagonal itself) set to 1 and above the main diagonal set to 0:
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

6
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [34]:
masked_simple = attn_weights * mask_simple
print(masked_simple)


tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [38]:
# we should re normalize after masking to ensure rows sums up to 1
row_sums = masked_simple.sum(dim=1, keepdim=True)
print("row_sums: \n",row_sums)
masked_simple_norm = masked_simple / row_sums
print("masked_simple_norm: \n",masked_simple_norm)

row_sums: 
 tensor([[0.1921],
        [0.3700],
        [0.5357],
        [0.6775],
        [0.8415],
        [1.0000]], grad_fn=<SumBackward1>)
masked_simple_norm: 
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


In [41]:
#Instead of zeroing out attention weights above the diagonal and renormalizing the results,
#we can mask the unnormalized attention scores above the diagonal with negative infinity before they enter the softmax function:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores*masked_fill(mask.bool(),-torch.inf)
print(masked)

NameError: name 'masked_fill' is not defined

In [42]:
# adding dropouts after computing attn weights to reduce overfitting
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) #50%
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [43]:
torch.manual_seed(123)
print(dropout(attn_weights))


tensor([[0.3843, 0.3293, 0.3303, 0.3100, 0.3442, 0.3019],
        [0.0000, 0.3318, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.3325, 0.0000, 0.3328, 0.0000],
        [0.3738, 0.3334, 0.0000, 0.0000, 0.0000, 0.3128],
        [0.3661, 0.0000, 0.0000, 0.0000, 0.0000, 0.3169],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


### Putting all to a compact class
One more thing is to implement the code to handle batches consisting of more than one input so that our CausalAttention
class supports the batch outputs produced by the data loader we implemented in chapter 2
- For simplicity, to simulate such batch input, we duplicate the input text example: -->

In [44]:
batch = torch.stack([inputs,inputs],dim=0)
print(batch)
print(batch.shape)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])
torch.Size([2, 6, 3])


In [46]:
class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length,dropout, qkv_bias=False):
    super().__init__()
    self.d_out=d_out
    self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = nn.Dropout(dropout)
    self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

  def forward(self, x):
    b, num_tokens, d_in = x.shape # New batch dimension b
    keys = self.w_key(x)
    queries = self.w_query(x)
    values = self.w_value(x)

    attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
    attn_scores.masked_fill_(  # New, _ ops are in-place
        self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # :num_tokens to account for cases where the number of tokens in the batch is smaller than the supported context_size
    attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
    )
    attn_weights = self.dropout(attn_weights) # New

    context_vec = attn_weights @ values
    return context_vec


torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)


tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


## Multihead attention
The main idea behind multi-head attention is to run the attention mechanism multiple times (in parallel) with different, learned linear projections. This allows the model to jointly attend to information from different representation subspaces at different positions.

In [48]:
class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self, d_in, d_out, num_heads, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.heads = nn.ModuleList(
        [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
    )

  def forward(self, x):
    return torch.cat([h(x) for h in self.heads], dim=-1)

torch.manual_seed(123)
context_lenghth = batch.shape[1] # no of tokens
mha = MultiHeadAttentionWrapper(d_in, d_out,2, context_length, 0.0,)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
# In the implementation above, the embedding dimension is 4, because we d_out=2 as
# the embedding dimension for the key, query, and value vectors as well as the
# context vector. And since we have 2 attention heads, we have the output embedding dimension 2*2=4

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


## Standalone multiattention


In [49]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)


tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
