## Attention Mechanisms
1) Reasons for using attention in neural networks
2) Basic self-attention and also self-attention
3) Implementing Causal attention that allows LLM's to generate one token at a time
4) Masking random selected attention weights with dropout to reduce overfitting
5) Stacking multiple causal attention into a Multi-head attention module

In [1]:
import torch
inputs = torch.tensor(
    [[0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55]]
)

In [2]:
query = inputs[1]
atten_scores2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    atten_scores2[i] = torch.dot(x_i, query)
print(atten_scores2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [3]:
# normalizing the attention weights so they sum to 1
atten_weights_2_tmp = atten_scores2 / atten_scores2.sum()
print("attention weights: ", atten_weights_2_tmp)
print(atten_weights_2_tmp.sum())

attention weights:  tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
tensor(1.0000)


In [4]:
# using softmax to do normalization
def softmax_naive(x):
    return torch.exp(x) / (torch.exp(x).sum(dim=0))

atten_weights_2_naive = softmax_naive(atten_scores2)
print("Attention weights naive: ", atten_weights_2_naive)
print(atten_weights_2_naive.sum())

Attention weights naive:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor(1.)


In [5]:
# using the inbuilt torch.softmax
atten_weights_2 = torch.softmax(atten_scores2, dim=0)
print("Attention weights torch.softmax: ", atten_weights_2)
print(atten_weights_2.sum())

Attention weights torch.softmax:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor(1.)


In [6]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += atten_weights_2[i] * x_i
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


In [7]:
attention_scores = torch.empty(6,6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attention_scores[i,j] = torch.dot(x_i, x_j)

print(attention_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [8]:
attention_scores_2 = inputs @ inputs.T
print(attention_scores_2)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [9]:
attention_weights = torch.softmax(attention_scores_2, dim=1)
print(attention_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [10]:
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 Sum: ", row_2_sum)
print("ALL Rows sums: ", attention_weights.sum(dim=1))

Row 2 Sum:  1.0
ALL Rows sums:  tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [11]:
all_context_vecs = attention_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


### Computing weighted attention

In [12]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [13]:
torch.manual_seed(123)
w_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [14]:
print(w_query)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])


In [15]:
print(x_2)

tensor([0.5500, 0.8700, 0.6600])


In [16]:
query_2 = x_2 @ w_query
key_2 = x_2 @ w_key
value_2 = x_2 @ w_value
print(query_2)

tensor([0.4306, 1.4551])


In [17]:
keys = inputs @ w_key
values = inputs @ w_value
print("Keys shape: ", keys.shape)
print("Value shape: ", values.shape)

Keys shape:  torch.Size([6, 2])
Value shape:  torch.Size([6, 2])


In [18]:
print(keys)

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])


In [19]:
print(values)

tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])


In [20]:
keys_2 = keys[1]
atten_score_22 = query_2.dot(keys_2)
print(atten_score_22)

tensor(1.8524)


In [21]:
# performing attention scores for all the values
atten_scores_2 = query_2 @ keys.T
print(atten_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [22]:
d_k = keys.shape[-1]
print(d_k)

2


In [23]:
atten_weights_2 = torch.softmax(atten_scores_2 / (d_k**0.5), dim=-1)
print(atten_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


In [24]:
context_vector_2 = atten_weights_2 @ values
print(context_vector_2)

tensor([0.3061, 0.8210])


### Implementing a self-attention class

In [25]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.w_query = torch.nn.Parameter(torch.rand(d_in, d_out)) 
        self.w_keys = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.w_values = torch.nn.Parameter(torch.rand(d_in, d_out))
        
    def forward(self, x):
        keys = x @ self.w_keys
        values = x @ self.w_values
        queries = x @ w_query
        atten_scores = queries @ keys.T
        atten_weights = torch.softmax(atten_scores / (keys.shape[-1]**0.5), dim=-1)
        context_vector = atten_weights @ values
        return context_vector

In [26]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [27]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.w_key(x)
        values = self.w_value(x)
        query = self.w_query(x)
        attention_scores = query @ keys.T
        attention_weights = torch.softmax(attention_scores / (keys.shape[-1]**0.5), dim=-1)
        context_vector = attention_weights @ values
        return context_vector

In [28]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


### Applying causal attention mask
1) Why do we need causal attention mask?:- we need this to make sure the llm is not seeing the future tokens while training, this masking
   helps us in not exposing the future tokens, Hence with this llm at any point can only have access to the current and the previous token.

In [29]:
### Information leak:- When we apply masking and then re-normalize the weights, it might initially appear that the information from the
### Future tokens could still influence the current tokens because their values are part of the softmax. 

In [30]:
queries = sa_v2.w_query(inputs)
keys = sa_v2.w_key(inputs)
atten_scores = queries @ keys.T
atten_weights = torch.softmax(atten_scores / (keys.shape[-1]**0.5), dim=1)
print(atten_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [31]:
context_length = atten_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [32]:
masked_simple = mask_simple * atten_weights
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [33]:
# Normalizing the masked_simple atten_weights
row_sums = masked_simple.sum(dim=1, keepdims=True)
normalized_masked_simple = masked_simple / row_sums
print(normalized_masked_simple)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


In [34]:
# Better way to do this is we can use softmax because it gives the probability distribution, softmax(-inf) is always zero.
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = atten_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [35]:
attention_weights_masked_softmax = torch.softmax(masked / (keys.shape[-1]**0.5), dim=1)
print(attention_weights_masked_softmax)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [36]:
context_vector = attention_weights_masked_softmax @ values
print(context_vector)

tensor([[0.1855, 0.8812],
        [0.2795, 0.9361],
        [0.3133, 0.9508],
        [0.2994, 0.8595],
        [0.2702, 0.7554],
        [0.2772, 0.7618]], grad_fn=<MmBackward0>)


### Masking additional attention weights with dropout
1) Why do we use dropout:- We use it to reduce overfitting during training of the LLM.
2) What is dropout:- It is a deep learning technique, where randomly selected hidden layer units are ignored during training.

In [37]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


### In the above output we use ones matrix but the values in the output are 2's.
1) Why is the output contains 2 and not one? It is because when we are dropping the units we need to balance the weights hence, all the
   other values are scaled by 1 / 0.5 which is 2.

In [38]:
torch.manual_seed(123)
print(dropout(attention_weights_masked_softmax))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


### Implementing  a compact causal attention class

In [39]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [40]:
class causal_selfattention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) ## adding the dropout layer
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
        

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.w_key(x)
        query = self.w_query(x)
        value = self.w_value(x)
        attention_scores = query @ keys.transpose(1,2)
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attention_weights = torch.softmax(attention_scores / (keys.shape[-1]**0.5), dim=1)
        attention_weights = self.dropout(attention_weights)
        context_vector = attention_weights @ value
        return context_vector

In [41]:
torch.manual_seed(123)
context_length = batch.shape[1]
csa_v2 = causal_selfattention(d_in, d_out, context_length, 0.0)
context_vecs = csa_v2(batch)
print("context_vecs.shape: ", context_vecs.shape)

context_vecs.shape:  torch.Size([2, 6, 2])


### Multi-head attention

In [50]:
class multiheadattentionwrapper(nn.Module):
    def __init__(self,d_in, d_out, num_heads, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([
            causal_selfattention(d_in, d_out, context_length, qkv_bias, dropout)
            for _ in range(num_heads)
        ])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [53]:
torch.manual_seed(123)
context_length = batch.shape[1] # number of tokens
d_in, d_out = 3, 2
mha = multiheadattentionwrapper(d_in, d_out, context_length=context_length, dropout=0.0, num_heads=2)
context_vecs = mha(batch)

print(context_vecs)
print("context_vecs_shape: ", context_vecs.shape)

tensor([[[-0.0844,  0.0414,  0.0766,  0.0171],
         [-0.2264, -0.0039,  0.2143,  0.1185],
         [-0.4163, -0.0564,  0.3878,  0.2453],
         [-0.5014, -0.1011,  0.4992,  0.3401],
         [-0.7754, -0.1867,  0.7387,  0.4868],
         [-1.1632, -0.3303,  1.1224,  0.8460]],

        [[-0.0844,  0.0414,  0.0766,  0.0171],
         [-0.2264, -0.0039,  0.2143,  0.1185],
         [-0.4163, -0.0564,  0.3878,  0.2453],
         [-0.5014, -0.1011,  0.4992,  0.3401],
         [-0.7754, -0.1867,  0.7387,  0.4868],
         [-1.1632, -0.3303,  1.1224,  0.8460]]], grad_fn=<CatBackward0>)
context_vecs_shape:  torch.Size([2, 6, 4])


### Implementing multi-head attention with weight splits

In [91]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, dropout, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0
        self.d_in = d_in
        self.d_out = d_out
        self.head_dim = d_out // num_heads
        self.num_heads = num_heads
        self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_out, d_out)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        print(b, " ", num_tokens, " ", d_in)
        keys = self.w_key(x)
        values = self.w_value(x)
        query = self.w_query(x)
        print("Before View")
        print("keys.shape: ", keys.shape)
        print("values.shape: ", values.shape)
        print("queries.shape: ", query.shape)
        
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        query = query.view(b, num_tokens, self.num_heads, self.head_dim)
        print("After View")
        print("keys.shape: ", keys.shape)
        print("values.shape: ", values.shape)
        print("queries.shape: ", query.shape)

        print("keys: ", keys)
        print("values: ", values)
        print("query: ", query)
        
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        query = query.transpose(1,2)
        print("keys: ", keys)
        print("values: ", values)
        print("query: ", query)

        attention_scores = query @ keys.transpose(2,3)
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)

        attention_weights = torch.softmax(attention_scores / (keys.shape[-1]**0.5), dim=-1)
        attention_weights = self.dropout(attention_weights)

        context_vec = (attention_weights @ values).transpose(1,2)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

In [92]:
torch.manual_seed(123)
context_length = batch.shape[1] # number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttention(d_in, d_out, context_length, num_heads=2, dropout=0.0)
context_vecs = mha(batch)

2   6   3
Before View
keys.shape:  torch.Size([2, 6, 2])
values.shape:  torch.Size([2, 6, 2])
queries.shape:  torch.Size([2, 6, 2])
After View
keys.shape:  torch.Size([2, 6, 2, 1])
values.shape:  torch.Size([2, 6, 2, 1])
queries.shape:  torch.Size([2, 6, 2, 1])
keys:  tensor([[[[-0.5740],
          [ 0.2727]],

         [[-0.8709],
          [ 0.1008]],

         [[-0.8628],
          [ 0.1060]],

         [[-0.4789],
          [ 0.0051]],

         [[-0.4744],
          [ 0.1696]],

         [[-0.5888],
          [-0.0388]]],


        [[[-0.5740],
          [ 0.2727]],

         [[-0.8709],
          [ 0.1008]],

         [[-0.8628],
          [ 0.1060]],

         [[-0.4789],
          [ 0.0051]],

         [[-0.4744],
          [ 0.1696]],

         [[-0.5888],
          [-0.0388]]]], grad_fn=<ViewBackward0>)
values:  tensor([[[[-0.4519],
          [ 0.2216]],

         [[-0.7142],
          [-0.1961]],

         [[-0.7127],
          [-0.1971]],

         [[-0.3809],
          [-0

In [93]:
print(context_vecs)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)


In [94]:
tensor_ones = torch.ones(2, 6, 2)
print("Tensor with ones:\n", tensor_ones)
tensor_ones = tensor_ones.view(2, 6, 2, 1)
print("Tensor with ones after changing the view:\n", tensor_ones)
tensor_ones = tensor_ones.transpose(1,2)
print("Tensor with ones after changing the view and transposing:\n", tensor_ones)

Tensor with ones:
 tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]])
Tensor with ones after changing the view:
 tensor([[[[1.],
          [1.]],

         [[1.],
          [1.]],

         [[1.],
          [1.]],

         [[1.],
          [1.]],

         [[1.],
          [1.]],

         [[1.],
          [1.]]],


        [[[1.],
          [1.]],

         [[1.],
          [1.]],

         [[1.],
          [1.]],

         [[1.],
          [1.]],

         [[1.],
          [1.]],

         [[1.],
          [1.]]]])
Tensor with ones after changing the view and transposing:
 tensor([[[[1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.]],

         [[1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.]]],


        [[[1.],
          [1.],
      

In [97]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0, num_heads = 2)
context_vecs = mha(batch)
print(context_vecs)

2   6   3
Before View
keys.shape:  torch.Size([2, 6, 2])
values.shape:  torch.Size([2, 6, 2])
queries.shape:  torch.Size([2, 6, 2])
After View
keys.shape:  torch.Size([2, 6, 2, 1])
values.shape:  torch.Size([2, 6, 2, 1])
queries.shape:  torch.Size([2, 6, 2, 1])
keys:  tensor([[[[-0.5740],
          [ 0.2727]],

         [[-0.8709],
          [ 0.1008]],

         [[-0.8628],
          [ 0.1060]],

         [[-0.4789],
          [ 0.0051]],

         [[-0.4744],
          [ 0.1696]],

         [[-0.5888],
          [-0.0388]]],


        [[[-0.5740],
          [ 0.2727]],

         [[-0.8709],
          [ 0.1008]],

         [[-0.8628],
          [ 0.1060]],

         [[-0.4789],
          [ 0.0051]],

         [[-0.4744],
          [ 0.1696]],

         [[-0.5888],
          [-0.0388]]]], grad_fn=<ViewBackward0>)
values:  tensor([[[[-0.4519],
          [ 0.2216]],

         [[-0.7142],
          [-0.1961]],

         [[-0.7127],
          [-0.1971]],

         [[-0.3809],
          [-0