## 3.0 4 differnt variant of self-attention:
    1. simplified self-attention
    2. self-attention
        - with trainable weight that form the basis of LLM
    3. casual attention
        - add mask to self attention, allow LLM generate 1 word at a time
        - only consider current and previous input in sequence, ensure temporal order
    4. multi-head attention
        - extension of casual and self attention, to attend to information from different representation subspace

## 3.1 problem with modeling long sequences
    - pre-LLM architecture:
        - when translate, cannot translate word by word due to grammatical structure difference
        - common to use deep NN with encoder + decoder
        - before transformer, RNN was the most famouse encoder-decoder architecture for language translation
            - output from previous step fed as inpt to current step
            - suit for sequential data like text
            - encoder process input text 1 by 1 and update internal state in hidden layer
            - decoder use the final hidden state to generate output
            - limitation: e-d RNN cant directly access earlier hidden state from encoder during decoding phase
                - only rely on current hidden state
                - loss of context in complex sentence when dependencies across long distance

## 3.2 capturing data dependencies with attention mechanism
    - RNN dont have access to previous word, entire encoded input need to be in single hidden state before pass to decoder
    - Bahdanau attention mechanism
        - update on RNN
        - decoder can selectively access different parts of input sequence at each decoding step
    - self attention
        - each position in input sequence to consider relevancy of all other position ( attend to)
        - interact between position and weigh the importance

## 3.3 attending to differnt parts of the input with self attention
    - self:
        - compute attention weight by relating differnt position of single input itself
        - access and learn relation between parts of input
        - traditional attention: focus on relationship between elements of 2 differnt sequences
### 3.3.1 simple self attention
    - goal of self attention: calculate context vector z for each element in input sequence
        - context vector is like enriched embedding vector
        - contain embedding vector of respective token +  EV of other token
    - first step of self attention is to get intermediate value ω:
        - using dot product of token_i tensors with all other token tensors


In [47]:
import torch
inputs = torch.tensor(
 [[0.43, 0.15, 0.89], # Your (x^1)
 [0.55, 0.87, 0.66], # journey (x^2)
 [0.57, 0.85, 0.64], # starts (x^3)
 [0.22, 0.58, 0.33], # with (x^4)
 [0.77, 0.25, 0.10], # one (x^5)
 [0.05, 0.80, 0.55]] # step (x^6)
)

In [48]:
query = inputs[1] # getting intermediate value
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


- then each of the attention score was normalized (sum to 1)
    - usually use softmax to normalize
    - better for handle extreme value and better gradient
- softmax ensure attention weights are positive
    - so that output can be interpret as probabilities, and higher indicate more important

In [49]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [50]:
def softmax_naive(x):#normalize with softmax
    # naive method, may have overflow/ underflow
    return torch.exp(x) / torch.exp(x).sum(dim=0)
attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [51]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
# proper softmax with pytorch
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


- after getting the normalized attention weights, then can cal context vector
- context vector = sum of embedded input token * attention weights ( each dimension seperately)

In [52]:
query = inputs[1] #getting the context vector of 2nd token
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


### 3.3.2 computing attention weights for all input tokens
- compute all context vector

In [53]:
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j) 
        # torch.dot of tensor[a,b,c] and tensor[x,y,z] = a*x+b*y+c*z
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [54]:
attn_scores= inputs @ inputs.T # using matrix opration instead of for loop for faster
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [55]:
attn_weights = torch.softmax(attn_scores, dim=-1)
# normalize the score to get weight
# dim=-1 means apply normalization along last dimension
# means value in each row sum to 1
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [56]:
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)
print("All row sums:", attn_weights.sum(dim=-1))
# verify sum of each row is to 1

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [57]:
all_context_vecs = attn_weights @ inputs
print( all_context_vecs)
# compute all context vectors with attention weight and inputs 

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## 3.4 self attention with trainable weights
- scaled dot-product attention
- similar to previous, compute context vector as weighted sum over input vector specific to certain input element
- except with trainable weight that are updated during model training

### 3.4.1 attention weights computing step by step
- 3 trainable weight Wq, Wk, Wv ( query, key , vector)


In [58]:
x_2 = inputs[1]
d_in = inputs.shape[1] # input embedding size, d=3
d_out = 2 # output embedding size=2

# initialize 3 weights
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

- usually in GPT like model, input and output dimension size ar the same
- requires_grad set to false here as not training model, if training, need to be True

In [59]:
query_2 = x_2 @ W_query #getting key value query of input[1]
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)
print(key_2)
print(value_2)

tensor([0.4306, 1.4551])
tensor([0.4433, 1.1419])
tensor([0.3951, 1.0037])


In [60]:
keys = inputs @ W_key #get key and value of all input
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


- then we can compute the attention score
- similar to what have been done in the simplify self attention
- using dot product, but not directly between input element
    - but use query and key obtained with respective weighted matrics

In [61]:
keys_2 = keys[1]
attn_scores_22 = query_2.dot(keys_2)
print(attn_scores_22) # getting the ω22

tensor(1.8524)


In [62]:
attn_scores_2 = query_2 @ keys.T #generalize to get all ω
print( attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [63]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)
# getting the attention weights by scaling attention score by 
# divide by square root of embedding dimensons of keys
# why? to improve trainign perf by avoid small gradient
# when embedding dimension increase ( >1000 for GPT), dot prduct large
# after softmax, some output dominate, and other vanish, result in extreme small gradient
# slow or halt the training due to vanishing gradient

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


$$
\text{dot}(Q, K) = \sum_{i=1}^{d_k} Q_i \cdot K_i
$$
- so more d_k ( dimension), larger dot product
$$
\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j} e^{z_j}}
$$
- exponential grows fast, so large difference within group cause higher to be dominant
- become like near-one-hot (like [0,0,1]), which derivative is close to 0 everywhere
- small gradient, layer cant learn well


- to counter. scale dot product by 1/d_k^0.5
$$
\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V
$$

In [64]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)
# then, with attentio weights, we can get context vector by matrix mul with value

tensor([0.3061, 0.8210])


<img src="pic1.png" width="600"/>  

- why query, key, value?
    - query: current item the model focus on/ try to understand
        - probe other part of the input sequence to determine how much attention need to pay
    - key: for indexing, searching, used for match the query
    - value: the actual context of input item, afte model determine which keys most relevant to the query, ite retrieve corresponding value

### 3.4.2 Implement a compact self attention python class
- usually, the LLM code will be organized into a python class

In [65]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out): # initialize trainable weight matrics
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    def forward(self, x): # compute attention scores
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        # step1: multiple query with key
        attn_scores = queries @ keys.T # omega
        # step2: normalize
        attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        #step3:create context vector by weighting value with attention weights
        context_vec = attn_weights @ values
        return context_vec

In [66]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


<img src="pic2.png" width="600"/>  

- self attention involve 3 trainable weights matrices, Wq, Wk, Wv
- trainabe weights adjusts when model exposed to data
- use Q and K to compute attention weightd
- then use attention weight and V to compute context vector (Z)

In [67]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        # using nn.Linear(d_in, d_out, bias=qkv_bias) to 
        # replace nn.Parameter(torch.rand(d_in, d_out))
        # nn.Linear can provide better weight initialize scheme
        # so more stable and effective model training
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

In [68]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))
# use V2 to get context vec

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


In [69]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
# using paramter in v2 for v1, for exercise
sa_v1.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)
sa_v1.W_key = torch.nn.Parameter(sa_v2.W_key.weight.T)
sa_v1.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)
print(sa_v1(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


- next, need to:
    1. add in casual element: prvent model access future information in input
    2. add in multi-head element: split attention mechanism into multiple head
        - each head learn differnt aspect of data
        - so model can attend differnt information at the same time, for complex task

## 3.5 hiding future words with casual attention
- so that self attention mechanism only consider token that appear prior to current position
- known as masked attention
- standard self-attention allow acces to entire input
<img src="pic3.png" width="600">

### 3.5.1 Applying casual attention mask
- after normalize weight, apply 0's diagonally => "masked attention scores"
- then normalize the row => "masked attention weights"

In [70]:
# getting attention weights
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T # .T is property of tensors, swap dimension of a 2d tensors
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

# use pytorch tril function to create a mask
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

# apply mask 
masked_simple = attn_weights*mask_simple
print(masked_simple)

# renormalize, so that sum to 1 again
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)
tensor([[1.0000, 0.0000, 0.0000

- Information leakage: after mask, and renormalized, the information of masked element will be removed. So, no information leakage
- more efficient way to mask, it to mask attention score with -inf before softmax
- which will be converted to 0

In [71]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked) # masked with -inf

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [72]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights) #applied softmax

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


### 3.5.2 masking additional attention weights with dropout
- dropout: randomly ignore some hidden layer unit, to prevent overfitting
    - ensure model not overly reliant on specific set of hidden unit
    - only use when training, and disabled after 
- dropout for GPT when:
    - after applying attention weights
    - after calculate attention weight

In [73]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) # dropout rate of 50%
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


<img src="pic4.PNG" width="600">  

- to compensate for the dropout of 50%
    - the remaining elements in matrix scale by 2x

In [74]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


### 3.5.3 Compact casual attention class

In [75]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) 
# for batching purpose, create a batch of 2 input ( 6*3)
# 2 input text, with 6 token, each with 3 dimensional embedding vector

torch.Size([2, 6, 3])


In [76]:
# incorporate causla attention and dropout modification into self attention class
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
    dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # dropout added here
        self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_length, context_length),
        diagonal=1)
        ) # added register buffer, useful for non-trainable tensor like mask.
        # register a tensor as part of model so tat it auto move to same device as model parameter
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x) 
        # print(keys.shape, '1')
        # print(keys.T.shape, '2')
        # print(keys.transpose(0,1).shape, '2')
        # print(keys.transpose(1,2).shape, '3')
        queries = self.W_query(x) #(b, num_tokens, d_out)
        values = self.W_value(x) 

        attn_scores = queries @ keys.transpose(1, 2) 
        # (b, num_tokens, d_out) @ (b, d_out, num_tokens) => (b, num_tokens, num_tokens)
        # transpose: swap token# and dimension vector, keep batch number at 1st
        attn_scores.masked_fill_( # pytroch operation with trailing "_" means inplace
        self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec

In [77]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


## 3.6 Extending single-head attention to multi-head attention
- multi-head: divide attention mechanism, so that each operate independently
- like stacking multiple casual-attention module for multi-head attention
    - (can be more efficient)

### 3.6.1 stacking multiple single-head attention layers
- multi-head attention involves creating multiple instance of self attention
    - each with own weight
- then combine output from them

In [78]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
        [CausalAttention( d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
        ) 
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)
        #but this is actually process sequentially, need to be in parallel
        # can achieve by compute output of all attention head simultanesouly via matrix

# if num_heads=2, d_out=2, we can get 4 dimensional context vector

In [79]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
 d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
# first dimension=2, due to 2 input text
# second dimension=6, due to 6 token in each input
# third dimension=4 due to 4 dimensional embedding ( from num_heads* d_out)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


### 3.6.2 Implementing multi-head attention with weight splits

In [None]:
# combine the multi head attention and casual attention class
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # reduce projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out) # use linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x) 

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(
            b, num_tokens, self.num_heads, self.head_dim
        )
        # reshape the matrix with self.head_dim & self.num_head. 
        # it was (b, num_tokens, d_out) previously


        # transpose the 3 from (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2) 
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3) # get dot product of each head
        # (b, num_heads, num_tokens, head_dim) @ (b, num_heads, head_dim, num_tokens) => (b, num_heads, num_tokens, num_tokens)
        # (...,m,k)@(...,k,n)=> (...m,n)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens] # mask truncated to number of tokens

        attn_scores.masked_fill_(mask_bool, -torch.inf) # use mask to fill attention scores

        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2) # transpose back to shape (b, num_tokens, num_heads, head_dim)
        context_vec = context_vec.contiguous().view(
            b, num_tokens, self.d_out
        ) # combine heads. self.d_out = self.num_heads * self.head_dim

        context_vec = self.out_proj(context_vec) # optional linear projection
        return context_vec

<img src="pic5.png" width='600'>  

- above is what happen in `MultiHeadAttentionWrapper`:  2 weight and compute 2 queries seperately
- below is what happen in `MultiheadAttention`: initialize 1 larger matrix and perform only 1 matrix multiplication

- more efficient for `MultiheadAttention` due to only 1 matrix mul to compute, while the earlier need to repeat multiple time

In [81]:
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
 [0.8993, 0.0390, 0.9268, 0.7388],
 [0.7179, 0.7058, 0.9156, 0.4340]],
 [[0.0772, 0.3565, 0.1479, 0.5331],
 [0.4066, 0.2318, 0.4545, 0.9737],
 [0.4606, 0.5159, 0.4220, 0.5786]]]])

print(a @ a.transpose(2, 3))

first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head:\n", first_res)

second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])
First head:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])

Second head:
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


In [82]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

# in smallest GPT-2, have 12 attention heads, context vector embedding of 768

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
