In [None]:
# Wrap outputs in Colab notebook
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
# Testing vanilla Transformer architecture for Autoregressive Language Generation

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
! pip install datasets

Collecting datasets
  Downloading datasets-2.14.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.3/519.3 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0.0,>=0.14.0 (from datasets)
  Downloading huggingface_hub-0.16.4-py3-none-a

In [None]:
from datasets import load_dataset
dataset = load_dataset("roneneldan/TinyStories")

Downloading readme:   0%|          | 0.00/946 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/249M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

In [None]:
import torch
from torchtext import datasets
from torch import nn
import random
import math
import numpy as np
import tqdm
import copy

In [None]:
# Architecture hyperparams
# SPECIAL END TOKEN | CANNOT BE IN SPECIAL_CHAR
accepted_chars = [' ', '!', '"', "'", ',', '.', '\n', ':', ';', '?']
end_word_token = "|"

num_chars = len(accepted_chars)
train_batch_sz = 40
macro_batch_sz = 8 # For Gradient Accumulation to train with larger batch sizes without overflowing memory
max_truncate_len = 512

# Model sized according to Chinchilla scaling laws
head_sz = 64
head_n = 7
embed_sz = head_n * head_sz
layer_n = 8
fcnn_sz = 4 * embed_sz
# dropout_p = 0.3

cycles = 10000

# Epochs to train model for
epochs = 3

# Drop characters from string that aren't in accepted_chars
def filter_chars(string):
  return ''.join([char for char in string if char in accepted_chars or char.isalpha()])

In [None]:
# Load in Byte Pair Encoding tokenizer (pretrained using BPE file)
ind_to_char = []
char_to_ind = {}

with open("/content/ch2i.txt") as f:
  char_to_ind = eval(''.join(f.readlines()))

with open("/content/i2ch.txt") as f:
  ind_to_char = eval(''.join(f.readlines()))

# Padding character
ind_to_char.append("<eos>")

vocab_sz = len(ind_to_char)
pad_token = vocab_sz - 1

In [None]:
# Encoding
def encode_punc(word):
  if len(word) == 0:
    return ""

  i = -1
  while not word[i].isalpha():
    i -= 1

    if -i > len(word):
      return word

  if i == -1:
    return word + end_word_token
  return word[: i + 1] + end_word_token + word[i + 1 :]

def encode_processed(word):
  if len(word) == 0 or word == end_word_token:
    return []
  for i in range(len(word), 0, -1):
    if word[:i] in char_to_ind.keys():
      res = [char_to_ind[word[:i]]]
      res.extend(encode_processed(word[i:]))
      return res

def encode(word):
  if len(word) == 0:
    return []

  if word == "<eos>":
    return [pad_token]

  if word.find("\n") != -1:
    res = encode(word[:word.find("\n")])
    res.append(6)
    res.extend(encode(word[word.find("\n") + 1:]))
    return res

  word = encode_punc(word)
  return encode_processed(word)

# Decoding
def decode(seq):
  res = ""
  for i in seq:
    if i >= len(accepted_chars) and end_word_token in ind_to_char[i]:
      res += " "
    res += ind_to_char[i]
  res = res.replace(end_word_token, "")
  res = res.replace("<eos>", "\n\n")
  return res

# Decoding, but print token boundaries
def dbg_decode(seq):
  res = ""
  for i in seq:
    if i >= len(accepted_chars) and end_word_token in ind_to_char[i]:
      res += " "
    res += "[" + ind_to_char[i] + "]"
  res = res.replace(end_word_token, "")
  return res

In [None]:
# Load in and encode the training data
data_sz = 1000000

data_tokenized = []
data_tokenized.append([])

with tqdm.tqdm(total = data_sz) as t:
  for i in range(data_sz):
    example = filter_chars(dataset['train'][i]['text']).split(" ")
    example.append("<eos>")
    for word in example:
      if len(word) > 250:
        continue

      for j in encode(word):
        if len(data_tokenized[-1]) == max_truncate_len:
          data_tokenized.append([])
        data_tokenized[-1].append(j)
    t.update(1)

100%|██████████| 1000000/1000000 [15:19<00:00, 1087.82it/s]


In [None]:
# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f'Using Device {device}')

Using Device cuda


In [None]:
# Single Attention Operation
class SelfAttentionHead(nn.Module):
  def __init__(self, head_sz, token_unique_n, embed_sz):
    super().__init__()
    self.head_sz = head_sz
    self.embed_sz = embed_sz

    self.q = nn.Linear(embed_sz, head_sz, bias = False)
    self.k = nn.Linear(embed_sz, head_sz, bias = False)
    self.v = nn.Linear(embed_sz, head_sz, bias = False)
    self.dropout = nn.Dropout(0.1)

  def forward(self, inp):
    query = self.q(inp)
    key = self.k(inp)
    value = self.v(inp)

    attention = query @ key.transpose(1,2) / math.sqrt(self.head_sz)
    # Masked attention for autoregressive decoder layer
    attention = torch.tril(attention) + torch.triu(torch.ones_like(attention) * float('-inf'), diagonal = 1)
    attention = self.dropout(nn.functional.softmax(attention, dim = -1)) @ value
    return attention

# Multihead Attention Block
class MultiHeadSelfAttention(nn.Module):
  def __init__(self, head_sz, head_n, token_unique_n, embed_sz):
    super().__init__()
    self.head_sz = head_sz
    self.embed_sz = embed_sz
    self.head_n = head_n

    self.heads = nn.ModuleList([SelfAttentionHead(head_sz, token_unique_n, embed_sz).to(device) for _ in range(int(head_n))])
    self.dropout = nn.Dropout(0.1)
    self.out = nn.Linear(int(self.head_n) * head_sz, embed_sz)

  def forward(self, inp):
    # Concatenate outputs of each attention head and project it back to the embedding space
    return self.dropout(self.out(torch.cat([head(inp) for head in self.heads], dim = -1)))

# Transformer Block
class TransformerBlock(nn.Module):
  def __init__(self, head_sz, head_n, token_unique_n, embed_sz, fcnn_sz):
    super().__init__()
    self.head_sz = head_sz
    self.embed_sz = embed_sz
    self.head_n = head_n

    self.attention = MultiHeadSelfAttention(head_sz, head_n, token_unique_n, embed_sz).to(device)
    self.layernorm1 = nn.LayerNorm(embed_sz)
    self.fc = nn.Sequential(
        nn.Linear(embed_sz, fcnn_sz),
        nn.GELU(),
        nn.Linear(fcnn_sz, embed_sz)
    )
    self.layernorm2 = nn.LayerNorm(embed_sz)

  def forward(self, inp):
    inp = inp + self.attention(self.layernorm1(inp))
    inp = inp + self.fc(self.layernorm2(inp))

    return inp

class TransformerModel(nn.Module):
  def __init__(self, head_sz, head_n, token_unique_n, block_sz, embed_sz, layer_n, fcnn_sz):
    super().__init__()
    self.char_embed = nn.Embedding(token_unique_n, embed_sz)
    self.pos_embed = nn.Embedding(block_sz, embed_sz)

    self.model = nn.Sequential(
      *[TransformerBlock(head_sz, head_n, token_unique_n, embed_sz, fcnn_sz).to(device) for _ in range(layer_n)]
    )

    self.out_layernorm = nn.LayerNorm(embed_sz),
    self.out = nn.Linear(embed_sz, token_unique_n, bias = False)

    self.char_embed.weight = self.out.weight # Weight tie the vocabulary embedding weights and the output projection weights

  def forward(self, inp):
    data = torch.stack([torch.stack([self.char_embed(i[j]) + self.pos_embed(torch.tensor(j).to(device)) for j in range(len(i))]) for i in inp])
    data = self.model(data)
    return self.out(data)


In [None]:
# Model Initialization

def init_weights(m):
  if type(m) is nn.Linear:
    nn.init.normal_(m.weight, mean = 0.0, std = 0.02)
    if m.bias is not None:
      torch.nn.init.zeros_(m.bias)

  if type(m) is nn.Embedding:
    nn.init.normal_(m.weight, mean = 0.0, std = 0.02)


def construct_param_groups(m):
  if type(m) is nn.Linear and m.out_features != vocab_sz:
    weight_decay.append(m.weight)
    if m.bias is not None:
      no_weight_decay.append(m.bias)
  if type(m) is nn.Embedding:
    no_weight_decay.append(m.weight)
  if type(m) is nn.LayerNorm:
    no_weight_decay.append(m.weight)
    no_weight_decay.append(m.bias)

weight_decay = []
no_weight_decay = []

optim_groups = [
    {"params": weight_decay, "weight_decay": 0.1},
    {"params": no_weight_decay, "weight_decay": 0}
]

model = TransformerModel(head_sz, head_n, vocab_sz, max_truncate_len, embed_sz, layer_n, fcnn_sz).to(device)
model.apply(init_weights)
model.apply(construct_param_groups)
opt = torch.optim.AdamW(optim_groups, lr = 3e-4, betas = (0.9, 0.95))

print(f'Parameter Number: {sum(p.numel() for p in model.parameters())}')

Parameter Number: 24040576


In [None]:
# Load model from gdrive
from google.colab import drive
drive.mount('/content/gdrive')

def construct_param_groups(m):
  if type(m) is nn.Linear and m.out_features != vocab_sz:
    weight_decay.append(m.weight)
    if m.bias is not None:
      no_weight_decay.append(m.bias)
  if type(m) is nn.Embedding:
    no_weight_decay.append(m.weight)
  if type(m) is nn.LayerNorm:
    no_weight_decay.append(m.weight)
    no_weight_decay.append(m.bias)

weight_decay = []
no_weight_decay = []

optim_groups = [
    {"params": weight_decay, "weight_decay": 0.1},
    {"params": no_weight_decay, "weight_decay": 0}
]

model = TransformerModel(head_sz, head_n, vocab_sz, max_truncate_len, embed_sz, layer_n, fcnn_sz).to(device)
model.apply(construct_param_groups)
opt = torch.optim.AdamW(optim_groups, lr = 3e-4, betas = (0.9, 0.95))
model.load_state_dict(torch.load("/content/gdrive/My Drive/TinyStoriesLM/tinystoriesmini/state.pth"))
opt.load_state_dict(torch.load("/content/gdrive/My Drive/TinyStoriesLM/tinystoriesmini/opt.pth"))

print(f'Parameter Number: {sum(p.numel() for p in model.parameters())}')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
Parameter Number: 24040576


In [None]:
# Generate batched sample for sampling purposes
def test_sample(sample_len):
  return data_tokenized[random.randint(0, len(data_tokenized) - 1)][:sample_len]

# Returns batched sample for training purposes
def train_batch(batch_sz):
  indices = [random.randint(0, len(data_tokenized) - 1) for _ in range(batch_sz)]
  return [data_tokenized[indices[i]][1:] for i in range(batch_sz)], [data_tokenized[indices[i]][:-1] for i in range(batch_sz)]

# Generate sample from user given prompt
def prompt_sample(prompt):
  encoded = []

  for word in prompt.split(" "):
    encoded.extend(encode(word))

  return encoded

In [None]:
# Sampling hyperparameters
sample_model = False
sample_freq = 50

# Saving frequency
save_freq = 10

def train_cycle():
  model.zero_grad()

  for train_cycle in range(cycles):
    if train_cycle % sample_freq == 0 and sample_model:
      test_cycle_top_k()

    model.train()

    print(f'Train Cycle {train_cycle}')
    loss = 0
    accumulated_loss = 0

    for _ in range(macro_batch_sz):
      expected, data = train_batch(train_batch_sz)

      expected = torch.tensor(expected).to(device)
      data = torch.tensor(data).to(device)
      output = model(data)

      # Trained the model with the first loss function, which allows the model to ignore producing end of sequence tokens so the model loss doesn't get
      # saturated with end of sequence tokens, not sure if that's beneficial or not.
      # loss = nn.functional.cross_entropy(torch.reshape(output, (-1, vocab_sz)), torch.reshape(expected, (-1, )), ignore_index = pad_token) / macro_batch_sz
      loss = nn.functional.cross_entropy(torch.reshape(output, (-1, vocab_sz)), torch.reshape(expected, (-1, ))) / macro_batch_sz
      accumulated_loss += loss
      loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()
    model.zero_grad()

    print(f'Iteration Loss: {round(accumulated_loss.item(), 3)}')

    if train_cycle % save_freq == save_freq - 1:
      save_model()

# Top-k sampling the model
def test_cycle_top_k():
  print('================================================== Test Cycle ==================================================')
  model.eval()

  # Generate a sample to serve as context for the model
  if sample_from_prompt:
    context = [prompt_sample(prompt)]
  else:
    context = [test_sample(sample_context)]

  generated_chars = copy.deepcopy(context) # Deep copy context to store all sampled characters

  for i in range(sample_length):
    output_distribution = model(torch.tensor(context).to(device))[0][-1]
    top_chars = torch.topk(output_distribution, sample_topk)
    sampled_char = top_chars[1][list(torch.utils.data.WeightedRandomSampler(nn.functional.softmax(top_chars[0] * sample_temp, dim = 0), 1))[0]].item()

    context[0].append(sampled_char)
    generated_chars[0].append(sampled_char)
    if len(context[0]) > max_truncate_len:
      context[0] = context[0][1:]

  print(decode(generated_chars[0]))
  print('================================================================================================================')

def save_model():
  PATH = "/content/gdrive/My Drive/TinyStoriesLM/tinystoriesmini/state.pth"
  torch.save(model.state_dict(), PATH)
  PATH = "/content/gdrive/My Drive/TinyStoriesLM/tinystoriesmini/opt.pth"
  torch.save(opt.state_dict(), PATH)

In [None]:
for g in opt.param_groups:
  g['lr'] = 1e-4

print(opt)

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0.1

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0
)


In [None]:
# Training cycles
for epoch in range(epochs):
  print(f"Epoch {epoch}")
  train_cycle()

Epoch 0
Train Cycle 0
Iteration Loss: 1.679
Train Cycle 1
Iteration Loss: 1.656
Train Cycle 2
Iteration Loss: 1.667
Train Cycle 3
Iteration Loss: 1.706
Train Cycle 4
Iteration Loss: 1.645
Train Cycle 5
Iteration Loss: 1.676
Train Cycle 6
Iteration Loss: 1.671
Train Cycle 7
Iteration Loss: 1.692
Train Cycle 8
Iteration Loss: 1.681
Train Cycle 9
Iteration Loss: 1.669
Train Cycle 10
Iteration Loss: 1.686
Train Cycle 11
Iteration Loss: 1.649
Train Cycle 12
Iteration Loss: 1.709
Train Cycle 13
Iteration Loss: 1.688
Train Cycle 14
Iteration Loss: 1.71
Train Cycle 15
Iteration Loss: 1.642
Train Cycle 16
Iteration Loss: 1.671
Train Cycle 17
Iteration Loss: 1.651
Train Cycle 18
Iteration Loss: 1.674
Train Cycle 19
Iteration Loss: 1.644
Train Cycle 20
Iteration Loss: 1.709
Train Cycle 21
Iteration Loss: 1.649
Train Cycle 22
Iteration Loss: 1.663
Train Cycle 23
Iteration Loss: 1.696
Train Cycle 24
Iteration Loss: 1.673
Train Cycle 25
Iteration Loss: 1.687
Train Cycle 26
Iteration Loss: 1.662
Trai

In [None]:
# Sample the model

# Sampling hyperparameters
sample_from_prompt = True
prompt = '''Once upon a time,'''
sample_topk = 25

sample_length = 512
sample_context = 64
sample_temp = 2.5

# Run inference
test_cycle_top_k()

 Once upon a time, there was a little girl named Lily. She loved to play outside in the sun. One day, she saw a big, scary dog. The dog was barking and growling. Lily was scared, but she wanted to be brave.

 She went inside her house and saw a big, scary dog. The dog was barking and growling. Lily was scared, but she remembered her mommy's words. She knew she had to be brave and go back inside.

 Lily walked back inside and told her mommy what happened. Her mommy hugged her and said," Don't worry, Lily. The dog is just a friendly dog. He just wanted to play." Lily felt better and went back inside. She was happy that she was brave enough to go outside and play.

 Once upon a time, there was a little girl named Lily. She loved to play outside in the snow. One day, she saw a big snowman in the snow. She was so happy and ran to it.

 But then, the snowman started to melt! Lily was sad and didn't know what to do. She asked her mom for help. Her mom said," Don't worry, Lily. We can make a n

In [None]:
# Save the model progress
save_model()

In [None]:
for param in model.named_parameters():
  print(param)

# Clear GPU RAM on Colab (in case tuning parameters in the middle of a session)

In [None]:
# Weird hack to clear RAM effectively - sometimes variables can't be garbage collected
# due to being part of an exception and thus causing a new exception can relieve them
print(1/0)

ZeroDivisionError: ignored

In [None]:
# Not sure why you have to run this twice, but I found this works to clear the RAM
import gc
torch.cuda.empty_cache()
model = None
opt = None
gc.collect()

torch.cuda.empty_cache()
model = None
opt = None
gc.collect()

0