# self-attention

## setup

In [285]:
import numpy as np
import math
import copy
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [286]:
EMBED_SIZE = 8
VOCAB_SIZE = 11
HIDDEN_SIZE = 32
CONTEXT_SIZE = 10
MAGIC_TOKEN = VOCAB_SIZE-1
EPOCHS = 10000
LEARNING_RATE = 3e-4

In [287]:
X = []
Y = []

for i in range(1000):
  magic_token_idx = random.randint(1, CONTEXT_SIZE/2 -1)
  x = [random.randint(1,VOCAB_SIZE-2) for _ in range(magic_token_idx)] + [MAGIC_TOKEN] + [0 for _ in range(CONTEXT_SIZE - magic_token_idx - 1)]
  y = x[:magic_token_idx+1] + x[:magic_token_idx] + [0 for _ in range(CONTEXT_SIZE - 2 * magic_token_idx - 1)]
  X.append(x)
  Y.append(y)

X = torch.tensor(X).to(device)
Y= torch.tensor(Y).to(device)

## code

In [288]:
def get_training():
  X = torch.tensor([[0, 1, 2, 3],
                    [3, 2, 1, 0]])

  X = torch.randint(0, VOCAB_SIZE-2, (1000, CONTEXT_SIZE))

  Y = torch.ones_like(X) # TODO
  return X.to(device), Y.to(device)

In [289]:
class Attention(nn.Module):
  def __init__(self):
    super().__init__()
    self.w_key = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
    self.w_query = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
    self.w_value = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
    self.ln = nn.LayerNorm(EMBED_SIZE)

  def forward(self, x):
    # generate K,Q,V
    key = self.w_key(x)
    query = self.w_query(x)
    value = self.w_value(x)
    # pre-layernorm
    # x = self.ln(x)
    # do the attention
    correlation = query @ key.transpose(-2, -1)
    correlation = correlation / math.sqrt(key.shape[-1])
    new_embedding = correlation.softmax(-1) @ value
    # post-layernorm
    new_embedding = self.ln(new_embedding)
    return new_embedding

class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embedding = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
    self.positional_embedding = nn.Embedding(CONTEXT_SIZE, EMBED_SIZE)
    self.attention = Attention()
    self.ff = nn.Sequential(
      nn.Linear(EMBED_SIZE, HIDDEN_SIZE),
      nn.ReLU(),
      nn.Linear(HIDDEN_SIZE, EMBED_SIZE),
      nn.LayerNorm(EMBED_SIZE),
    )
    self.head = nn.Linear(EMBED_SIZE, VOCAB_SIZE)

  def forward(self, x):
    # (batch_size, context_size)
    x = self.token_embedding(x)  # (batch_size, context_size, embedding_size)
    # positional encoding
    x = x + self.positional_embedding(torch.arange(0, x.shape[1]).to(device))
    # attention
    x = x + self.attention(x)
    # feed forward
    x = x + self.ff(x)
    # head
    x = self.head(x)
    return x

In [290]:
model = Net().to(device)

In [291]:
def train(model, epochs=EPOCHS, lr=LEARNING_RATE):
  opt = torch.optim.Adam(model.parameters(), lr=lr)

  for epoch in range(epochs):
    out = model(X)
    loss = F.cross_entropy(out.view(-1, out.shape[-1]), Y.view(-1))
    opt.zero_grad()
    loss.backward()
    opt.step()
    if epoch % 500 == 0:
      print(f'{epoch:5} {loss.item()}')

train(model)

    0 2.5265917778015137
  500 0.9128351807594299
 1000 0.4808105528354645
 1500 0.29808908700942993
 2000 0.21623677015304565
 2500 0.16821356117725372
 3000 0.13157570362091064
 3500 0.10579367727041245
 4000 0.08433056622743607
 4500 0.06761296838521957
 5000 0.053498268127441406
 5500 0.04267818480730057
 6000 0.03486316278576851
 6500 0.028382284566760063
 7000 0.02384473755955696
 7500 0.020177608355879784
 8000 0.017205029726028442
 8500 0.015238409861922264
 9000 0.012392907403409481
 9500 0.011136529967188835


In [292]:
res = model(X).argmax(-1)
correct = ((res == Y).float().sum(-1) == CONTEXT_SIZE).float().sum()
print(f'accuracy: {correct / len(X)}')
res[:20]

accuracy: 0.9850000739097595


tensor([[ 9,  1,  1,  7, 10,  9,  1,  1,  7,  0],
        [ 8,  4,  5,  3, 10,  8,  4,  5,  3,  0],
        [ 6, 10,  6,  0,  0,  0,  0,  0,  0,  0],
        [ 2,  9, 10,  2,  9,  0,  0,  0,  0,  0],
        [ 9,  9, 10,  9,  9,  0,  0,  0,  0,  0],
        [ 3,  4, 10,  3,  4,  0,  0,  0,  0,  0],
        [ 6,  1,  9,  9, 10,  6,  1,  9,  9,  0],
        [ 9,  8,  2, 10,  9,  8,  2,  0,  0,  0],
        [ 9,  6,  8, 10,  9,  6,  8,  0,  0,  0],
        [ 6,  9, 10,  6,  9,  0,  0,  0,  0,  0],
        [ 6,  4,  2, 10,  4,  4,  2,  0,  0,  0],
        [ 1, 10,  1,  0,  0,  0,  0,  0,  0,  0],
        [ 7,  1,  7, 10,  7,  1,  7,  0,  0,  0],
        [ 2,  9, 10,  2,  9,  0,  0,  0,  0,  0],
        [ 8,  4,  2,  3, 10,  8,  4,  2,  3,  0],
        [ 1,  2, 10,  1,  2,  0,  0,  0,  0,  0],
        [ 7,  5,  5, 10,  7,  5,  5,  0,  0,  0],
        [ 2,  7,  9, 10,  2,  7,  9,  0,  0,  0],
        [ 6,  4, 10,  6,  4,  0,  0,  0,  0,  0],
        [ 2,  4,  7, 10,  2,  4,  7,  0,  0,  0]],