# self-attention

## setup

In [28]:
import numpy as np
import math
import copy
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = "cpu"

In [29]:
EMBED_SIZE = 8
VOCAB_SIZE = 11
HIDDEN_SIZE = 32
CONTEXT_SIZE = 10
MAGIC_TOKEN = VOCAB_SIZE-1
MAX_ITERS = 10000
LEARNING_RATE = 0.001

In [30]:
X = []
Y = []

for i in range(1000):
  magic_token_idx = random.randint(1, CONTEXT_SIZE/2 -1)
  x = [random.randint(1,VOCAB_SIZE-2) for _ in range(magic_token_idx)] + [MAGIC_TOKEN] + [0 for _ in range(CONTEXT_SIZE-magic_token_idx -1)]
  y = x[:magic_token_idx+1] + x[:magic_token_idx] + [0 for _ in range(CONTEXT_SIZE - 2*magic_token_idx -1)]
  X.append(x)
  Y.append(y)

X = torch.tensor(X).to(device)
Y= torch.tensor(Y).to(device)

## code

In [31]:
def get_training():
  X = torch.tensor([[0, 1, 2, 3],
                    [3, 2, 1, 0]])

  X = torch.randint(0, VOCAB_SIZE-2, (1000, CONTEXT_SIZE))

  Y = torch.ones_like(X) # TODO
  return X.to(device), Y.to(device)

In [32]:
# class Attention(nn.Module):
#   def __init__(self):
#     super().__init__()
#     self.w_key = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
#     self.w_query = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
#     self.w_value = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)

#   def forward(self, x):
#     # generate K,Q,V
#     key = self.w_key(x) #(batch_size, context_size, embedding_size) @ (embedding_size, embedding_size) ---> (batch_size, context_size, embedding_size)
#     query = self.w_query(x)
#     value = self.w_value(x)
#     # do the attention
#     correlation = query @ key.transpose(-2, -1)
#     correlation = correlation / math.sqrt(key.shape[-1])
#     new_embedding = correlation.softmax(-1) @ value


class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embedding = torch.nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
    self.positional_embedding = torch.nn.Embedding(CONTEXT_SIZE, EMBED_SIZE)

    self.w_key = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
    self.w_query = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
    self.w_value = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)

    self.emb_ln = nn.LayerNorm(EMBED_SIZE)

    self.ff = nn.Sequential(
      nn.Linear(EMBED_SIZE, HIDDEN_SIZE),
      nn.ReLU(),
      nn.Linear(HIDDEN_SIZE, EMBED_SIZE),
      nn.LayerNorm(EMBED_SIZE),
    )
    self.layers = nn.Sequential(
        nn.Linear(EMBED_SIZE, VOCAB_SIZE),
    )


  def forward(self, x):
    # x: (batch_size, context_size)
    x = self.token_embedding(x)  # (batch_size, context_size, embedding_size)
    # positional embedding
    x = x + self.positional_embedding(torch.arange(0, x.shape[1]).to(device))  # (batch_size, context_size, embedding_size)

    # generate K,Q,V
    key = self.w_key(x) #(batch_size, context_size, embedding_size) @ (embedding_size, embedding_size) ---> (batch_size, context_size, embedding_size)
    query = self.w_query(x)
    value = self.w_value(x)
    # do the attention
    correlation = query @ key.transpose(-2, -1)
    correlation = correlation / math.sqrt(key.shape[-1])
    new_embedding = correlation.softmax(-1) @ value

    x = x + self.emb_ln(new_embedding)

    # layernorm + MLP
    fed = self.ff(x)
    x = x + fed
    x = self.layers(x)
    return x

In [33]:
k = torch.arange(2*4*8).view([2, 4, 8]).float()
# k = torch.ones([2, 4, 8])
q = torch.ones([2, 4, 8])

# Why does this match ?!
# ----------------------
q @ k.transpose(-2, -1)
print(f'{k.shape=} {k.transpose(-2, -1).shape=} {(q @ k.transpose(-2, -1)).shape=}')

k, k.transpose(-2, -1), q @ k.transpose(-2, -1)

k.shape=torch.Size([2, 4, 8]) k.transpose(-2, -1).shape=torch.Size([2, 8, 4]) (q @ k.transpose(-2, -1)).shape=torch.Size([2, 4, 4])


(tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11., 12., 13., 14., 15.],
          [16., 17., 18., 19., 20., 21., 22., 23.],
          [24., 25., 26., 27., 28., 29., 30., 31.]],
 
         [[32., 33., 34., 35., 36., 37., 38., 39.],
          [40., 41., 42., 43., 44., 45., 46., 47.],
          [48., 49., 50., 51., 52., 53., 54., 55.],
          [56., 57., 58., 59., 60., 61., 62., 63.]]]),
 tensor([[[ 0.,  8., 16., 24.],
          [ 1.,  9., 17., 25.],
          [ 2., 10., 18., 26.],
          [ 3., 11., 19., 27.],
          [ 4., 12., 20., 28.],
          [ 5., 13., 21., 29.],
          [ 6., 14., 22., 30.],
          [ 7., 15., 23., 31.]],
 
         [[32., 40., 48., 56.],
          [33., 41., 49., 57.],
          [34., 42., 50., 58.],
          [35., 43., 51., 59.],
          [36., 44., 52., 60.],
          [37., 45., 53., 61.],
          [38., 46., 54., 62.],
          [39., 47., 55., 63.]]]),
 tensor([[[ 28.,  92., 156., 220.],
          [ 28.,  92., 1

In [34]:
model = Net().to(device)
# X, Y = get_training()
# model(X).argmax(dim=-1)

In [35]:
### train
opt = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(MAX_ITERS):
  out = model(X)

  # Why are these 2 not equivalent?
  # -------------------------------
  # loss = F.cross_entropy(out, F.one_hot(Y, VOCAB_SIZE).float())
  loss = F.cross_entropy(out.view(-1, out.shape[-1]), Y.view(-1))

  opt.zero_grad()
  loss.backward()
  opt.step()
  if epoch % 500 == 0:
    print(f'{epoch:5} {loss.item()}')


    0 2.904038190841675


  500 0.44902303814888
 1000 0.17360185086727142
 1500 0.09929771721363068
 2000 0.06542648375034332
 2500 0.045226242393255234
 3000 0.03154982998967171
 3500 0.023086871951818466
 4000 0.016843700781464577
 4500 0.014822080731391907
 5000 0.009837410412728786
 5500 0.007400656118988991
 6000 0.007868007756769657
 6500 0.005057420581579208
 7000 0.0038744830526411533
 7500 0.0029830269049853086
 8000 0.0022892651613801718
 8500 0.0037419490981847048
 9000 0.0024802726693451405
 9500 0.002011876553297043


In [36]:
model(X).argmax(-1).tolist()[:10]

[[3, 6, 5, 1, 10, 3, 6, 5, 1, 0],
 [9, 7, 8, 9, 10, 9, 7, 8, 9, 0],
 [7, 2, 4, 10, 7, 2, 4, 0, 0, 0],
 [3, 1, 10, 3, 1, 0, 0, 0, 0, 0],
 [5, 2, 10, 5, 2, 0, 0, 0, 0, 0],
 [9, 1, 2, 4, 10, 9, 1, 2, 4, 0],
 [8, 4, 10, 8, 4, 0, 0, 0, 0, 0],
 [9, 6, 10, 9, 6, 0, 0, 0, 0, 0],
 [8, 1, 6, 2, 10, 8, 1, 6, 2, 0],
 [7, 8, 10, 7, 8, 0, 0, 0, 0, 0]]