# self-attention

## setup

In [1]:
import numpy as np
import math
import copy
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = "cpu"

In [2]:
EMBED_SIZE = 8
VOCAB_SIZE = 11
HIDDEN_SIZE = 320
CONTEXT_SIZE = 10
MAGIC_TOKEN = VOCAB_SIZE-1
MAX_ITERS = 10000
LEARNING_RATE = 0.001

In [3]:
X = []
Y = []

for i in range(1000):
  magic_token_idx = random.randint(1, CONTEXT_SIZE/2 -1)
  x = [random.randint(1,VOCAB_SIZE-2) for _ in range(magic_token_idx)] + [MAGIC_TOKEN] + [0 for _ in range(CONTEXT_SIZE-magic_token_idx -1)]
  y = x[:magic_token_idx+1] + x[:magic_token_idx] + [0 for _ in range(CONTEXT_SIZE - 2*magic_token_idx -1)]
  X.append(x)
  Y.append(y)

X = torch.tensor(X).to(device)
Y= torch.tensor(Y).to(device)

## code

In [4]:
def get_training():
  X = torch.tensor([[0, 1, 2, 3],
                    [3, 2, 1, 0]])

  X = torch.randint(0, VOCAB_SIZE-2, (1000, CONTEXT_SIZE))

  Y = torch.ones_like(X) # TODO
  return X.to(device), Y.to(device)

In [5]:
class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embedding = torch.nn.Embedding(VOCAB_SIZE, EMBED_SIZE)

    self.positional_embedding = torch.nn.Embedding(CONTEXT_SIZE, EMBED_SIZE)

    self.w_key = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
    self.w_query = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)
    self.w_value = torch.nn.Linear(EMBED_SIZE, EMBED_SIZE, bias=False)

    self.emb_ln = nn.LayerNorm(EMBED_SIZE)

    self.ff = nn.Sequential(
      nn.Linear(EMBED_SIZE, HIDDEN_SIZE),
      nn.ReLU(),
      nn.Linear(HIDDEN_SIZE, EMBED_SIZE),
      nn.LayerNorm(EMBED_SIZE),
    )
    self.layers = nn.Sequential(
        nn.Linear(EMBED_SIZE, VOCAB_SIZE),
    )


  def forward(self, x):
    # x: (batch_size, context_size)
    x = self.token_embedding(x)  # (batch_size, context_size, embedding_size)
    # positional embedding
    x = x + self.positional_embedding(torch.arange(0, x.shape[1]).to(device))  # (batch_size, context_size, embedding_size)

    # generate K,Q,V
    key = self.w_key(x) #(batch_size, context_size, embedding_size) @ (embedding_size, embedding_size) ---> (batch_size, context_size, embedding_size)
    query = self.w_query(x)
    value = self.w_value(x)
    # do the attention
    correlation = query @ key.transpose(-2, -1)
    correlation = correlation / math.sqrt(key.shape[-1])
    new_embedding = correlation.softmax(-1) @ value

    x = x + self.emb_ln(new_embedding)

    # layernorm + MLP
    fed = self.ff(x)
    x = x + fed
    x = self.layers(x)
    return x

In [6]:
model = Net().to(device)
# X, Y = get_training()
# model(X).argmax(dim=-1)

In [7]:
### train
opt = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(MAX_ITERS):
  out = model(X)
  loss = F.cross_entropy(out, F.one_hot(Y, VOCAB_SIZE).float())
  opt.zero_grad()
  loss.backward()
  opt.step()
  if epoch % 500 == 0:
    print(f'{epoch:5} {loss.item()}')


    0 2.3383076190948486
  500 1.1538957357406616
 1000 1.0742956399917603
 1500 1.0440741777420044
 2000 1.0273371934890747
 2500 1.0176148414611816
 3000 1.0117547512054443
 3500 1.0040936470031738
 4000 0.9944692850112915
 4500 0.9918068051338196
 5000 0.9960141777992249
 5500 0.9873514771461487
 6000 0.9877238869667053
 6500 0.9844509959220886
 7000 0.9845349192619324
 7500 0.9810346961021423
 8000 0.980131208896637
 8500 0.9796953201293945
 9000 0.9905064105987549
 9500 0.976533830165863


In [9]:
model(X).argmax(-1).tolist()[:10]

[[9, 0, 0, 10, 1, 0, 1, 0, 0, 0],
 [0, 6, 0, 0, 0, 0, 6, 0, 0, 0],
 [0, 0, 0, 10, 1, 0, 0, 0, 0, 0],
 [0, 0, 0, 1, 0, 0, 6, 0, 0, 0],
 [6, 0, 10, 1, 1, 0, 0, 0, 0, 0],
 [0, 10, 6, 0, 0, 0, 0, 0, 0, 0],
 [0, 10, 6, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 1, 0, 0, 1, 6, 0, 0, 0],
 [0, 9, 8, 1, 0, 1, 6, 0, 0, 0],
 [9, 0, 10, 1, 4, 0, 0, 0, 0, 0]]