In [1]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [2]:
input_query = inputs[1]
input_query

tensor([0.5500, 0.8700, 0.6600])

In [3]:
input_1 = inputs[0]
input_1

tensor([0.4300, 0.1500, 0.8900])

In [4]:
torch.dot(input_query, input_1)

tensor(0.9544)

In [5]:
res = 0.
i = 0

res = torch.dot(inputs[i], input_query)
print(res)

tensor(0.9544)


In [6]:
query = inputs[1]

attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, input_query)

print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [7]:
# normalize scores 

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
attn_weights_2_tmp

tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])

In [8]:
# Softmax self-implementation

def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

softmax_naive(attn_scores_2)

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

In [9]:
# torch implementation - use this when you can

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

In [10]:
query = inputs[1] # 2nd input token is the query in this example

context_vec_2 = torch.zeros(query.shape) # empty context vector of same shape as input vectors (for second context vector)

for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i # weighted sum done here

print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


### Simple self-attention mechanism without trainable weights

In [11]:
# Manual implementaton with for loops

attn_scores = torch.empty(6,6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)

print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [12]:
# matrix multiplication to calculate the above MUCH quicker with special notation

attn_scores = inputs @ inputs.T
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [13]:
attn_weights = torch.softmax(attn_scores, dim=1) # weights along the rows (dim 1) add to 1
attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [14]:
# use matrix multiplication to find all 6 context vectors

all_context_vecs = attn_weights @ inputs
all_context_vecs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

#### In one cell:

In [15]:
attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(attn_scores, dim=1)
all_context_vecs = attn_weights @ inputs
all_context_vecs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

## 3.4 Implementing self-attention with trainable weights

### 3.4.1 Computing the attention weights step by step

In [16]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2 # the user can pick these! Can make it the size you want

In [17]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out)) # torch.nn.Parameter is a wrapper around a tensor to make it trainable
W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
W_value = torch.nn.Parameter(torch.rand(d_in, d_out))

In [18]:
query_2 = x_2 @ W_query

query_2

# turning the 3-d input tensor into a 2-d query tensor

tensor([0.4306, 1.4551], grad_fn=<SqueezeBackward4>)

In [19]:
# Query is reused everywhere (when comparing to every other token). However, keys (and values) are unique when used in relation to other tokens. Basically we need to find the key vector of the inputs for every other token in the context (which is full matrix multiplication instead of a vector dotted with a matrix)

keys = inputs @ W_key
values = inputs @ W_value

keys.shape

torch.Size([6, 2])

In [20]:
keys

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]], grad_fn=<MmBackward0>)

In [21]:
keys_2 = keys[1] 
attn_score_22 = torch.dot(query_2, keys_2)
attn_score_22

tensor(1.8524, grad_fn=<DotBackward0>)

#### Calculating attn *scores*

In [22]:
# Can compute the attn score for the entire context (from the reference point of the second token) by matrix multiplying the query embedding of the 2nd token with the tensor of the key values for all the tokens in the context.

attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
       grad_fn=<SqueezeBackward4>)

#### Calculating attn *weights* with softmax

In [23]:
d_k = keys.shape[1]

# normalizing the attention scores with the dimensions of the key matrix

attn_weights_2 = torch.softmax(attn_scores_2 / d_k ** 0.5, dim=-1) # row-wise softmax -> we are normalizing the key vectors entirely (dim = 0 treats the first entry of each key as a column vector then normalizes)
attn_weights_2

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
       grad_fn=<SoftmaxBackward0>)

In [24]:
# normalized weights should sum up to 1

sum(attn_weights_2)

tensor(1., grad_fn=<AddBackward0>)

In [25]:
# context vector is the attn weighted sum of all the value vectors (another matrix multiplication)

context_vec_2 = attn_weights_2 @ values

context_vec_2

tensor([0.3061, 0.8210], grad_fn=<SqueezeBackward4>)

In summary, the queries and keys together are used to find the **attention scores** which are then normalized and softmaxxed into **attention weights**. The values are applied to all the attention weights to compute the **context vectors**.

### 3.4.2 Implementing a compact self-attention class

We are trying now to get the context vector for all the other input tokens (previously we did it just for the second token) by reusing the code from 3.4.1 to make a reusable class.

In [26]:
import torch.nn as nn

class SelfAttention_V1(nn.Module):
    
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out)) # torch.nn.Parameter is a wrapper around a tensor to make it trainable
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out))   

    def forward(self, x):

        queries = inputs @ W_query
        keys = inputs @ W_key
        values = inputs @ W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)

        context_vec = attn_weights @ values

        return context_vec
    
torch.manual_seed(123)
sa_v1 = SelfAttention_V1(d_in, d_out)
sa_v1(inputs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

Note in the code above that sa_v1[inputs][1] = [0.3061, 0.8210], which means the class works!

In [27]:
m = torch.nn.Linear(2,3)
m.bias

Parameter containing:
tensor([-0.3189,  0.2240, -0.3146], requires_grad=True)

In [28]:
import torch.nn as nn

# Self attention with linear layers (better implementation)

class SelfAttention_V2(nn.Module):
    
    def __init__(self, d_in, d_out, qkv_bias=False): # the bias term in the nn.Linear, which we don't need (not using a bias in this nn I guess?)
        super().__init__()
        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias) # torch.nn.Linear is an alternative to nn.Parameter (not sure what the benefit is aside from no explicit torch.rand call needed, slightly better weight initialization under the hood according to Sebastian)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)  

    def forward(self, x):

        queries = self.W_query(inputs) # similar behavior -> we are still doing matrix multiplication
        keys = self.W_key(inputs)
        values = self.W_value(inputs)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)

        context_vec = attn_weights @ values

        return context_vec
    
torch.manual_seed(789)
sa_v2 = SelfAttention_V2(d_in, d_out)
sa_v2(inputs)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

## 3.5 Hiding future words with causal attention

For a predictive model, it doesn't make sense for the model to know future words in the context (relative to the current token being looked at)

### 3.5.1 Applying a causal attention mask

In [29]:
# Your journey starts with one step <- sentence to use

# When pre-training an LLM, it is fine not to mask, but when you actually want to test prediction ability, future words should be masked.

In [30]:
queries = sa_v2.W_query(inputs) # similar behavior -> we are still doing matrix multiplication
keys = sa_v2.W_key(inputs)
values = sa_v2.W_value(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)

In [31]:
attn_weights

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

#### One way to mask is to just zero out any scores that correspond to values that come after the current word

In [32]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [36]:
mask_simple = attn_weights * mask_simple

In [37]:
# This matrix is not normalzed, so do some scaling

row_sums = mask_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = mask_simple / row_sums
masked_simple_norm

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)

There is a different way to do this; do the masking before computing softmax

In [38]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


Softmax exponentiates, so these -infs go to zero

In [39]:
attn_weights = torch.softmax(masked / d_k**0.5, dim=-1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


### 3.5.2 Masking additional attention weights with dropout

This generates a randomized map that masks *additional* attention weights **on top** of the causal mask. 

Often used to reduce the reliance of an LLM on words placed in a certain position after the current query and to prevent overfitting. 

In [40]:
torch.manual_seed(123)

layer = torch.nn.Dropout(0.5) # dropout layer with 50% dropout rate

In [41]:
example = torch.ones(6,6)
layer(example)

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

The values are being scaled using $$\frac{1}{1 - \text{dropout rate}}$$ to account for the fact that some data is being removed by the random dropout 

In [42]:
layer(attn_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4925, 0.4638, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3941, 0.0000],
        [0.3869, 0.3327, 0.0000, 0.3084, 0.3331, 0.3058]],
       grad_fn=<MulBackward0>)

### 3.5.3 Implementing a compact causal self-attention class

In [None]:
batch = torch.stack((inputs, inputs), dim=0)
batch.shape # 2 x 6 x 3 tensor

torch.Size([2, 6, 3])

In [50]:
import torch.nn as nn

# Self attention with linear layers (better implementation)

class CausalAttention(nn.Module):
    
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias) 
        self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        
        # NEW ADDITIONS
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        num_tokens, d_in = x.shape[1], x.shape[2]
        queries = self.W_query(x) # similar behavior -> we are still doing matrix multiplication
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)

        mask = torch.Tensor(self.mask)
        # Note, underscores after a function/method name in PyTorch indicate in-place operations

        attn_scores.masked_fill_(
            mask.bool()[:num_tokens, :num_tokens], -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)

        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values

        return context_vec
    
torch.manual_seed(789)
context_length = batch.shape[1]
dropout = 0.0
ca = CausalAttention(d_in, d_out, context_length, dropout)
ca(batch)

tensor([[[-0.0872,  0.0286],
         [-0.0996,  0.0511],
         [-0.1004,  0.0645],
         [-0.0984,  0.0489],
         [-0.0512,  0.1102],
         [-0.0761,  0.0682]],

        [[-0.0872,  0.0286],
         [-0.0996,  0.0511],
         [-0.1004,  0.0645],
         [-0.0984,  0.0489],
         [-0.0512,  0.1102],
         [-0.0761,  0.0682]]], grad_fn=<UnsafeViewBackward0>)

In [51]:
batch

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

## 3.6 Extending to multi-head attention

### 3.6.1 Stacking multiple single-head attention layers

In [54]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads=2, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)
            ])
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim = -1)

torch.manual_seed(123)

context_length = batch.shape[1]
d_in, d_out = batch.shape[-1], 2

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, dropout, num_heads=2, qkv_bias = False)
mha(batch)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5856,  0.0087,  0.5840,  0.3158],
         [-0.6284, -0.0607,  0.6161,  0.3779],
         [-0.5664, -0.0832,  0.5468,  0.3557],
         [-0.5511, -0.0978,  0.5317,  0.3416],
         [-0.5290, -0.1073,  0.5072,  0.3465]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5856,  0.0087,  0.5840,  0.3158],
         [-0.6284, -0.0607,  0.6161,  0.3779],
         [-0.5664, -0.0832,  0.5468,  0.3557],
         [-0.5511, -0.0978,  0.5317,  0.3416],
         [-0.5290, -0.1073,  0.5072,  0.3465]]], grad_fn=<CatBackward0>)

### 3.6.2 MHA with weight splits

Our current setup is in series (not efficient). Attention heads are independent of each other, so it doesn't make sense for one to wait for another to begin learning.

In [67]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask = torch.Tensor(self.mask)
        # Note, underscores after a function/method name in PyTorch indicate in-place operations

        attn_scores.masked_fill_(
            mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 4
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.1184,  0.3120, -0.0847, -0.5774],
         [ 0.0178,  0.3221, -0.0763, -0.4225],
         [-0.0147,  0.3259, -0.0734, -0.3721],
         [-0.0116,  0.3138, -0.0708, -0.3624],
         [-0.0117,  0.2973, -0.0698, -0.3543],
         [-0.0132,  0.2990, -0.0689, -0.3490]],

        [[ 0.1184,  0.3120, -0.0847, -0.5774],
         [ 0.0178,  0.3221, -0.0763, -0.4225],
         [-0.0147,  0.3259, -0.0734, -0.3721],
         [-0.0116,  0.3138, -0.0708, -0.3624],
         [-0.0117,  0.2973, -0.0698, -0.3543],
         [-0.0132,  0.2990, -0.0689, -0.3490]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])
