In [1]:
import sys
import os
import torch 
from pathlib import Path

def get_project_info() -> Path:
  current = Path.cwd().resolve()
  root = current
  for parent in [current, *current.parents]:
    if (parent / "toy_transformers").exists():
      root = parent
      break
  return root, current

if 'ROOT_DIR' not in globals():
	ROOT_DIR, EXPERIMENT_DIR = get_project_info()
	if str(ROOT_DIR) not in sys.path:
		sys.path.append(str(ROOT_DIR))
	if Path.cwd() != ROOT_DIR:
		os.chdir(ROOT_DIR)

from toy_transformers.models import gptv3
from toy_transformers import tokenization
from toy_transformers import checkpoint

In [2]:
VOCAB_SIZE = 256
BATCH_SIZE = 16
MODE = tokenization.TokenizationMode.STR
DEVICE = "mps"

config = gptv3.GPTv3Config(
	vocab_size=VOCAB_SIZE,
	block_size=256,
	device=DEVICE,
	n_heads=6,
)

In [3]:
vocab_path = EXPERIMENT_DIR / f"vocab_{VOCAB_SIZE}.json"
raw_data_path = ROOT_DIR / "data/gutenberg/freud-interpretation-of-dreams.txt"

if not vocab_path.exists():
	raw_data = open(raw_data_path, "r")
	vocab = tokenization.create_bpe(
		raw_data, 
		VOCAB_SIZE, MODE
	)
	vocab.save(vocab_path)
else:
	vocab = tokenization.Vocabulary.load(vocab_path)

In [4]:
data = torch.tensor(
	vocab.encode(open(raw_data_path, "r").read()),
	dtype=torch.long
).to(device=DEVICE)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

block_size, batch_size = config.block_size, BATCH_SIZE
def get_batch(split):
  data = train_data if split == 'train' else val_data
  idxs = torch.randint(len(data) - block_size, (batch_size,), device=DEVICE)
  x = torch.stack([data[i:i+block_size] for i in idxs])
  y = torch.stack([data[i+1:i+block_size+1] for i in idxs])
  return x, y

@torch.no_grad()
def estimate_val_loss(model):
  model.eval()
  X, Y = get_batch("val")
  _, loss = model(X, Y)
  model.train()
  return loss.item()

In [5]:
torch.set_float32_matmul_precision("medium")
m = gptv3.LanguageModel(config).to(device=DEVICE)
m.compile()

optimizer = m.get_optimizer(weight_decay=0.1, lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  optimizer,
  mode='min',
  factor=0.1,
  patience=10
)

from torch.amp import autocast

@torch.compile(fullgraph=False)
def opt_step():
	optimizer.step()


TOTAL_STEPS = 5000
CKPT_DIR: Path = EXPERIMENT_DIR / "checkpoints"
TRAINING_CONFIG = {
	"lr": 3e-4, 
	"batch_size": BATCH_SIZE, "vocab_path": str(vocab_path)
}

metrics = []
step = 1

In [6]:
while step < TOTAL_STEPS:
	xb, yb = get_batch('train')
	with autocast(device_type=DEVICE, dtype=torch.bfloat16):
		logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0)
	opt_step()

	if step % 10 != 0: 
		step += 1
		continue

	row = {"step": step, "train_loss": loss.item()}
	if step % 50 == 0:
		row["val_loss"] = estimate_val_loss(m)
		scheduler.step(row["val_loss"])
		print(step, row.get("train_loss"), row.get("val_loss"))
	metrics.append(row)
	
	# if step % 1000 == 0:
	# 	checkpoint.save(
	# 		CKPT_DIR / f"step-{step}",
	# 		m, config, TRAINING_CONFIG, metrics, optimizer=optimizer, scheduler=scheduler
	# 	)
	# 	print(f"saved /checkpoints/step-{step}")

	step += 1

# checkpoint.save(
# 	CKPT_DIR / "final", 
# 	m, config, TRAINING_CONFIG, 
# 	metrics, 
# 	optimizer=optimizer, scheduler=scheduler
# )

W0219 15:40:30.080000 71146 torch/_logging/_internal.py:1154] [1/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored


50 3.9244933128356934 4.1749444007873535
100 3.5931828022003174 3.8514957427978516
150 3.1167306900024414 3.548898220062256
200 2.8385024070739746 3.3778066635131836
250 2.767003059387207 3.2359979152679443
300 2.7373557090759277 3.358879804611206
350 2.5502915382385254 3.1376328468322754
400 2.378495693206787 2.8449530601501465
450 2.4831795692443848 2.952868700027466
500 2.1727793216705322 2.9904608726501465
550 2.2676777839660645 2.911193609237671
600 2.2127254009246826 2.8323745727539062
650 2.201648712158203 2.9746642112731934
700 2.1552319526672363 2.7987470626831055
750 2.0808825492858887 2.7265307903289795
800 2.0988922119140625 2.929091453552246
850 1.9184622764587402 2.814074993133545
900 2.0258474349975586 2.9404420852661133
950 1.911578893661499 2.9324519634246826
1000 1.9286599159240723 3.0614914894104004


KeyboardInterrupt: 

In [7]:
idx = torch.tensor([vocab.encode("The mind ")], dtype=torch.long, device=DEVICE)
print(idx)
print("The mind ", end="", flush=True)
for token in m.generate(idx, max_new_tokens=200):
	print(vocab.decode([token.item()])[0], end="", flush=True)
print()

tensor([[ 36, 126, 145, 129,  47]], device='mps:0')
The mind  is

  return node.target(*args, **kwargs)


most mammore and there was of althoughn forsuch a father has beenfarised from the familymay soothing but why they findcloghtes back upon the child grateful neurosis I sway to exching the woreus as theyage I keep after these indreams within

KeyboardInterrupt: 