In [1]:
import sys
import os
import torch
from torch.amp import autocast

sys.path.append(os.path.abspath(os.path.join('..')))

from models import gptv2 as transformer
from utilities import text_cleaning, tokenizer

In [2]:
vocab_size = 512
device = "mps"
config = transformer.GPTv2Config(
	vocab_size=vocab_size,
	device=device,
)
m = transformer.LanguageModel(config)
print(m.get_num_parameters(as_str=True))

6.179m


In [3]:
filepath = "../../data/gsm8k/full.txt"
input_file = open(filepath, 'r', encoding='utf-8')
raw_text = input_file.read()
input_file.close()

In [4]:
from utilities import text_cleaning
from utilities import tokenizer as tokenizer
text = raw_text
td = tokenizer.create_tokenizer(text, num_tokens=vocab_size, predefined=["<START>", "<END>"])
print(td.token_set)
print(len(td.token_set))

['<START>', '<END>', '\t', '\n', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '^', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '}', '\xa0', '¢', '£', '³', '¼', '½', '¾', '×', 'ç', 'è', 'é', 'ñ', '÷', 'А', '\u200b', '–', '—', '‘', '’', '“', '”', '€', '−', '√', ' t', 'he', ' a', 'in', ' the', ' s', 'es', ' o', 're', '00', ' m', 'er', ' h', ' c', ' p', ' b', ' f', ' w', 'an', ' of', ' to', 'is', ' 1', 'nd', 'al', 'ar', ' d', 'ou', ' 2', 'at', 'as', 'or', 'on', 'ed', ' e', 'll', 'ch', 'en', 'ow', ' in', 'ing', ' and', 'ts', ' 3', ' n', 'ay', ' he', ' g', ' th', 'The', ' is', ' l', 'it', ' 4', 'ic', ' co', '50', 'et', '20', 'AR', 'ND', 'END', 'ST

In [5]:
characters, idx_to_token, token_to_idx = td
encode = tokenizer.get_encoder(td)
decode = tokenizer.get_decoder(td)
data = torch.tensor(encode(text), dtype=torch.long).to(device=device)

n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

block_size, batch_size = config.block_size, config.batch_size
def get_batch(split):
  data = train_data if split == 'train' else val_data
  idxs = torch.randint(len(data) - block_size, (batch_size,), device=device)
  x = torch.stack([data[i:i+block_size] for i in idxs])
  y = torch.stack([data[i+1:i+block_size+1] for i in idxs])
  return x, y

@torch.no_grad()
def estimate_val_loss(model):
  model.eval()
  X, Y = get_batch("val")
  _, loss = model(X, Y)
  model.train()
  return loss.item()

In [6]:
torch.set_float32_matmul_precision("medium")
m = transformer.LanguageModel(config).to(device=device)
m.compile()

optimizer = m.get_optimizer(weight_decay=0.01, lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
	optimizer, mode='min', factor=0.1, patience=10
)

@torch.compile(fullgraph=False)
def opt_step():
	optimizer.step()

In [17]:
num_steps = 1000
for step in range(num_steps):
	xb, yb = get_batch('train')
	m.train()
	with autocast(device_type="mps", dtype=torch.float16):
		logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	opt_step()
	train_loss, val_loss = loss.item(), None
	if step % 50 == 0:
		val_loss = estimate_val_loss(m)
		scheduler.step(val_loss)
		print(f"[{step:04d}/{num_steps}] train: {train_loss:01.05f} val: {val_loss:01.05f}")
	elif step % 25 == 0:
		print(f"[{step:04d}/{num_steps}] train: {train_loss:01.05f}")


[0000/1000] train: 1.63325 val: 1.65944
[0025/1000] train: 1.64618
[0050/1000] train: 1.56198 val: 1.59178
[0075/1000] train: 1.53419
[0100/1000] train: 1.61095 val: 1.71255
[0125/1000] train: 1.58318
[0150/1000] train: 1.63872 val: 1.62444
[0175/1000] train: 1.64037
[0200/1000] train: 1.53076 val: 1.71924
[0225/1000] train: 1.63400
[0250/1000] train: 1.63927 val: 1.55578
[0275/1000] train: 1.57439
[0300/1000] train: 1.61565 val: 1.58658
[0325/1000] train: 1.54941
[0350/1000] train: 1.59138 val: 1.68251
[0375/1000] train: 1.62881
[0400/1000] train: 1.65113 val: 1.65079
[0425/1000] train: 1.55605
[0450/1000] train: 1.51294 val: 1.66040
[0475/1000] train: 1.56573
[0500/1000] train: 1.62771 val: 1.54786
[0525/1000] train: 1.62546
[0550/1000] train: 1.61106 val: 1.62593
[0575/1000] train: 1.56638
[0600/1000] train: 1.65033 val: 1.57127


KeyboardInterrupt: 

In [20]:
seed = "<START>I have ten apples. Tobias takes six. How many do I have left?"
end_tok = encode("<END>")[0]
idx = torch.tensor([encode(seed)], dtype=torch.long, device=device)
print(seed, end="", flush=True)
for token in m.generate(idx, max_new_tokens=400):
	v = token.item()
	if v == end_tok:
		break
	print(decode([v])[0], end="", flush=True)
print()

<START>I have ten apples. Tobias takes six. How many do I have left?
It takes 3 pets and 7 dogs to skill.
10 x 3 = <<10*3=30>>30 pets.
15 It takes up only 60 - 10 = <<150-10=60>>60 dogs to skill.
25 - 30 = <<25-30=10>>10 dogs remaining.
#### 10
