## 3.0 4 differnt variant of self-attention:
    1. simplified self-attention
    2. self-attention
        - with trainable weight that form the basis of LLM
    3. casual attention
        - add mask to self attention, allow LLM generate 1 word at a time
        - only consider current and previous input in sequence, ensure temporal order
    4. multi-head attention
        - extension of casual and self attention, to attend to information from different representation subspace

## 3.1 problem with modeling long sequences
    - pre-LLM architecture:
        - when translate, cannot translate word by word due to grammatical structure difference
        - common to use deep NN with encoder + decoder
        - before transformer, RNN was the most famouse encoder-decoder architecture for language translation
            - output from previous step fed as inpt to current step
            - suit for sequential data like text
            - encoder process input text 1 by 1 and update internal state in hidden layer
            - decoder use the final hidden state to generate output
            - limitation: e-d RNN cant directly access earlier hidden state from encoder during decoding phase
                - only rely on current hidden state
                - loss of context in complex sentence when dependencies across long distance

## 3.2 capturing data dependencies with attention mechanism
    - RNN dont have access to previous word, entire encoded input need to be in single hidden state before pass to decoder
    - Bahdanau attention mechanism
        - update on RNN
        - decoder can selectively access different parts of input sequence at each decoding step
    - self attention
        - each position in input sequence to consider relevancy of all other position ( attend to)
        - interact between position and weigh the importance

## 3.3 attending to differnt parts of the input with self attention
    - self:
        - compute attention weight by relating differnt position of single input itself
        - access and learn relation between parts of input
        - traditional attention: focus on relationship between elements of 2 differnt sequences
### 3.3.1 simple self attention
    - goal of self attention: calculate context vector z for each element in input sequence
        - context vector is like enriched embedding vector
        - contain embedding vector of respective token +  EV of other token
    - first step of self attention is to get intermediate value ω:
        - using dot product of token_i tensors with all other token tensors


In [1]:
import torch
inputs = torch.tensor(
 [[0.43, 0.15, 0.89], # Your (x^1)
 [0.55, 0.87, 0.66], # journey (x^2)
 [0.57, 0.85, 0.64], # starts (x^3)
 [0.22, 0.58, 0.33], # with (x^4)
 [0.77, 0.25, 0.10], # one (x^5)
 [0.05, 0.80, 0.55]] # step (x^6)
)

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [2]:
query = inputs[1] # getting intermediate value
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


- then each of the attention score was normalized (sum to 1)
    - usually use softmax to normalize
    - better for handle extreme value and better gradient
- softmax ensure attention weights are positive
    - so that output can be interpret as probabilities, and higher indicate more important

In [3]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [4]:
def softmax_naive(x):#normalize with softmax
    # naive method, may have overflow/ underflow
    return torch.exp(x) / torch.exp(x).sum(dim=0)
attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [5]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
# proper softmax with pytorch
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


- after getting the normalized attention weights, then can cal context vector
- context vector = sum of embedded input token * attention weights ( each dimension seperately)

In [6]:
query = inputs[1] #getting the context vector of 2nd token
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


### 3.3.2 computing attention weights for all input tokens
- compute all context vector

In [None]:
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j) 
        # torch.dot of tensor[a,b,c] and tensor[x,y,z] = a*x+b*y+c*z
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [10]:
attn_scores= inputs @ inputs.T # using matrix opration instead of for loop for faster
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [None]:
attn_weights = torch.softmax(attn_scores, dim=-1)
# normalize the score to get weight
# dim=-1 means apply normalization along last dimension
# means value in each row sum to 1
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [12]:
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)
print("All row sums:", attn_weights.sum(dim=-1))
# verify sum of each row is to 1

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [13]:
all_context_vecs = attn_weights @ inputs
print( all_context_vecs)
# compute all context vectors with attention weight and inputs 

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## 3.4 self attention with trainable weights
- scaled dot-product attention
- similar to previous, compute context vector as weighted sum over input vector specific to certain input element
- except with trainable weight that are updated during model training

### 3.4.1 attention weights computing step by step
- 3 trainable weight Wq, Wk, Wv ( query, key , vector)


In [14]:
x_2 = inputs[1]
d_in = inputs.shape[1] # input embedding size, d=3
d_out = 2 # output embedding size=2

# initialize 3 weights
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

- usually in GPT like model, input and output dimension size ar the same
- requires_grad set to false here as not training model, if training, need to be True

In [None]:
query_2 = x_2 @ W_query #getting key value query of input[1]
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)
print(key_2)
print(value_2)

tensor([0.4306, 1.4551])
tensor([0.4433, 1.1419])
tensor([0.3951, 1.0037])


In [17]:
keys = inputs @ W_key #get key and value of all input
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


- then we can compute the attention score
- similar to what have been done in the simplify self attention
- using dot product, but not directly between input element
    - but use query and key obtained with respective weighted matrics

In [None]:
keys_2 = keys[1]
attn_scores_22 = query_2.dot(keys_2)
print(attn_scores_22) # getting the ω22

tensor(1.8524)


In [19]:
attn_scores_2 = query_2 @ keys.T #generalize to get all ω
print( attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [None]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)
# getting the attention weights by scaling attention score by 
# divide by square root of embedding dimensons of keys
# why? to improve trainign perf by avoid small gradient
# when embedding dimension increase ( >1000 for GPT), dot prduct large
# after softmax, some output dominate, and other vanish, result in extreme small gradient
# slow or halt the training due to vanishing gradient

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


$$
\text{dot}(Q, K) = \sum_{i=1}^{d_k} Q_i \cdot K_i
$$
- so more d_k ( dimension), larger dot product
$$
text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j} e^{z_j}}
$$
- exponential grows fast, so large difference within group cause higher to be dominant
- become like near-one-hot (like [0,0,1]), which derivative is close to 0 everywhere
- small gradient, layer cant learn well


- to counter. scale dot product by 1/d_k^0.5
$$
\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V
$$

In [21]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)
# then, with attentio weights, we can get context vector by matrix mul with value

tensor([0.3061, 0.8210])


<img src="pic1.png" width="600"/>  

- why query, key, value?
    - query: current item the model focus on/ try to understand
        - probe other part of the input sequence to determine how much attention need to pay
    - key: for indexing, searching, used for match the query
    - value: the actual context of input item, afte model determine which keys most relevant to the query, ite retrieve corresponding value