# self-attention

## setup

In [118]:
import numpy as np
import math
import copy
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = "cpu"

In [119]:
EMBED_SIZE = 8
VOCAB_SIZE = 11
HIDDEN_SIZE = 32
CONTEXT_SIZE = 10
MAGIC_TOKEN = VOCAB_SIZE-1
MAX_ITERS = 10000
LEARNING_RATE = 0.001

In [120]:
X = []
Y = []

for i in range(1000):
  magic_token_idx = random.randint(1, CONTEXT_SIZE/2 -1)
  x = [random.randint(1,VOCAB_SIZE-2) for _ in range(magic_token_idx)] + [MAGIC_TOKEN] + [0 for _ in range(CONTEXT_SIZE-magic_token_idx -1)]
  y = x[:magic_token_idx+1] + x[:magic_token_idx] + [0 for _ in range(CONTEXT_SIZE - 2*magic_token_idx -1)]
  X.append(x)
  Y.append(y)

X = torch.tensor(X).to(device)
Y= torch.tensor(Y).to(device)

## code

In [121]:
def get_training():
  X = torch.tensor([[0, 1, 2, 3],
                    [3, 2, 1, 0]])

  X = torch.randint(0, VOCAB_SIZE-2, (1000, CONTEXT_SIZE))

  Y = torch.ones_like(X) # TODO
  return X.to(device), Y.to(device)

In [122]:
class Attention(nn.Module):
  def __init__(self):
    super().__init__()
    self.w_key = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
    self.w_query = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
    self.w_value = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
    self.ln = nn.LayerNorm(EMBED_SIZE)

  def forward(self, x):
    # generate K,Q,V
    key = self.w_key(x)
    query = self.w_query(x)
    value = self.w_value(x)
    # do the attention
    correlation = query @ key.transpose(-2, -1)
    correlation = correlation / math.sqrt(key.shape[-1])
    new_embedding = correlation.softmax(-1) @ value
    # layer norm
    new_embedding = self.ln(new_embedding)
    return new_embedding

class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embedding = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
    self.positional_embedding = nn.Embedding(CONTEXT_SIZE, EMBED_SIZE)
    self.attention = Attention()
    self.ff = nn.Sequential(
      nn.Linear(EMBED_SIZE, HIDDEN_SIZE),
      nn.ReLU(),
      nn.Linear(HIDDEN_SIZE, EMBED_SIZE),
      nn.LayerNorm(EMBED_SIZE),
    )
    self.head = nn.Linear(EMBED_SIZE, VOCAB_SIZE)

  def forward(self, x):
    # x: (batch_size, context_size)
    x = self.token_embedding(x)  # (batch_size, context_size, embedding_size)
    # positional encoding
    x = x + self.positional_embedding(torch.arange(0, x.shape[1]).to(device))
    # attention
    x = x + self.attention(x)
    # feed forward
    x = x + self.ff(x)
    # head
    x = self.head(x)
    return x

In [124]:
model = Net().to(device)

In [128]:
### train
opt = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(MAX_ITERS):
  out = model(X)
  loss = F.cross_entropy(out.view(-1, out.shape[-1]), Y.view(-1))
  opt.zero_grad()
  loss.backward()
  opt.step()
  if epoch % 500 == 0:
    print(f'{epoch:5} {loss.item()}')

    0 0.05184357240796089
  500 0.048021651804447174
 1000 0.046429093927145004
 1500 0.044264405965805054
 2000 0.043660808354616165
 2500 0.036411285400390625
 3000 0.03408998250961304
 3500 0.030521320179104805
 4000 0.027178823947906494
 4500 0.024127095937728882
 5000 0.020261602476239204
 5500 0.01299357134848833
 6000 0.010142686776816845
 6500 0.008977246470749378
 7000 0.0033558951690793037
 7500 0.008173685520887375
 8000 0.0026782879140228033
 8500 0.001965648028999567
 9000 0.0015759249217808247
 9500 0.0011593375820666552


In [143]:
res = model(X).argmax(-1)
correct = ((res == Y).float().sum(-1) == CONTEXT_SIZE).float().sum()
print(f'accuracy: {correct / len(X)}')
res[:20]

accuracy: 1.0


tensor([[ 1, 10,  1,  0,  0,  0,  0,  0,  0,  0],
        [ 6,  1,  7,  2, 10,  6,  1,  7,  2,  0],
        [ 4,  6, 10,  4,  6,  0,  0,  0,  0,  0],
        [ 9,  1,  4,  5, 10,  9,  1,  4,  5,  0],
        [ 5,  1,  4,  5, 10,  5,  1,  4,  5,  0],
        [ 8,  7,  1, 10,  8,  7,  1,  0,  0,  0],
        [ 4,  7, 10,  4,  7,  0,  0,  0,  0,  0],
        [ 6, 10,  6,  0,  0,  0,  0,  0,  0,  0],
        [ 5,  9, 10,  5,  9,  0,  0,  0,  0,  0],
        [ 5,  8,  8,  9, 10,  5,  8,  8,  9,  0],
        [ 7, 10,  7,  0,  0,  0,  0,  0,  0,  0],
        [ 7,  3,  3,  1, 10,  7,  3,  3,  1,  0],
        [ 9, 10,  9,  0,  0,  0,  0,  0,  0,  0],
        [ 4,  8, 10,  4,  8,  0,  0,  0,  0,  0],
        [ 8, 10,  8,  0,  0,  0,  0,  0,  0,  0],
        [ 4,  2,  6, 10,  4,  2,  6,  0,  0,  0],
        [ 6,  5,  1, 10,  6,  5,  1,  0,  0,  0],
        [ 8,  1,  5, 10,  8,  1,  5,  0,  0,  0],
        [ 7,  3,  8, 10,  7,  3,  8,  0,  0,  0],
        [ 1,  6,  5,  8, 10,  1,  6,  5,  8,  0]])