In [1]:
import sys
import os
import torch 
from pathlib import Path

def get_project_info() -> Path:
  current = Path.cwd().resolve()
  root = current
  for parent in [current, *current.parents]:
    if (parent / "toy_transformers").exists():
      root = parent
      break
  return root, current

if 'ROOT_DIR' not in globals():
	ROOT_DIR, EXPERIMENT_DIR = get_project_info()
	if str(ROOT_DIR) not in sys.path:
		sys.path.append(str(ROOT_DIR))
	if Path.cwd() != ROOT_DIR:
		os.chdir(ROOT_DIR)

from toy_transformers.models import gptv2
from toy_transformers import tokenization
from toy_transformers import checkpoint

In [2]:
VOCAB_SIZE = 256
BATCH_SIZE = 16
MODE = tokenization.TokenizationMode.STR
DEVICE = "mps"

config = gptv2.GPTv2Config(
	vocab_size=VOCAB_SIZE,
	block_size=256,
	device=DEVICE,
	n_heads=6,
)

In [3]:
vocab_path = EXPERIMENT_DIR / f"vocab_{VOCAB_SIZE}.json"
raw_data_path = ROOT_DIR / "data/gutenberg/freud-interpretation-of-dreams.txt"

if not vocab_path.exists():
	raw_data = open(raw_data_path, "r")
	vocab = tokenization.create_bpe(
		raw_data, 
		VOCAB_SIZE, MODE
	)
	vocab.save(vocab_path)
else:
	vocab = tokenization.Vocabulary.load(vocab_path)

In [4]:
data = torch.tensor(
	vocab.encode(open(raw_data_path, "r").read()),
	dtype=torch.long
).to(device=DEVICE)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

block_size, batch_size = config.block_size, BATCH_SIZE
def get_batch(split):
  data = train_data if split == 'train' else val_data
  idxs = torch.randint(len(data) - block_size, (batch_size,), device=DEVICE)
  x = torch.stack([data[i:i+block_size] for i in idxs])
  y = torch.stack([data[i+1:i+block_size+1] for i in idxs])
  return x, y

@torch.no_grad()
def estimate_val_loss(model):
  model.eval()
  X, Y = get_batch("val")
  _, loss = model(X, Y)
  model.train()
  return loss.item()

In [5]:
torch.set_float32_matmul_precision("medium")
m = gptv2.LanguageModel(config).to(device=DEVICE)
m.compile()

optimizer = torch.optim.AdamW(m.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  optimizer,
  mode='min',
  factor=0.1,
  patience=10
)

import time
from torch.amp import autocast
from tqdm import tqdm

@torch.compile(fullgraph=False)
def opt_step():
	optimizer.step()


TOTAL_STEPS = 5000
CKPT_DIR = EXPERIMENT_DIR / "checkpoints"
TRAINING_CONFIG = {
	"lr": 3e-4, 
	"batch_size": BATCH_SIZE, "vocab_path": str(vocab_path)
}

metrics = []
step = 1

In [None]:
while step < TOTAL_STEPS:
	xb, yb = get_batch('train')
	with autocast(device_type=DEVICE, dtype=torch.float16):
		logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	opt_step()

	if step % 10 != 0: 
		step += 1
		continue

	row = {"step": step, "train_loss": loss.item()}
	if step % 50 == 0:
		row["val_loss"] = estimate_val_loss(m)
		scheduler.step(row["val_loss"])
		print(step, row.get("train_loss"), row.get("val_loss"))
	metrics.append(row)
	
	if step % 1000 == 0:
		checkpoint.save(
			CKPT_DIR / f"step-{step}",
			m, config, TRAINING_CONFIG, metrics, optimizer=optimizer, scheduler=scheduler
		)
		print(f"saved /checkpoints/step-{step}")

	step += 1

checkpoint.save(
	CKPT_DIR / "final", 
	m, config, TRAINING_CONFIG, 
	metrics, 
	optimizer=optimizer, scheduler=scheduler
)

W0219 10:16:30.542000 30898 torch/_inductor/utils.py:1436] [0/0] Not enough SMs to use max_autotune_gemm mode
W0219 10:16:33.930000 30898 torch/_logging/_internal.py:1154] [1/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored


50 4.292960166931152 4.661849021911621
100 3.9201087951660156 4.316394805908203
150 3.7405645847320557 4.088815689086914
200 3.7566044330596924 3.878612756729126
250 3.715174913406372 3.900709390640259
300 3.6568617820739746 3.9685564041137695
350 3.6479673385620117 3.7475197315216064
400 3.606252670288086 3.682924270629883
450 3.588193655014038 3.7294578552246094
500 3.5607895851135254 3.7191061973571777
550 3.5317625999450684 3.757465362548828
600 3.5554752349853516 3.6862587928771973
650 3.4921512603759766 3.765446424484253
700 3.496335506439209 3.7302021980285645
750 3.5093302726745605 3.6455678939819336
800 3.319657802581787 3.675609588623047
850 3.4025797843933105 3.5273284912109375
900 3.3259317874908447 3.657517433166504
950 3.293686628341675 3.5478515625
1000 3.1013734340667725 3.5098273754119873
saved /checkpoints/step-1000
1050 3.0829339027404785 3.4094481468200684
1100 3.115504264831543 3.363572359085083
1150 3.0943117141723633 3.35829758644104
1200 3.024785041809082 3.3158

TypeError: save() got an unexpected keyword argument 'optimier'

In [10]:
# test loading checkpoint
m, config, training_config, metrics, opt_state, sched_state = checkpoint.load(
	CKPT_DIR / "step-4000",
	gptv2.LanguageModel, gptv2.GPTv2Config, device=DEVICE, 
)
optimizer = torch.optim.AdamW(m.parameters(), lr=training_config["lr"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
	optimizer, mode='min', factor=0.1, patience=10
)
if opt_state: optimizer.load_state_dict(opt_state)
if sched_state: scheduler.load_state_dict(sched_state)

In [11]:
idx = torch.tensor([vocab.encode("The mind ")], dtype=torch.long, device=DEVICE)
print(idx)
print("The mind ", end="", flush=True)
for token in m.generate(idx, max_new_tokens=200):
	print(vocab.decode([token.item()])[0], end="", flush=True)
print()

tensor([[ 36, 126, 145, 129,  47]], device='mps:0')
The mind  was groing

But she is arrtlated the following she determined the fact that the patient is applied toThe dreamer was always before the fact that it does not equal expression I shall gave my wife to give exact that these process of diagnosis ence that according to advertise The but identification is caused by their expression of these lurkes as it were stillfurthermorgotten in normal b

KeyboardInterrupt: 