In [83]:
import torch

In [3]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
     [0.55, 0.87, 0.66], # journey  (x^2)
     [0.57, 0.85, 0.64], # starts   (x^3)
     [0.22, 0.58, 0.33], # with     (x^4)
     [0.77, 0.24, 0.10], # one      (x^5)
     [0.05, 0.80, 0.55]] # step     (x^6)
)

In [4]:
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])

In [5]:
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.6983, 1.0865])


In [6]:
# Normalizing the attention scores.
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights: ", attn_weights_2_tmp)
print("Sum: ", attn_weights_2_tmp.sum())

Attention weights:  tensor([0.1456, 0.2281, 0.2251, 0.1287, 0.1066, 0.1658])
Sum:  tensor(1.)


In [7]:
# For normalizing it is adviced to use softmax for smoother gradient flow,
# and to make sure scores as alwasy +ve
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights: ", attn_weights_2_naive)
print("Sum: ", attn_weights_2_naive.sum())

Attention weights:  tensor([0.1387, 0.2381, 0.2335, 0.1241, 0.1073, 0.1583])
Sum:  tensor(1.0000)


In [8]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
#print("Attention weights: ", attn_weights_2)
#print("Sum: ", attn_weights_2.sum())

In [9]:
attn_weights_2_naive = softmax_naive(attn_scores_2)

In [10]:
# computing context vectors. 
# Multiplying attention weights to token and summing them up
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
print(context_vec_2)

tensor([0.4416, 0.6508, 0.5687])


In [11]:
# Attn scores for all the tokens
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4561, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.6983, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7069, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3416, 0.6565],
        [0.4561, 0.6983, 0.7069, 0.3416, 0.6605, 0.2855],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2855, 0.9450]])


In [12]:
# with matrix multiplication
attn_scores = inputs @ inputs.T
print("Attn scores:\n", attn_scores,"\nAttn weights:\n ", torch.softmax(attn_scores, dim=1))

Attn scores:
 tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4561, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.6983, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7069, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3416, 0.6565],
        [0.4561, 0.6983, 0.7069, 0.3416, 0.6605, 0.2855],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2855, 0.9450]]) 
Attn weights:
  tensor([[0.2099, 0.2006, 0.1982, 0.1243, 0.1219, 0.1452],
        [0.1387, 0.2381, 0.2335, 0.1241, 0.1073, 0.1583],
        [0.1391, 0.2371, 0.2328, 0.1243, 0.1100, 0.1566],
        [0.1436, 0.2075, 0.2047, 0.1463, 0.1257, 0.1722],
        [0.1534, 0.1954, 0.1971, 0.1368, 0.1881, 0.1293],
        [0.1386, 0.2185, 0.2129, 0.1422, 0.0981, 0.1897]])


In [13]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2099, 0.2006, 0.1982, 0.1243, 0.1219, 0.1452],
        [0.1387, 0.2381, 0.2335, 0.1241, 0.1073, 0.1583],
        [0.1391, 0.2371, 0.2328, 0.1243, 0.1100, 0.1566],
        [0.1436, 0.2075, 0.2047, 0.1463, 0.1257, 0.1722],
        [0.1534, 0.1954, 0.1971, 0.1368, 0.1881, 0.1293],
        [0.1386, 0.2185, 0.2129, 0.1422, 0.0981, 0.1897]])


In [14]:
# compute all context vectors
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4420, 0.5919, 0.5791],
        [0.4416, 0.6508, 0.5687],
        [0.4428, 0.6489, 0.5675],
        [0.4301, 0.6288, 0.5514],
        [0.4671, 0.5884, 0.5266],
        [0.4174, 0.6497, 0.5649]])


In [15]:
print("previous context vector: ", context_vec_2)

previous context vector:  tensor([0.4416, 0.6508, 0.5687])


### Self-attention with trainable weights

In [16]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [17]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [18]:
W_query

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])

In [19]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


In [20]:
keys = inputs @ W_key
values = inputs @ W_value
print("keys: ", keys)
print("values: ", values)

keys:  tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1809, 0.3220],
        [0.3275, 0.9642]])
values:  tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1460, 0.3306],
        [0.3221, 0.7863]])


In [21]:
keys_2 = keys[1]
attn_scores_22 = query_2.dot(keys_2)
print(attn_scores_22)

tensor(1.8524)


In [22]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5464, 1.5440])


In [23]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k ** 0.5, dim=-1)
print(attn_weights_2)

tensor([0.1501, 0.2265, 0.2200, 0.1312, 0.0900, 0.1822])


In [24]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3059, 0.8210])


### 3.4.2 Implementing a compact self-attention python class

In [25]:
import torch.nn as nn

In [26]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
        
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T 
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [27]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2993, 0.8052],
        [0.3059, 0.8210],
        [0.3056, 0.8203],
        [0.2945, 0.7936],
        [0.2922, 0.7885],
        [0.2988, 0.8039]], grad_fn=<MmBackward0>)


We can improve the `SelfAttention_v1` implementation further by utilizing
PyTorch’s `nn.Linear` layers, which effectively perform matrix multiplication when
the bias units are disabled. Additionally, a significant advantage of using `nn.Linear` 
instead of manually implementing `nn.Parameter(torch.rand(...))` is that `nn.Linear`
has an optimized weight initialization scheme, contributing to more stable and
effective model training.

In [28]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [29]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.5332, -0.1044],
        [-0.5318, -0.1073],
        [-0.5318, -0.1072],
        [-0.5291, -0.1069],
        [-0.5305, -0.1058],
        [-0.5293, -0.1073]], grad_fn=<MmBackward0>)


In [30]:
for name, parameters in sa_v1.named_parameters():
    print(name, parameters)


W_query Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]], requires_grad=True)
W_key Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]], requires_grad=True)
W_value Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]], requires_grad=True)


In [31]:
q = sa_v1.get_parameter('W_query')
k = sa_v1.get_parameter('W_key')
v = sa_v1.get_parameter('W_value')

In [32]:
sa_v2.W_query.weight = nn.Parameter(q.T)
sa_v2.W_key.weight = nn.Parameter(k.T)
sa_v2.W_value.weight = nn.Parameter(v.T)

In [33]:
for name, parameters in sa_v2.named_parameters():
    print(name, parameters)

W_query.weight Parameter containing:
tensor([[0.2961, 0.2517, 0.0740],
        [0.5166, 0.6886, 0.8665]], requires_grad=True)
W_key.weight Parameter containing:
tensor([[0.1366, 0.1841, 0.3153],
        [0.1025, 0.7264, 0.6871]], requires_grad=True)
W_value.weight Parameter containing:
tensor([[0.0756, 0.3164, 0.1186],
        [0.1966, 0.4017, 0.8274]], requires_grad=True)


### 3.5 Hiding future words with `causal attention`

  1. Take attention score (unnormalized).
  2. Apply `softmax` to get attention "weights" (normalized).
  3. Mask with 0's above diagonal to get masked attention scores (unnormalized).
  4. Normalize the row to get masked attention "weights" (nomralized).

In [34]:
queries =sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
print(attn_weights)

tensor([[0.1552, 0.2106, 0.2061, 0.1414, 0.1068, 0.1800],
        [0.1501, 0.2265, 0.2200, 0.1312, 0.0900, 0.1822],
        [0.1505, 0.2258, 0.2194, 0.1316, 0.0908, 0.1820],
        [0.1592, 0.1995, 0.1963, 0.1478, 0.1202, 0.1770],
        [0.1612, 0.1947, 0.1921, 0.1503, 0.1265, 0.1752],
        [0.1558, 0.2093, 0.2050, 0.1420, 0.1084, 0.1795]],
       grad_fn=<SoftmaxBackward0>)


We can implement the second step using PyTorch’s `tril` function to create a mask
where the values above the diagonal are zero:

In [35]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [36]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1552, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1501, 0.2265, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1505, 0.2258, 0.2194, 0.0000, 0.0000, 0.0000],
        [0.1592, 0.1995, 0.1963, 0.1478, 0.0000, 0.0000],
        [0.1612, 0.1947, 0.1921, 0.1503, 0.1265, 0.0000],
        [0.1558, 0.2093, 0.2050, 0.1420, 0.1084, 0.1795]],
       grad_fn=<MulBackward0>)


In [37]:
# Renormalizing
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3986, 0.6014, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2526, 0.3791, 0.3683, 0.0000, 0.0000, 0.0000],
        [0.2265, 0.2839, 0.2794, 0.2103, 0.0000, 0.0000],
        [0.1954, 0.2361, 0.2329, 0.1823, 0.1533, 0.0000],
        [0.1558, 0.2093, 0.2050, 0.1420, 0.1084, 0.1795]],
       grad_fn=<DivBackward0>)


A more efficient way to obtain the masked attention weight matrix in
causal attention is to mask the attention scores with negative infinity values before
applying the softmax function.

In [67]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.9231,   -inf,   -inf,   -inf,   -inf,   -inf],
        [1.2705, 1.8524,   -inf,   -inf,   -inf,   -inf],
        [1.2544, 1.8284, 1.7877,   -inf,   -inf,   -inf],
        [0.6973, 1.0167, 0.9941, 0.5925,   -inf,   -inf],
        [0.6052, 0.8730, 0.8538, 0.5069, 0.2627,   -inf],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3856, 1.0996]],
       grad_fn=<MaskedFillBackward0>)


In [68]:
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3986, 0.6014, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2526, 0.3791, 0.3683, 0.0000, 0.0000, 0.0000],
        [0.2265, 0.2839, 0.2794, 0.2103, 0.0000, 0.0000],
        [0.1954, 0.2361, 0.2329, 0.1823, 0.1533, 0.0000],
        [0.1558, 0.2093, 0.2050, 0.1420, 0.1084, 0.1795]],
       grad_fn=<SoftmaxBackward0>)


In the following code example, we use a dropout rate of 50%, which means masking <br>
out half of the attention weights. (When we train the GPT model in later chapters, <br>
we will use a lower dropout rate, such as 0.1 or 0.2.) We apply PyTorch’s dropout <br>
implementation first to a 6 × 6 tensor consisting of 1s for simplicity:

In [80]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


When applying dropout to an attention weight matrix with a rate of 50%, half of the
elements in the matrix are randomly set to zero. To compensate for the reduction in
active elements, the values of the remaining elements in the matrix are scaled up by a
factor of 1/0.5 = 2. This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains
consistent during both the training and inference phases.

In [82]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.2029, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.7366, 0.0000, 0.0000, 0.0000],
        [0.4529, 0.5677, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3908, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4186, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


### 3.5.3 Implementing a compact causal attention class

In [84]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [93]:
batch

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2400, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2400, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [94]:
batch.transpose(1, 2)

tensor([[[0.4300, 0.5500, 0.5700, 0.2200, 0.7700, 0.0500],
         [0.1500, 0.8700, 0.8500, 0.5800, 0.2400, 0.8000],
         [0.8900, 0.6600, 0.6400, 0.3300, 0.1000, 0.5500]],

        [[0.4300, 0.5500, 0.5700, 0.2200, 0.7700, 0.0500],
         [0.1500, 0.8700, 0.8500, 0.5800, 0.2400, 0.8000],
         [0.8900, 0.6600, 0.6400, 0.3300, 0.1000, 0.5500]]])

In [99]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out   = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
    
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
        self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(
        attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
    
        context_vec = attn_weights @ values
        return context_vec
        

In [100]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


In [101]:
context_vecs

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5519, -0.0972],
         [-0.5293, -0.1073]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5519, -0.0972],
         [-0.5293, -0.1073]]], grad_fn=<UnsafeViewBackward0>)

### 3.6 Single head attention to multi-head attention

We will tackle this expansion from causal attention to multi-head attention. First, <br>
we will intuitively build a multi-head attention module by stacking multiple Causal <br>
Attention modules. Then we will then implement the same multi-head attention <br>
module in a more complicated but more computationally efficient way.

#### Stacking multiple single-headed attention layers

In [102]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])
        
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [103]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3, 2

In [107]:
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape: ", context_vecs.shape)

tensor([[[-0.1471,  0.4106,  0.4675, -0.2793],
         [-0.2493,  0.3548,  0.4651, -0.0590],
         [-0.2782,  0.3323,  0.4578,  0.0089],
         [-0.2636,  0.2932,  0.4108,  0.0479],
         [-0.2190,  0.2184,  0.3162,  0.0206],
         [-0.2415,  0.2433,  0.3491,  0.0629]],

        [[-0.1471,  0.4106,  0.4675, -0.2793],
         [-0.2493,  0.3548,  0.4651, -0.0590],
         [-0.2782,  0.3323,  0.4578,  0.0089],
         [-0.2636,  0.2932,  0.4108,  0.0479],
         [-0.2190,  0.2184,  0.3162,  0.0206],
         [-0.2415,  0.2433,  0.3491,  0.0629]]], grad_fn=<CatBackward0>)
context_vecs.shape:  torch.Size([2, 6, 4])


Instead of maintaining two separate classes, `MultiHeadAttentionWrapper` and <br>
`CausalAttention` , we can combine these concepts into a single `MultiHeadAttention` <br>
class. Also, in addition to merging the `MultiHeadAttentionWrapper` with the Causal <br>
Attention code, we will make some other modifications to implement multi-head <br>
attention more efficiently.

In [153]:
# The following MultiHeadAttention class integrates the multi-head functionality within a single class.
# It splits the input into multiple heads by reshaping the projected query, key, and value
# tensors and then combines the results from these heads after computing attention.
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, 
                    context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask",
                             torch.triu(torch.ones(context_length, context_length),
                                        diagonal=1)
        )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # We implicitly split the matrix by adding a num_heads
        # dimension. Then we unroll the last dim: 
        # (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        
        # Transposes from shape (b, num_tokens, num_heads, head_dim) to 
        # (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # Computes dot product for each head
        attn_scores = queries @ keys.transpose(2, 3)
        
        # mask truncated to the number of tokens
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        
        # uses mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores /keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context_vec = (attn_weights @ values).transpose(1, 2) # Tensor shape: (b, num_tokens, num_heads, head_dim)
        
        # Combines heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        
        context_vec = self.out_proj(context_vec) # Adds an additional linear projection
        
        return context_vec

In [149]:
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
[0.8993, 0.0390, 0.9268, 0.7388],
[0.7179, 0.7058, 0.9156, 0.4340]],
[[0.0772, 0.3565, 0.1479, 0.5331],
[0.4066, 0.2318, 0.4545, 0.9737],
[0.4606, 0.5159, 0.4220, 0.5786]]]])

In [158]:
a.shape

torch.Size([1, 2, 3, 4])

In [165]:
print(a @ a.transpose(2,3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


In [166]:
first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head:\n", first_res)

second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("Second head:\n", second_res)

First head:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])
Second head:
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


In [172]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 4
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)

In [177]:
temp = torch.rand([2, 1024, 768])
mha_gpt = MultiHeadAttention(768, 768, 1024, 0.0, num_heads=12)
context_vec_gpt2 = mha_gpt(temp)