In [6]:
!which python

import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
  %pip install datasets


/Users/dan/miniconda3/envs/ml/bin/python
ffmpeg                    4.3                  h0a44026_0    pytorch
torch                     1.13.1                   pypi_0    pypi
torchtext                 0.14.1                   pypi_0    pypi
torchvision               0.14.0                 py39_cpu    pytorch


In [74]:
import math
import random
import string
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchtext
from torchtext.vocab import vocab
import datasets
from collections import Counter, OrderedDict

device = 'cuda' if torch.cuda.is_available() else 'cpu'

DEFAULT_PARAMS = {
  'd_model': 64, # output size of all sub-layers and embedding layer
  'n_layer': 4, # number of decoder layers
  'n_head': 4, # number of attention heads in each decoder layer
}

class Transformer(nn.Module):
  def __init__(self, vocab, tokenizer, **kwargs):
    super(Transformer, self).__init__()
    self.d_model = kwargs.get('d_model', DEFAULT_PARAMS['d_model'])
    self.n_head = kwargs.get('n_head', DEFAULT_PARAMS['n_head'])
    self.n_layer = kwargs.get('n_layer', DEFAULT_PARAMS['n_layer'])    
    self.vocab = vocab
    self.tokenizer = tokenizer
    self.embedding = nn.Embedding(len(vocab), self.d_model)
    self.layer_norm = nn.LayerNorm(self.d_model)

    self.q_projs = nn.ModuleList()
    self.k_projs = nn.ModuleList()
    self.v_projs = nn.ModuleList()
    self.attn_outs = nn.ModuleList()
    self.ffns = nn.ModuleList()

    for n in range(self.n_layer):
      # create q, k, v projection layers for each attention head.
      self.q_projs.append(nn.ModuleList([nn.Linear(self.d_model, self.d_model // self.n_head) for h in range(self.n_head)]))
      self.k_projs.append(nn.ModuleList([nn.Linear(self.d_model, self.d_model // self.n_head) for h in range(self.n_head)]))
      self.v_projs.append(nn.ModuleList([nn.Linear(self.d_model, self.d_model // self.n_head) for h in range(self.n_head)]))
      self.attn_outs.append(nn.Linear(self.d_model, self.d_model)) # final linear layer in each attention sub-layer

      # Not explained in paper why inner layer has dimensionality 4 * d_model
      self.ffns.append(nn.Sequential(nn.Linear(self.d_model, self.d_model * 4), nn.ReLU(), nn.Linear(self.d_model * 4, self.d_model))) 


  def get_positional_encoding(self, seq_length, scale=10000):
    # scale 10000 used in paper - the wavelengths will form a geometric progression from 2pi to scale * 2pi
    positions = np.zeros((seq_length, self.d_model))

    for pos in range(seq_length):
      for i in range(self.d_model):
        if i % 2 == 0:
          positions[pos][i] = np.sin(pos/scale**((2 * i)/self.d_model))
        else:
          positions[pos][i] = np.cos(pos/scale**((2 * i)/self.d_model))

    return torch.as_tensor(positions, dtype=torch.float32)
  

  def plot_positional_encodings(self, encodings):
    fig, ax = plt.subplots(figsize=(20,20))
    ax.imshow(encodings)
    plt.show()


  # TODO: this is no good as a class method - changes are not reflected in already-trained models, would need to train again

  

  def positional_embedding(self, tokens):
    indices = torch.LongTensor(self.vocab(tokens))
    embeds = self.embedding(indices)
    positions = self.get_positional_encoding(len(tokens))
    return torch.add(embeds, positions)


  def attention(self, Q, K, V):
    scaled_QK = torch.div(Q @ K.transpose(0, 1), math.sqrt(K.shape[1]))
    # connections in QKt above the diagonal are illegal - represent a query attending to a key from a future position
    mask = torch.triu(torch.full((Q.shape[0], K.shape[0]), float('-inf')), 1)
    masked_QK = torch.add(scaled_QK, mask)
    return nn.functional.softmax(masked_QK, dim=1) @ V


  def multi_attention(self, input, layer_idx):
    attentions = []
    for h in range(self.n_head):
      Q = self.q_projs[layer_idx][h](input)
      K = self.k_projs[layer_idx][h](input)
      V = self.v_projs[layer_idx][h](input)
      attn = self.attention(Q, K, V)
      attentions.append(attn)

    cat = torch.cat(attentions, dim=1)
    return self.attn_outs[layer_idx](cat)


  def decoder_layer(self, input, layer_idx):
    attention = self.multi_attention(input, layer_idx)
    attention_with_residual = self.layer_norm(torch.add(attention, input))
    ffn_out = self.ffns[layer_idx](attention_with_residual)
    return self.layer_norm(torch.add(ffn_out, attention_with_residual))


  def decoder_stack(self, input):
    current_value = input
    for layer_idx in range(self.n_layer):
      current_value = self.decoder_layer(current_value, layer_idx)

    return current_value


  def forward(self, input):
    tokens = self.tokenizer(input)
    PE = self.positional_embedding(tokens)
    decoder_output = self.decoder_stack(PE)
    
    # final linear layer uses same weights as embedding, but scaled
    logits = nn.functional.linear(decoder_output, torch.div(self.embedding.weight, math.sqrt(self.d_model)))
    log_probs = nn.functional.log_softmax(logits, dim=1)
    return log_probs
  

In [91]:
def log_probs_to_tokens(model, log_probs, top_n=10):
  probs = torch.exp(log_probs)
  sorted_probs = torch.sort(probs, descending=True)
  predictions = []

  for token_probs, token_indices in zip(sorted_probs.values.numpy(force=True), sorted_probs.indices.numpy(force=True)):
    predictions.append([(token, prob) for token, prob in zip(model.vocab.lookup_tokens(token_indices[:top_n]), token_probs[:top_n])])

  return np.array(predictions)

In [71]:
class StringReverseDataset(torch.utils.data.Dataset):
  def __init__(self, n_items):
    self.items = [self.generate_item() for i in range(n_items)]

  def __len__(self):
    return len(self.items)

  def __getitem__(self, idx):
    return self.items[idx]

  def generate_item(self):
    chars = random.choices(string.ascii_lowercase, k=5)
    return ''.join(chars) + ''.join(reversed(chars))

def create_letter_transformer():
  letter_tokenizer = lambda str : [*str]
  alphabet = vocab(OrderedDict([letter, 1] for letter in string.ascii_lowercase))
  return Transformer(alphabet, letter_tokenizer, d_model=64, n_head=4, n_layer=4)

def train_string_reverser(model):
  train_dl = torch.utils.data.DataLoader(StringReverseDataset(800), batch_size=1, shuffle=True)
  loss_fn = nn.NLLLoss()

  # paper used variable LR, how to do that here?
  optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98))

  for epoch in range(4):
    running_loss = 0.0
    for i, data in enumerate(train_dl):
      optimizer.zero_grad()
      # TODO: make transformer work with batches
      # don't care about the prediction for the next (non-existent) token
      predictions = model(data[0])[:-1]
      targets = torch.tensor(model.vocab.lookup_indices([*data[0][1:]]))

      loss = loss_fn(predictions, targets)

      loss.backward()
      optimizer.step()
      running_loss += loss.item()
      if i % 100 == 0:    # print every 2000 mini-batches
          print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
          running_loss = 0.0




letter_tsfmr = create_letter_transformer()
train_string_reverser(letter_tsfmr)


[1,     1] loss: 0.051
[1,   101] loss: 3.364
[1,   201] loss: 3.256
[1,   301] loss: 3.200
[1,   401] loss: 3.150
[1,   501] loss: 3.139
[1,   601] loss: 3.104
[1,   701] loss: 2.999
[2,     1] loss: 0.026
[2,   101] loss: 2.693
[2,   201] loss: 2.662
[2,   301] loss: 2.544
[2,   401] loss: 2.427
[2,   501] loss: 2.388
[2,   601] loss: 2.254
[2,   701] loss: 2.125
[3,     1] loss: 0.020
[3,   101] loss: 1.994
[3,   201] loss: 1.956
[3,   301] loss: 1.919
[3,   401] loss: 1.858
[3,   501] loss: 1.845
[3,   601] loss: 1.857
[3,   701] loss: 1.814
[4,     1] loss: 0.020
[4,   101] loss: 1.691
[4,   201] loss: 1.712
[4,   301] loss: 1.664
[4,   401] loss: 1.752
[4,   501] loss: 1.699
[4,   601] loss: 1.693
[4,   701] loss: 1.623


In [124]:
def test_string_reverser(model):
  test_dl = torch.utils.data.DataLoader(StringReverseDataset(10), batch_size=1, shuffle=True)
  loss_fn = nn.NLLLoss()

  for i, data in enumerate(test_dl):
    predictions = model(data[0])[:-1]
    targets = torch.tensor(model.vocab.lookup_indices([*data[0][1:]]))
    loss = loss_fn(predictions, targets)

    token_preds = ''.join(log_probs_to_tokens(model, predictions, top_n=1)[:, 0, 0])

    # Predictions should be completely wrong for the first half of the sequence, and accurate for the second half. This verifies that masking is working.
    print(data[0][:-5], '', data[0][-5:])
    print('', token_preds[:-5], '', token_preds[-5:], '\n')

test_string_reverser(letter_tsfmr)

vakqn  nqkav
 iiit  nqkav 

agrmg  gmrga
 iiim  ggrga 

jtvvg  gvvtj
 iiiv  vvvtj 

doelz  zleod
 iiii  zleod 

lhsoa  aoshl
 iiii  aoshl 

chzuh  huzhc
 iiiu  uuzhc 

fbglj  jlgbf
 iiii  jlgbf 

qmzck  kczmq
 iiig  kczqq 

ejfqk  kqfje
 iiit  kqfje 

yihnn  nnhiy
 iiii  nnhiy 



In [None]:
# dataset and vocab setup
unk_token = '<unk>'

dataset = datasets.load_dataset('tiny_shakespeare', split='train')

# use my own implementation of BPE eventually
tokenizer = torchtext.data.get_tokenizer('basic_english')

token_counts = Counter()

for text in dataset['text']:
  tokens = tokenizer(text)
  for token in tokens:
    token_counts[token] += 1

sorted_by_freq_tuples = sorted(token_counts.items(), key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
# TODO: make my own implementation of vocab (tricky bit is looking up multiple indices at the same time efficiently)
voc = vocab(ordered_dict, specials=[unk_token], min_freq=20)
voc.set_default_index(voc[unk_token])


# start dict with unk and then fill with most common words until vocab size limit
# vocab = {'<UNK>': 0}
# vocab.update({word: idx + 1 for [idx, [word, count]] in enumerate(token_counts.most_common(vocab_size-1))})
print(len(voc))
