- Starting with a very simple implementation of Attention, with no training of weights.

In [12]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

word_input = ['Your', 'journey', 'starts', 'with', 'one', 'step']

- The text input above is `Your journey starts with one step`.
- Each word above is an input token
- Each word has an embedding.

In [8]:
query = inputs[1] # `journey`

# attn scores for each of the word in the input against the word `journey`
attn_scores_2 = torch.empty(inputs.shape[0]) # Create empty vector

# Doing dot product between `journey` and each word in the input.
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)

print(attn_scores_2)    

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [9]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum() # Scale by dividing with the SUM
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


# The above norailzation is better with softmax
- Provides better managing of extreme values and offers more favorable gradients for training.

In [9]:
def softmax_naive(x): 
    '''Calculate and return the softmax'''
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())


Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [10]:
# Torch has a better implementation
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [57]:
query = inputs[1] # 2nd input token = `journey`

context_vec_2 = torch.zeros(query.shape)
print("Word      : Word Embedding                   : Word Attn for `journey`  : Context vector after each word")
print("="*105)

for i,x_i in enumerate(inputs):
    print(f"{word_input[i]:<10}: {x_i} : {attn_weights_2[i]:1.21}  : {context_vec_2}")
    context_vec_2 += attn_weights_2[i]*x_i

print("Final Context vector for `journey` = ",context_vec_2) # This is the context vector for `journey`

Word      : Word Embedding                   : Word Attn for `journey`  : Context vector after each word
Your      : tensor([0.4300, 0.1500, 0.8900]) : 0.138547569513320922852  : tensor([0., 0., 0.])
journey   : tensor([0.5500, 0.8700, 0.6600]) : 0.237891301512718200684  : tensor([0.0596, 0.0208, 0.1233])
starts    : tensor([0.5700, 0.8500, 0.6400]) : 0.233274027705192565918  : tensor([0.1904, 0.2277, 0.2803])
with      : tensor([0.2200, 0.5800, 0.3300]) : 0.123991586267948150635  : tensor([0.3234, 0.4260, 0.4296])
one       : tensor([0.7700, 0.2500, 0.1000]) : 0.108181864023208618164  : tensor([0.3507, 0.4979, 0.4705])
step      : tensor([0.0500, 0.8000, 0.5500]) : 0.158113613724708557129  : tensor([0.4340, 0.5250, 0.4813])
Final Context vector for `journey` =  tensor([0.4419, 0.6515, 0.5683])


- The attention weights we calculated for each of word in the input, while the second word was used as the query.
- Once the attention weights are available, each word in the input is multiplied by its attention weight to generate the new context vector for the input query.
- Next, compute the context vector for all the words in the input. The 3 steps for computing the context vector are:
    1) Compute the attention scores
    2) Compute the attention weights
    3) Compute the context vectors.


In [77]:
ctx_sum = 0
for i, sample in enumerate(inputs):
    ctx_sum += sample[0] * attn_weights_2[i] # cumulative sum of each word times it attention weight.
    print(sample[0], attn_weights_2[i], sample[0] * attn_weights_2[i], ctx_sum)

# The code above shows how the first dimension for the context vector is calculated.

tensor(0.4300) tensor(0.1385) tensor(0.0596) tensor(0.0596)
tensor(0.5500) tensor(0.2379) tensor(0.1308) tensor(0.1904)
tensor(0.5700) tensor(0.2333) tensor(0.1330) tensor(0.3234)
tensor(0.2200) tensor(0.1240) tensor(0.0273) tensor(0.3507)
tensor(0.7700) tensor(0.1082) tensor(0.0833) tensor(0.4340)
tensor(0.0500) tensor(0.1581) tensor(0.0079) tensor(0.4419)


In [19]:
# Input has 6 words. Each word will attend to itself and the other words.
# Hence the attention scores will be a 6 X 6 matrix.
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)        


tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [20]:
# quicker way to do it
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [21]:
# Normalize the scores
attn_weights = torch.softmax(attn_scores, dim=-1) # -1 is the last dimension
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [24]:
attn_weights.sum(axis=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [25]:
# Finally compute all the context vectors
all_context_vectors = attn_weights @ inputs
print(all_context_vectors)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


# Implementing Self-Attention with trainable weights
- `scaled_dot_product_attention`
As done above first we will compute the example using one input (the word `journey`) and then extend it to all the words.

In [83]:
x_2 = inputs[1] # This is the second element of the input. `journey`
d_in = inputs.shape[1] # The input embedding size, d=3.
d_out = 2 # the output embedding size, d=2. Note the input and output embedding sizes are not the same.

In [84]:
# Initialize the weight matrices
torch.manual_seed(123)
# requires_grad set to False to reduce clutter, but for training this would be set to True
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
print(W_query)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])


In [85]:
# Dot product with the query, key and value matrices
# This is for the second word `journey`
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)
print(key_2)
print(value_2)

tensor([0.4306, 1.4551])
tensor([0.4433, 1.1419])
tensor([0.3951, 1.0037])


- We are focused on computing the context vector (z_2) for the input 2 (x_2 : `journey`)
- But in order to do that we still require the key and value vector for all the input elements.
- This is b/c they are involved in computing the attention weights with respect to the query x_2.

In [86]:
# Compute the keys and values for all the inputs
keys = inputs @ W_key
values = inputs @ W_value
print(f"keys.shape: ", keys.shape)
print(f"values.shape: ", values.shape)

keys.shape:  torch.Size([6, 2])
values.shape:  torch.Size([6, 2])


# KEYS
```
+---------+--------+--------+
|         |   d1   |   d2   |
+---------+--------+--------+
| Your    | 0.3669 | 0.7646 |
+---------+--------+--------+
| journey | 0.4433 | 1.1419 |
+---------+--------+--------+
| starts  | 0.4361 | 1.1156 |
+---------+--------+--------+
| with    | 0.2408 | 0.6706 |
+---------+--------+--------+
| one     | 0.1827 | 0.3292 |
+---------+--------+--------+
| step    | 0.3275 | 0.9642 |
+---------+--------+--------+
```

# VALUES
```
+---------+--------+--------+
|         |   d1   |   d2   |
+---------+--------+--------+
| Your    | 0.1855 | 0.8812 |
+---------+--------+--------+
| journey | 0.3951 | 1.0037 |
+---------+--------+--------+
| starts  | 0.3879 | 0.9831 |
+---------+--------+--------+
| with    | 0.2393 | 0.5493 |
+---------+--------+--------+
| one     | 0.1492 | 0.3346 |
+---------+--------+--------+
| step    | 0.3221 | 0.7863 |
+---------+--------+--------+
```

- Using the `keys` and the `values` calculated above for all the inputs, we will now generate the attention score for each of the word in the input against the x_2 (`journey`)
- Hence, for the first input `Your`, the attention score will be the query for the word `journey`, is the dot product of the key values for `Your` and query values for `journey`. 
- Check the value of the query_2 for the word `journey`.

```
-----------------------------------------
|           Dot Product Calculation     |
-----------------------------------------
|  Values:                              |
|                                       |
|  Your from Table 1:                   |
|    - d1 = 0.3669                      |
|    - d2 = 0.7646                      |
|                                       |
|  journey from Table 2:                |
|    - d1 = 0.4306                      |
|    - d2 = 1.4551                      |
|                                       |
|  Dot Product Formula:                 |
|  -----------------------------------  |
|  (0.3669 * 0.4306) + (0.7646 * 1.4551)|
|                                       |
|  Calculations:                        |
|    - 0.3669 * 0.4306 = 0.1579         |
|    - 0.7646 * 1.4551 = 1.1119         |
|                                       |
|  Result:                              |
|    - 0.1579 + 1.1119 = 1.2698         |
|                                       |
|  Dot Product = 1.2698                 |
-----------------------------------------
```
- The above is the vallue of `attention_score_21`. The attention that the second word `journey` has to be pay on the first word `Your`.


In [100]:
# Calculation of attention_score_22 following the logic above.
keys_2 = keys[1]
attn_scores_22 = query_2.dot(keys_2)
print(attn_scores_22)

tensor(1.8524)


In [101]:
# Calculating the attention score for all the keys/words in the input.
attn_scores_2 =  query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


### Converting from attention scores to attention weights

In [50]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


- Above, we scale the attention scores by dividing them by the `sqaure root` of the embedding dimension of the keys. Provides stability during training.

# Caluclating Context Vector:
- Each of the word's attention weight was caluclated above wrt `journey`.
- The context for `journey` is computed by doing a dot product between the attention_weights and the values of each of the word.
- This generates context for the word `journey`.
- This is effectively just multiplying the 2-D value of each word with the required attention_weight for that word, and then summing them all.
- What we get below is the single context vector for the word `journey`.
- We need to do the same for all the inputs.

In [102]:
# caluclate the context
context_vec_2 = attn_weights_2  @ values
print(context_vec_2)

tensor([0.3069, 0.8188])


# Compact self_attention Python class

In [104]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self,x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec



In [105]:
torch.manual_seed(123)

# Create the self-attention instance
sa_v1 = SelfAttention_v1(d_in, d_out)

# Run it on the inputs.
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [126]:
# Another implementation
# nn.Linear has optimized weight initilaization. So use that rather than nn.Parameter
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self,x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec


In [127]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))
# Note the answer is different, since nn.Linear uses a different and more sophisticated weight initialization.

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


# Causal Attention / Masked Attention
- Future prediction should only depend on the past words.

In [128]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [129]:
# Mask can be created using the TRIL function from pytorch
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [130]:
#Multiply mask with attention weights
masked_simple = attn_weights * mask_simple
masked_simple

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)

In [131]:
# renormalize to sum = 1 in each row
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


In [137]:
# More efficient implementation with softmax
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
# The above sets everything up and now only next step is softmax
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5,dim=1)
print(attn_weights)


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


# Dropout to prevent over-fitting

In [78]:
# Using dropout to prevent overfitting
# Can be applied after attention scores or after applying the attention weights to the values vector
# Here the second is shown since it is more commonly used.
# Since half of the elements are set to zero, in order to maintain simiar number of active element, the entire matrix is scaled by 1/0.5 = 2.

torch.manual_seed(123)
dropout = nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [80]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


# Class of Causal Attention:
- We will use masking for causal attention
- We will use dropout.
- We will use this class when we do Multi-head attention
- We will ensure that it can handle batches.

In [140]:
# for the sake of simplicity, we will just stack same inputs to make a batch
batch = torch.stack((inputs, inputs), dim=0)
batch.shape
# These are 2 inputs with 6 tokens each

torch.Size([2, 6, 3])

In [141]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape # b is batch size, num_tokens is number of tokens and d_in is the embedding dimension for each token
        keys = self.W_key(x) 
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # note that keys, values and queries are computed for all batches

        attn_scores = queries @ keys.transpose(1, 2) 
        # note that matric multiplication is done for 2-D matrices. 
        # The first dimension is batch number. So in order to matrix multiply queries and keys, 
        # we need to do keys.T. This can be done using the transpose function in pytorch.
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec



In [166]:
d_out = 2
W_query = nn.Linear(d_in, d_out, bias=False)
W_key = nn.Linear(d_in, d_out, bias=False)
W_value = nn.Linear(d_in, d_out, bias=False)
dropout = nn.Dropout(0.2)
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
#
##########################
b, num_tokens, d_in = batch.shape
print(f"Batch size: {b}")
print(f"Number of tokens/ context length: {num_tokens}")
print(f"Input dimensions: {d_in}")

##########################
keys = W_key(batch)
queries = W_query(batch)
values = W_value(batch)
print(f"Shape of keys, queries, and values: {(keys.shape, queries.shape, values.shape)}")

##########################
attn_scores = queries @ keys.transpose(1,2)
attn_scores.masked_fill_(mask.bool()[:num_tokens, :num_tokens], -torch.inf)
attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
attn_weights = dropout(attn_weights)
context_vec = attn_weights @ values
print("Context Vector")
print(context_vec)


Batch size: 2
Number of tokens/ context length: 6
Input dimensions: 3
Shape of keys, queries, and values: (torch.Size([2, 6, 2]), torch.Size([2, 6, 2]), torch.Size([2, 6, 2]))
Context Vector
tensor([[[-0.2905, -0.1901],
         [-0.1142, -0.3385],
         [-0.1021, -0.0668],
         [-0.1369, -0.2201],
         [-0.1880, -0.4329],
         [-0.2015, -0.4806]],

        [[-0.2905, -0.1901],
         [-0.2656, -0.4376],
         [-0.2603, -0.5219],
         [-0.2133, -0.5038],
         [-0.1952, -0.3316],
         [-0.0367, -0.2071]]], grad_fn=<UnsafeViewBackward0>)


In [162]:
attn_weights.shape

torch.Size([2, 6, 6])

In [164]:
values.shape

torch.Size([2, 6, 2])

In [93]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print(context_vecs.shape)

torch.Size([2, 6, 2])


# Extend Single Head to Multi Head