In [1]:
import sys
import os
import torch 
from pathlib import Path

def get_project_info() -> Path:
  current = Path.cwd().resolve()
  root = current
  for parent in [current, *current.parents]:
    if (parent / "toy_transformers").exists():
      root = parent
      break
  return root, current

if 'ROOT_DIR' not in globals():
	ROOT_DIR, EXPERIMENT_DIR = get_project_info()
	if str(ROOT_DIR) not in sys.path:
		sys.path.append(str(ROOT_DIR))
	if Path.cwd() != ROOT_DIR:
		os.chdir(ROOT_DIR)

from toy_transformers.models import gptv3
from toy_transformers import tokenization
from toy_transformers import checkpoint

In [2]:
VOCAB_SIZE = 256
BATCH_SIZE = 16
MODE = tokenization.TokenizationMode.STR
DEVICE = "mps"

config = gptv3.GPTv3Config(
	vocab_size=VOCAB_SIZE,
	block_size=256,
	device=DEVICE,
	n_heads=6,
)

In [3]:
vocab_path = EXPERIMENT_DIR / f"vocab_{VOCAB_SIZE}.json"
raw_data_path = ROOT_DIR / "data/gutenberg/freud-interpretation-of-dreams.txt"

if not vocab_path.exists():
	raw_data = open(raw_data_path, "r")
	vocab = tokenization.create_bpe(
		raw_data, 
		VOCAB_SIZE, MODE
	)
	vocab.save(vocab_path)
else:
	vocab = tokenization.Vocabulary.load(vocab_path)

In [4]:
data = torch.tensor(
	vocab.encode(open(raw_data_path, "r").read()),
	dtype=torch.long
).to(device=DEVICE)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

block_size, batch_size = config.block_size, BATCH_SIZE
def get_batch(split):
  data = train_data if split == 'train' else val_data
  idxs = torch.randint(len(data) - block_size, (batch_size,), device=DEVICE)
  x = torch.stack([data[i:i+block_size] for i in idxs])
  y = torch.stack([data[i+1:i+block_size+1] for i in idxs])
  return x, y

@torch.no_grad()
def estimate_val_loss(model):
  model.eval()
  X, Y = get_batch("val")
  _, loss = model(X, Y)
  model.train()
  return loss.item()

In [5]:
torch.set_float32_matmul_precision("medium")
m = gptv3.LanguageModel(config).to(device=DEVICE)
m.compile()

optimizer = m.get_optimizer(weight_decay=0.1, lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  optimizer,
  mode='min',
  factor=0.1,
  patience=10
)

from torch.amp import autocast

@torch.compile(fullgraph=False)
def opt_step():
	optimizer.step()


TOTAL_STEPS = 5000
CKPT_DIR: Path = EXPERIMENT_DIR / "checkpoints"
TRAINING_CONFIG = {
	"lr": 3e-4, 
	"batch_size": BATCH_SIZE, "vocab_path": str(vocab_path)
}

metrics = []
step = 1

In [6]:
while step < TOTAL_STEPS:
	xb, yb = get_batch('train')
	with autocast(device_type=DEVICE, dtype=torch.bfloat16):
		logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0)
	opt_step()

	if step % 10 != 0: 
		step += 1
		continue

	row = {"step": step, "train_loss": loss.item()}
	if step % 50 == 0:
		row["val_loss"] = estimate_val_loss(m)
		scheduler.step(row["val_loss"])
		print(step, row.get("train_loss"), row.get("val_loss"))
	metrics.append(row)
	
	# if step % 1000 == 0:
	# 	checkpoint.save(
	# 		CKPT_DIR / f"step-{step}",
	# 		m, config, TRAINING_CONFIG, metrics, optimizer=optimizer, scheduler=scheduler
	# 	)
	# 	print(f"saved /checkpoints/step-{step}")

	step += 1

# checkpoint.save(
# 	CKPT_DIR / "final", 
# 	m, config, TRAINING_CONFIG, 
# 	metrics, 
# 	optimizer=optimizer, scheduler=scheduler
# )

W0219 15:48:33.021000 71864 torch/_logging/_internal.py:1154] [1/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored


50 3.91800594329834 4.19514274597168
100 3.633840560913086 3.9798569679260254
150 3.2122802734375 3.6354613304138184
200 2.771958589553833 3.237701416015625
250 2.6974315643310547 3.3200511932373047
300 2.592297077178955 3.0565438270568848
350 2.490304470062256 3.1327829360961914
400 2.4168834686279297 2.778806447982788
450 2.466536045074463 2.9609217643737793
500 2.36013126373291 3.0146431922912598
550 2.25425386428833 2.966062545776367
600 2.258500099182129 2.9977903366088867
650 2.231710910797119 2.944746732711792
700 2.251190662384033 2.9492666721343994
750 2.123166561126709 2.8822574615478516
800 2.1107873916625977 2.909810781478882
850 1.987457036972046 2.9613428115844727
900 1.9929145574569702 3.076072931289673
950 1.9082441329956055 2.8268110752105713
1000 1.906280517578125 2.895228147506714
1050 1.8253810405731201 2.8321151733398438
1100 1.9529666900634766 2.568094491958618
1150 1.782371997833252 2.98931622505188
1200 1.8500678539276123 2.971012592315674
1250 1.740368366241455

KeyboardInterrupt: 

In [7]:
idx = torch.tensor([vocab.encode("The mind ")], dtype=torch.long, device=DEVICE)
print(idx)
print("The mind ", end="", flush=True)
for token in m.generate(idx, max_new_tokens=200):
	print(vocab.decode([token.item()])[0], end="", flush=True)
print()

tensor([[ 36, 126, 145, 129,  47]], device='mps:0')
The mind  w

  return node.target(*args, **kwargs)


as cleaphocontent a profond aconnectionthe in thefulfilment of his fathersThe condensation had mentioned during the way to that we had-called each with the exceused that upon it oneof followedexperience in friend which would have us an enablealoccur in meition

The same room and a distance which they can be told in the dream is not a wish as long as it were necessary for the motive for the dreaming mind This a wish We have a
