In [1]:
import sys
import os
import torch 
from pathlib import Path

def get_project_info() -> Path:
  current = Path.cwd().resolve()
  root = current
  for parent in [current, *current.parents]:
    if (parent / "toy_transformers").exists():
      root = parent
      break
  return root, current

if 'ROOT_DIR' not in globals():
	ROOT_DIR, EXPERIMENT_DIR = get_project_info()
	if str(ROOT_DIR) not in sys.path:
		sys.path.append(str(ROOT_DIR))
	if Path.cwd() != ROOT_DIR:
		os.chdir(ROOT_DIR)

from toy_transformers.models import gptv3
from toy_transformers import tokenization
from toy_transformers import checkpoint

In [2]:
VOCAB_SIZE = 256
BATCH_SIZE = 16
MODE = tokenization.TokenizationMode.STR
DEVICE = "mps"

config = gptv3.GPTv3Config(
	vocab_size=VOCAB_SIZE,
	block_size=256,
	device=DEVICE,
	n_heads=6,
)

In [3]:
vocab_path = EXPERIMENT_DIR / f"vocab_{VOCAB_SIZE}.json"
raw_data_path = ROOT_DIR / "data/gutenberg/freud-interpretation-of-dreams.txt"

if not vocab_path.exists():
	raw_data = open(raw_data_path, "r")
	vocab = tokenization.create_bpe(
		raw_data, 
		VOCAB_SIZE, MODE
	)
	vocab.save(vocab_path)
else:
	vocab = tokenization.Vocabulary.load(vocab_path)

In [4]:
data = torch.tensor(
	vocab.encode(open(raw_data_path, "r").read()),
	dtype=torch.long
).to(device=DEVICE)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

block_size, batch_size = config.block_size, BATCH_SIZE
def get_batch(split):
  data = train_data if split == 'train' else val_data
  idxs = torch.randint(len(data) - block_size, (batch_size,), device=DEVICE)
  x = torch.stack([data[i:i+block_size] for i in idxs])
  y = torch.stack([data[i+1:i+block_size+1] for i in idxs])
  return x, y

@torch.no_grad()
def estimate_val_loss(model):
  model.eval()
  X, Y = get_batch("val")
  _, loss = model(X, Y)
  model.train()
  return loss.item()

In [5]:
torch.set_float32_matmul_precision("medium")
m = gptv3.LanguageModel(config).to(device=DEVICE)
m.compile()

optimizer = torch.optim.AdamW(m.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  optimizer,
  mode='min',
  factor=0.1,
  patience=10
)

import time
from torch.amp import autocast
from tqdm import tqdm

@torch.compile(fullgraph=False)
def opt_step():
	optimizer.step()


TOTAL_STEPS = 5000
CKPT_DIR = EXPERIMENT_DIR / "checkpoints"
TRAINING_CONFIG = {
	"lr": 3e-4, 
	"batch_size": BATCH_SIZE, "vocab_path": str(vocab_path)
}

metrics = []
step = 1

In [6]:
while step < TOTAL_STEPS:
	xb, yb = get_batch('train')
	with autocast(device_type=DEVICE, dtype=torch.float16):
		logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	opt_step()

	if step % 10 != 0: 
		step += 1
		continue

	row = {"step": step, "train_loss": loss.item()}
	if step % 50 == 0:
		row["val_loss"] = estimate_val_loss(m)
		scheduler.step(row["val_loss"])
		print(step, row.get("train_loss"), row.get("val_loss"))
	metrics.append(row)
	
	if step % 1000 == 0:
		checkpoint.save(
			CKPT_DIR / f"step-{step}",
			m, config, TRAINING_CONFIG, metrics, optimizer=optimizer, scheduler=scheduler
		)
		print(f"saved /checkpoints/step-{step}")

	step += 1

checkpoint.save(
	CKPT_DIR / "final", 
	m, config, TRAINING_CONFIG, 
	metrics, 
	optimizer=optimizer, scheduler=scheduler
)

W0219 11:48:49.627000 41235 torch/_inductor/utils.py:1436] [0/0] Not enough SMs to use max_autotune_gemm mode
W0219 11:48:52.784000 41235 torch/_logging/_internal.py:1154] [1/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored


50 4.384234428405762 4.648870468139648
100 4.047226905822754 4.139969825744629
150 3.7820019721984863 4.086138725280762
200 3.7535905838012695 3.92063307762146
250 3.6739964485168457 3.921074390411377
300 3.645641565322876 3.9194817543029785
350 3.634645462036133 3.8154044151306152
400 3.6066665649414062 3.8864126205444336
450 3.584534168243408 3.7719945907592773
500 3.648094892501831 3.7483103275299072
550 3.6554043292999268 3.7068026065826416
600 3.5686092376708984 3.708644151687622
650 3.5130562782287598 3.7643392086029053
700 3.5106019973754883 3.672624111175537
750 3.5582845211029053 3.7181248664855957
800 3.495180130004883 3.605212450027466
850 3.468641757965088 3.7021896839141846
900 3.320737361907959 3.628838062286377
950 3.358198642730713 3.6623711585998535
1000 3.3235650062561035 3.522066593170166
saved /checkpoints/step-1000
1050 3.3326032161712646 3.491982936859131
1100 3.1019415855407715 3.519822359085083
1150 3.122255802154541 3.4177074432373047
1200 3.042525053024292 3.2

In [9]:
# test loading checkpoint
m, config, training_config, metrics, opt_state, sched_state = checkpoint.load(
	CKPT_DIR / "final",
	gptv3.LanguageModel, gptv3.GPTv3Config, device=DEVICE, 
)
optimizer = torch.optim.AdamW(m.parameters(), lr=training_config["lr"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
	optimizer, mode='min', factor=0.1, patience=10
)
if opt_state: optimizer.load_state_dict(opt_state)
if sched_state: scheduler.load_state_dict(sched_state)

In [10]:
idx = torch.tensor([vocab.encode("The mind ")], dtype=torch.long, device=DEVICE)
print(idx)
print("The mind ", end="", flush=True)
for token in m.generate(idx, max_new_tokens=200):
	print(vocab.decode([token.item()])[0], end="", flush=True)
print()

tensor([[ 36, 126, 145, 129,  47]], device='mps:0')
The mind  to itself allows to reproduce on the dreamer has urinattensely with a relation to all scene impressions“He I fear that about just as to the first I make myself art you youth dream when shewas a short patient to hiscause her submbles“milike it your man her first impes at nightI was a going to adterminit friend whenhe has made they are not heard and really treating it of the punished hada
