In [30]:
import math
import numpy as np
import torch
import torch.nn as nn

## Self Attention



In [47]:
class SelfAttention(nn.Module):
    def __init__(self, d_model: int):
        super(SelfAttention, self).__init__()
        self.d_model = d_model
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

    def attention(
        self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
    ) -> torch.Tensor:
        d_k = Q.shape[-1]
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        attn_probs = torch.softmax(attn_scores, dim=1)
        attention_output = torch.matmul(attn_probs, V)

        return attention_output

    def forward(
        self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
    ) -> torch.Tensor:
        Q_prime = self.W_q(Q)
        K_prime = self.W_k(K)
        V_prime = self.W_v(V)

        return self.attention(Q_prime, K_prime, V_prime)


# Create test inputs
batch_size = 1
seq_length = 4
d_model = 8

# Create random input tensors
x = torch.randn(batch_size, seq_length, d_model)

# Initialize the self-attention module
self_attention = SelfAttention(d_model)

# Pass the same tensor as Q, K, and V (self-attention)
output = self_attention(x, x, x)

# Print shapes to verify
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

# Print a sample of the output
print("\nSample of output:")
print(output[0][0])  # First 5 values of the first sequence in the first batch

Input shape: torch.Size([1, 4, 8])
Output shape: torch.Size([1, 4, 8])

Sample of output:
tensor([-0.2497, -0.2451,  0.0241, -0.4867,  0.1290,  0.2022,  0.1524, -0.3079],
       grad_fn=<SelectBackward0>)


In [12]:
def self_attention(Q, K, V):
    return torch.softmax(Q @ K.T, dim=1) @ V


tokens = 5
embedding_dim = 10

input_seq = torch.randn(tokens, embedding_dim)

Q = input_seq.clone()
K = input_seq.clone()
V = input_seq.clone()

attention_output = self_attention(Q, K, V)

print(attention_output.shape)

torch.Size([5, 10])


In [None]:
def multi_head_attention(Q, K, V, num_heads):
    d_k = Q.shape[1] // num_heads