# Attention and self attention mechanism
Huilin Zhang hz3455@nyu.edu

for more details: https://github.com/zoezhang1202/LLM_Transformer

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


## 1. Code implementation

### 1.1 Single head self attention

In [2]:
def linear_projection(x, weight, bias=None):
    """
    Performs a linear projection.
    
    Args:
        x: Input tensor of shape (batch_size, seq_length, embed_dim).
        weight: Weight matrix of shape (embed_dim, embed_dim).
        bias: Bias vector of shape (embed_dim,).
        
    Returns:
        Projected tensor of shape (batch_size, seq_length, embed_dim).
    """

    if bias is not None: 
        return torch.matmul(x, weight) + bias 
    else:
        return torch.matmul(x, weight)
    
def scaled_dot_product_attention(queries, keys, values, embed_dim):
    """
    Computes scaled dot-product attention.
    
    Args:
        queries: Query tensor of shape (batch_size, seq_length, embed_dim).
        keys: Key tensor of shape (batch_size, seq_length, embed_dim).
        values: Value tensor of shape (batch_size, seq_length, embed_dim).
        embed_dim: Dimension of the embedding.
        
    Returns:
        attended_values: Attention output of shape (batch_size, seq_length, embed_dim).
        attention_weights: Attention weights of shape (batch_size, seq_length, seq_length).
    """
    scores = torch.matmul(queries, keys.transpose(-2, -1)) / (embed_dim ** 0.5)
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, values)
    return output



### 1.2 Single head self attention example

In [3]:
batch_size = 1 
seq_length = 3 #sentence length
embed_dim = 4 #512 in paper

#set a random seed
#torch.manual_seed(seed=1) 
# Create random input tensor (batch_size, seq_length, embed_dim)
x = torch.randn(batch_size, seq_length, embed_dim)
x

tensor([[[-0.1502,  0.9949, -0.3785, -0.4301],
         [ 2.1786, -0.2861, -1.2897, -1.2918],
         [-1.0618,  2.6584,  0.8894,  1.2147]]])

In [4]:
# Initialize weights and biases for linear projections
query_weight = torch.randn(embed_dim, embed_dim)
key_weight = torch.randn(embed_dim, embed_dim)
value_weight = torch.randn(embed_dim, embed_dim)
query_bias = torch.randn(embed_dim)
key_bias = torch.randn(embed_dim)
value_bias = torch.randn(embed_dim)


In [5]:
# Compute queries, keys, and values using linear projections
queries = linear_projection(x, query_weight, query_bias)
keys = linear_projection(x, key_weight, key_bias)
values = linear_projection(x, value_weight, value_bias)

In [6]:
# Apply scaled dot-product attention
output = scaled_dot_product_attention(queries, keys, values, embed_dim)

In [7]:
print("Input:")
print(x)
print("Output")
print(output)


Input:
tensor([[[-0.1502,  0.9949, -0.3785, -0.4301],
         [ 2.1786, -0.2861, -1.2897, -1.2918],
         [-1.0618,  2.6584,  0.8894,  1.2147]]])
Output
tensor([[[-1.5311, -2.3075,  0.0953, -2.0431],
         [ 1.9498,  1.1295, -1.2156,  4.3303],
         [-1.7512, -2.6436,  0.4118, -2.4540]]])


### 1.3 multi head self attention

In [8]:
def linear_projection(x, weight, bias=None):
    """
    Performs a linear projection.
    
    Args:
        x: Input tensor of shape (batch_size, seq_length, embed_dim).
        weight: Weight matrix of shape (embed_dim, embed_dim).
        bias: Bias vector of shape (embed_dim,).
        
    Returns:
        Projected tensor of shape (batch_size, seq_length, embed_dim).
    """

    if bias is not None: 
        return torch.matmul(x, weight) + bias 
    else:
        return torch.matmul(x, weight)
    

def scaled_dot_product_attention(queries, keys, values, embed_dim):
    """
    Computes scaled dot-product attention.
    
    Args:
        queries: Query tensor of shape (batch_size, seq_length, embed_dim).
        keys: Key tensor of shape (batch_size, seq_length, embed_dim).
        values: Value tensor of shape (batch_size, seq_length, embed_dim).
        embed_dim: Dimension of the embedding.
        
    Returns:
        attended_values: Attention output of shape (batch_size, seq_length, embed_dim).
        attention_weights: Attention weights of shape (batch_size, seq_length, seq_length).
    """
    scores = torch.matmul(queries, keys.transpose(-2, -1)) / (embed_dim ** 0.5)
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, values)
    return output,attention_weights

def create_parameters(embed_dim, num_heads):
    # Since we don't have the pretrained parameters here, we use torch.randn to create random parameters
    query_weights = [torch.randn(embed_dim, embed_dim) for _ in range(num_heads)]
    key_weights = [torch.randn(embed_dim, embed_dim) for _ in range(num_heads)]
    value_weights = [torch.randn(embed_dim, embed_dim) for _ in range(num_heads)]
    query_biases = [torch.randn(embed_dim) for _ in range(num_heads)]
    key_biases = [torch.randn(embed_dim) for _ in range(num_heads)]
    value_biases = [torch.randn(embed_dim) for _ in range(num_heads)]
    
    # Stack the weights and biases for easy access
    query_weights = torch.stack(query_weights)
    key_weights = torch.stack(key_weights)
    value_weights = torch.stack(value_weights)
    query_biases = torch.stack(query_biases)
    key_biases = torch.stack(key_biases)
    value_biases = torch.stack(value_biases)
    
    w0 = torch.randn(embed_dim*num_heads, embed_dim)

    return w0,query_weights, query_biases, key_weights, key_biases, value_weights, value_biases



def multi_head_attention(x, w0,num_heads, query_weights, query_biases, key_weights, key_biases, value_weights, value_biases):
    """
    Performs multi-head self-attention.
    
    Args:
        x: Input tensor of shape (batch_size, seq_length, embed_dim).
        num_heads: Number of attention heads.
        query_weights: Query weights tensor of shape (num_heads, embed_dim, embed_dim).
        query_biases: Query biases tensor of shape (num_heads, embed_dim).
        key_weights: Key weights tensor of shape (num_heads, embed_dim, embed_dim).
        key_biases: Key biases tensor of shape (num_heads, embed_dim).
        value_weights: Value weights tensor of shape (num_heads, embed_dim, embed_dim).
        value_biases: Value biases tensor of shape (num_heads, embed_dim).
        
    Returns:
        output: Attention output of shape (batch_size, seq_length, embed_dim).
        attention_weights: Attention weights of shape (batch_size, num_heads, seq_length, seq_length).
    """
    batch_size, seq_length, embed_dim = x.size()
    
    # Compute queries, keys, and values for each head
    queries = [linear_projection(x, query_weights[i], query_biases[i]) for i in range(num_heads)]
    keys = [linear_projection(x, key_weights[i], key_biases[i]) for i in range(num_heads)]
    values = [linear_projection(x, value_weights[i], value_biases[i]) for i in range(num_heads)]
    
    # Apply scaled dot-product attention for each head
    attention_outputs = []
    attention_weights = []
    for i in range(num_heads):
        head_attention_output, head_attention_weights = scaled_dot_product_attention(queries[i], keys[i], values[i], embed_dim)
        attention_outputs.append(head_attention_output)
        attention_weights.append(head_attention_weights)

# Concatenate the attention outputs from all heads
    concatenated_attention_outputs = torch.cat(attention_outputs, dim=-1)

# Apply a linear transformation to the concatenated attention outputs
#    w0 = torch.randn(concatenated_attention_outputs.size(-1), embed_dim)
    output = torch.matmul(concatenated_attention_outputs, w0)

# Stack the attention weights from all heads
    attention_weights = torch.stack(attention_weights, dim=1)

    return output, attention_weights

In [9]:
# Example 
batch_size = 1
seq_length = 1 
embed_dim = 5  
num_heads = 2 


# Create random input tensor (batch_size, seq_length, embed_dim)
x = torch.randn(batch_size, seq_length, embed_dim)
w0,query_weights, query_biases, key_weights, key_biases, value_weights, value_biases=create_parameters(
    embed_dim, num_heads)


In [10]:
# Apply multi-head self-attention
attended_values, attention_weights = multi_head_attention(
    x, w0,num_heads, query_weights, query_biases, key_weights, key_biases, value_weights, value_biases)

print("Input:")
print(x)
print("\nOutput:")
print(attended_values)



Input:
tensor([[[ 0.0168,  0.4185, -3.1416, -1.2179, -0.7450]]])

Output:
tensor([[[ 10.5150,   9.6042, -18.3407,  -6.6961, -10.1963]]])


In [11]:
X = torch.tensor([[[1, 2, 3]]], dtype=torch.float32)
def create_parameters():
    WQ1 = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
    WK1 = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32)
    WV1 = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=torch.float32)
    
    WQ2 = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32)
    WK2 = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=torch.float32)
    WV2 = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
    
    query_weights = torch.stack([WQ1, WQ2])
    key_weights = torch.stack([WK1, WK2])
    value_weights = torch.stack([WV1, WV2])
    
    query_biases = torch.zeros(2, 3)
    key_biases = torch.zeros(2, 3)
    value_biases = torch.zeros(2, 3)
    
    return query_weights, query_biases, key_weights, key_biases, value_weights, value_biases

num_heads = 2
query_weights, query_biases, key_weights, key_biases, value_weights, value_biases = create_parameters()

W0 = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
#run the function 
output, attention_weights = multi_head_attention(
    X, W0, num_heads, query_weights, query_biases, key_weights, key_biases, value_weights, value_biases)
print('Output:')
print(output)

Output:
tensor([[[3., 5., 4.]]])


# 2. Multi head hands on 

In [12]:
X = torch.tensor([1, 2, 3], dtype=torch.float32)

#parameters 
WQ1 = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
WK1 = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32)
WV1 = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=torch.float32)

WQ2 = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32)
WK2 = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=torch.float32)
WV2 = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)


# compute Q,K,V
Q1 = torch.matmul(X, WQ1)
K1 = torch.matmul(X, WK1)
V1 = torch.matmul(X, WV1)

Q2 = torch.matmul(X, WQ2)
K2 = torch.matmul(X, WK2)
V2 = torch.matmul(X, WV2)


In [13]:
Q1,Q2

(tensor([1., 2., 3.]), tensor([3., 1., 2.]))

In [14]:
K1,K2

(tensor([3., 1., 2.]), tensor([2., 3., 1.]))

In [15]:
V1,V2

(tensor([2., 3., 1.]), tensor([1., 2., 3.]))

In [16]:
#Attention score
score1 = torch.matmul(Q1.unsqueeze(0), K1.unsqueeze(1))
score2 = torch.matmul(Q2.unsqueeze(0), K2.unsqueeze(1))
score1,score2

(tensor([[11.]]), tensor([[11.]]))

In [17]:
# scaled
dk = K1.size(-1)
scaled_score1 = score1 / torch.sqrt(torch.tensor(dk, dtype=torch.float32))
scaled_score2 = score2 / torch.sqrt(torch.tensor(dk, dtype=torch.float32))
scaled_score1 ,scaled_score2

(tensor([[6.3509]]), tensor([[6.3509]]))

In [18]:
# compute the weight
weight1 = F.softmax(scaled_score1, dim=-1)
weight2 = F.softmax(scaled_score2, dim=-1)
weight1,weight2

(tensor([[1.]]), tensor([[1.]]))

In [19]:
# weighted sum
output1 = torch.matmul(weight1, V1.unsqueeze(0)).squeeze(0)
output2 = torch.matmul(weight2, V2.unsqueeze(0)).squeeze(0)
output1,output2

(tensor([2., 3., 1.]), tensor([1., 2., 3.]))

In [20]:
# concatenate
concat_output = torch.cat((output1, output2), dim=-1)
concat_output

tensor([2., 3., 1., 1., 2., 3.])

In [21]:
#final linear transformation
W0 = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
final_output = torch.matmul(concat_output.unsqueeze(0), W0).squeeze(0)
final_output

tensor([3., 5., 4.])

### 1.4 Use the number in the hands on example

In [22]:
X = torch.tensor([[[1, 2, 3]]], dtype=torch.float32)
print('Input:')
print(X)

output, attention_weights = multi_head_attention(
    X, W0, num_heads, query_weights, query_biases, key_weights, key_biases, value_weights, value_biases)
print('Output:')
print(output)

Input:
tensor([[[1., 2., 3.]]])
Output:
tensor([[[3., 5., 4.]]])
