In [1]:
import sys
import os
import torch 
from pathlib import Path

def get_project_info() -> Path:
  current = Path.cwd().resolve()
  root = current
  for parent in [current, *current.parents]:
    if (parent / "toy_transformers").exists():
      root = parent
      break
  return root, current

if 'ROOT_DIR' not in globals():
	ROOT_DIR, EXPERIMENT_DIR = get_project_info()
	if str(ROOT_DIR) not in sys.path:
		sys.path.append(str(ROOT_DIR))
	if Path.cwd() != ROOT_DIR:
		os.chdir(ROOT_DIR)

from toy_transformers.models import gptv1
from toy_transformers import tokenization
from toy_transformers import checkpoint

In [2]:
VOCAB_SIZE = 256
BATCH_SIZE = 16
MODE = tokenization.TokenizationMode.STR
DEVICE = "mps"

config = gptv1.GPTv1Config(
	vocab_size=VOCAB_SIZE,
	block_size=256,
	device=DEVICE,
	n_heads=6,
)

In [3]:
vocab_path = EXPERIMENT_DIR / f"vocab_{VOCAB_SIZE}.json"
raw_data_path = ROOT_DIR / "data/gutenberg/freud-interpretation-of-dreams.txt"

if not vocab_path.exists():
	raw_data = open(raw_data_path, "r")
	vocab = tokenization.create_bpe(
		raw_data, 
		VOCAB_SIZE, MODE
	)
	vocab.save(vocab_path)
else:
	vocab = tokenization.Vocabulary.load(vocab_path)

In [4]:
data = torch.tensor(
	vocab.encode(open(raw_data_path, "r").read()),
	dtype=torch.long
).to(device=DEVICE)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

block_size, batch_size = config.block_size, BATCH_SIZE
def get_batch(split):
  data = train_data if split == 'train' else val_data
  idxs = torch.randint(len(data) - block_size, (batch_size,), device=DEVICE)
  x = torch.stack([data[i:i+block_size] for i in idxs])
  y = torch.stack([data[i+1:i+block_size+1] for i in idxs])
  return x, y

@torch.no_grad()
def estimate_val_loss(model):
  model.eval()
  X, Y = get_batch("val")
  _, loss = model(X, Y)
  model.train()
  return loss.item()

In [9]:
torch.set_float32_matmul_precision("medium")
m = gptv1.LanguageModel(config).to(device=DEVICE)
m.compile()

optimizer = torch.optim.AdamW(m.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  optimizer,
  mode='min',
  factor=0.1,
  patience=10
)

import time
from torch.amp import autocast
from tqdm import tqdm

@torch.compile(fullgraph=False)
def opt_step():
	optimizer.step()


TOTAL_STEPS = 5000
CKPT_DIR = EXPERIMENT_DIR / "checkpoints"
TRAINING_CONFIG = {
	"lr": 3e-4, 
	"batch_size": BATCH_SIZE, "vocab_path": str(vocab_path)
}

metrics = []
step = 1

In [11]:
while step < TOTAL_STEPS:
	xb, yb = get_batch('train')
	with autocast(device_type=DEVICE, dtype=torch.float16):
		logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	opt_step()

	if step % 10 != 0: 
		step += 1
		continue

	row = {"step": step, "train_loss": loss.item()}
	if step % 50 == 0:
		row["val_loss"] = estimate_val_loss(m)
		scheduler.step(row["val_loss"])
		print(step, row.get("train_loss"), row.get("val_loss"))
	metrics.append(row)
	
	if step % 1000 == 0:
		checkpoint.save(
			CKPT_DIR / f"step-{step}",
			m, config, TRAINING_CONFIG, metrics, optimizer=optimizer, scheduler=scheduler
		)
		print(f"saved /checkpoints/step-{step}")

	step += 1

checkpoint.save(
	CKPT_DIR / "final", 
	m, config, TRAINING_CONFIG, 
	metrics, 
	optimier=optimizer, scheduler=scheduler
)

50 4.33851957321167 4.636962413787842
100 3.867475986480713 4.282941818237305
150 3.658592700958252 3.9919657707214355
200 3.6067748069763184 3.9639782905578613
250 3.6928863525390625 3.852693557739258
300 3.543936252593994 3.8758645057678223
350 3.540205955505371 3.822340250015259
400 3.6648848056793213 3.819779872894287
450 3.54874324798584 3.649855613708496
500 3.5332999229431152 3.6951940059661865
550 3.51282000541687 3.724705219268799
600 3.4040558338165283 3.6382627487182617
650 3.474104404449463 3.6349310874938965
700 3.485846996307373 3.689021587371826
750 3.4612555503845215 3.5868759155273438
800 3.2819128036499023 3.59631609916687
850 3.2905588150024414 3.430046558380127
900 3.092604160308838 3.474879503250122
950 3.042264938354492 3.357882261276245
1000 2.9300358295440674 3.313901424407959
saved /checkpoints/step-1000
1050 2.901423692703247 3.1838605403900146
1100 2.872645139694214 3.194042682647705
1150 2.869565725326538 3.2416810989379883
1200 2.8970236778259277 3.09223747

KeyboardInterrupt: 

In [15]:
# test loading checkpoint
m, config, training_config, metrics, opt_state, sched_state = checkpoint.load(
	CKPT_DIR / "step-1000",
	gptv1.LanguageModel, gptv1.GPTv1Config, device=DEVICE, 
)
optimizer = torch.optim.AdamW(m.parameters(), lr=training_config["lr"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
	optimizer, mode='min', factor=0.1, patience=10
)
if opt_state: optimizer.load_state_dict(opt_state)
if sched_state: scheduler.load_state_dict(sched_state)

In [16]:
idx = torch.tensor([vocab.encode("The mind ")], dtype=torch.long, device=DEVICE)
print(idx)
print("The mind ", end="", flush=True)
for token in m.generate(idx, max_new_tokens=200):
	print(vocab.decode([token.item()])[0], end="", flush=True)
print()

tensor([[ 36, 126, 145, 129,  47]], device='mps:0')
The mind  it from psychoind the may not it erows are theowing to thedayingunated the mother year that is finds invelucepe it his sleep Bystemang probaurysic at the such dispinion than in the resteries not codegidently the work appary to mething theysketout shoryital of the sericatus imply that I remn have in

KeyboardInterrupt: 