<a href="https://colab.research.google.com/github/ounospanas/AIDL_B_CS01/blob/main/self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The provided code is inspired by the book "Build a Large Language Model From Scratch" by Sebastian Raschka

For more please have a look at the corresponding repo: https://github.com/rasbt/LLMs-from-scratch

Self-attention network

In [67]:
import torch
import torch.nn as nn

class SelfAttention_v1(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_query = nn.Parameter(torch.rand(d_in, d_out))
    self.W_key = nn.Parameter(torch.rand(d_in, d_out))
    self.W_value = nn.Parameter(torch.rand(d_in, d_out))
  def forward(self, x):
    keys = x @ self.W_key
    queries = x @ self.W_query
    values = x @ self.W_value
    attn_scores = queries @ keys.T # omega
    attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)
    context_vec = attn_weights @ values

    return context_vec, attn_scores


In [68]:
# assume these as embeddings of the sentence: "Your journey starts with one step"
input_embeddings = torch.tensor(
  [[0.43, 0.15, 0.89], # Your (x^1)
  [0.55, 0.87, 0.66], # journey (x^2)
  [0.57, 0.85, 0.64], # starts (x^3)
  [0.22, 0.58, 0.33], # with (x^4)
  [0.77, 0.25, 0.10], # one (x^5)
  [0.05, 0.80, 0.55]] # step (x^6)
  )

In [69]:
# create positional embeddings
torch.manual_seed(123)
context_length = len(input_embeddings)
output_dim = 3

pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)
pos_embeddings = pos_embedding_layer(torch.arange(context_length))
print(pos_embeddings)

tensor([[ 0.3374, -0.1778, -0.1690],
        [ 0.9178,  1.5810,  1.3010],
        [ 1.2753, -0.2010, -0.1606],
        [-0.4015,  0.9666, -1.1481],
        [-1.1589,  0.3255, -0.6315],
        [-2.8400, -0.7849, -1.4096]], grad_fn=<EmbeddingBackward0>)


In [70]:
# element-wise add embeddings
embeddings = pos_embeddings + input_embeddings

In [71]:
# initialize network
torch.manual_seed(123)
sam = SelfAttention_v1(3,2)

In [72]:
output, attn_scores = sam(embeddings)
attn_scores

tensor([[ 6.4450e-01,  3.6325e+00,  1.1351e+00,  5.4441e-01, -1.8369e-02,
         -1.0449e+00],
        [ 2.6872e+00,  1.5109e+01,  4.7291e+00,  2.2521e+00, -8.3427e-02,
         -4.3646e+00],
        [ 1.2493e+00,  6.8984e+00,  2.1869e+00,  9.8685e-01, -6.1918e-02,
         -2.0553e+00],
        [ 2.3515e-01,  1.2088e+00,  4.0330e-01,  1.4285e-01, -2.8145e-02,
         -4.0548e-01],
        [-1.5004e-01, -8.8163e-01, -2.6758e-01, -1.4395e-01, -2.3411e-03,
          2.3579e-01],
        [-1.4948e+00, -8.2574e+00, -2.6169e+00, -1.1825e+00,  7.3384e-02,
          2.4582e+00]], grad_fn=<MmBackward0>)

In [73]:
# add causal attention
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[ 6.4450e-01,        -inf,        -inf,        -inf,        -inf,
                -inf],
        [ 2.6872e+00,  1.5109e+01,        -inf,        -inf,        -inf,
                -inf],
        [ 1.2493e+00,  6.8984e+00,  2.1869e+00,        -inf,        -inf,
                -inf],
        [ 2.3515e-01,  1.2088e+00,  4.0330e-01,  1.4285e-01,        -inf,
                -inf],
        [-1.5004e-01, -8.8163e-01, -2.6758e-01, -1.4395e-01, -2.3411e-03,
                -inf],
        [-1.4948e+00, -8.2574e+00, -2.6169e+00, -1.1825e+00,  7.3384e-02,
          2.4582e+00]], grad_fn=<MaskedFillBackward0>)


In [83]:
import torch
import torch.nn as nn

class CausalSelfAttention_v1(nn.Module):
  def __init__(self, d_in, d_out, dropout):
    super().__init__()
    self.W_query = nn.Linear(d_in, d_out)
    self.W_key = nn.Linear(d_in, d_out)
    self.W_value = nn.Linear(d_in, d_out)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)
    attn_scores = queries @ keys.T # if you have batch you have to go with queries @ keys.transpose(1, 2)
    mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1)
    masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
    attn_weights = torch.softmax(masked / keys.shape[1]**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)
    context_vec = attn_weights @ values

    return context_vec


In [84]:
# initialize network
torch.manual_seed(123)
cam = CausalSelfAttention_v1(3,2, 0.5)

In [87]:
output = cam(embeddings)

In [88]:
class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self, d_in, d_out, dropout, num_heads):
    super().__init__()
    self.heads = nn.ModuleList([CausalSelfAttention_v1(d_in, d_out, dropout) for _ in range(num_heads)])

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim=-1)

In [93]:
# initialize network
torch.manual_seed(123)
mha = MultiHeadAttentionWrapper(3, 2, 0.5, 10)


In [94]:
output = mha.forward(embeddings)