In [1]:
import sys
import os
import torch

sys.path.append(os.path.abspath(os.path.join('..')))

from models import gptv1

config = gptv1.GPTv1Config(device = "mps", batch_size = 16, block_size=256, n_heads=6)
device = config.device

In [2]:
filepath = "../../data/freud/interpretation-of-dreams.txt"
input_file = open(filepath, 'r', encoding='utf-8')
raw_text = input_file.read()
input_file.close()

In [3]:
import re
cleaned_text = raw_text.replace('\ufeff', '')
cleaned_text = re.sub(r'^[ \t]+|[ \t]+$', '', cleaned_text, flags=re.MULTILINE)
cleaned_text = re.sub(r'\n{2,}', '\n\n', cleaned_text)
cleaned_text = re.sub(r'(?<!\n)\n(?!\n)', ' ', cleaned_text)
cleaned_text = re.sub(r' {2,}', ' ', cleaned_text)
cleaned_text = re.sub(r'\n\n', '\n', cleaned_text)
text = cleaned_text.strip()
characters = sorted(list(set(text)))

In [4]:
idx_to_token = dict(enumerate(characters))
token_to_idx = {t: i for i, t in enumerate(characters)}
encode = lambda s: list(map(token_to_idx.__getitem__, s))
decode = lambda s: list(map(idx_to_token.__getitem__, s))
data = torch.tensor(encode(text), dtype=torch.long).to(device=device)

n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

block_size, batch_size = config.block_size, config.batch_size
def get_batch(split):
  data = train_data if split == 'train' else val_data
  idxs = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in idxs])
  y = torch.stack([data[i+1:i+block_size+1] for i in idxs])
  return x, y

@torch.no_grad()
def estimate_val_loss(model):
  model.eval()
  X, Y = get_batch("val")
  X = X.to(device=device)
  Y = Y.to(device=device)
  _, loss = model(X, Y)
  model.train()
  return loss.item()


In [5]:
torch._dynamo.list_backends()

['cudagraphs', 'inductor', 'onnxrt', 'openxla', 'tvm']

In [10]:
m = gptv1.LanguageModel(len(characters), config).to(device)
# m.compile()
optimizer = torch.optim.AdamW(m.parameters(), lr=3e-4)
import time
from torch.amp import autocast

In [21]:
for steps in range(1000):
	xb, yb = get_batch('train')
	xb = xb.to(device)
	yb = yb.to(device)
	with autocast(device_type="mps", dtype=torch.float16):
		logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	optimizer.step()
	if steps % 100 == 0:
		val_loss = estimate_val_loss(m)
		print(f"{steps}: {loss.item()} | val: {val_loss}")
	torch.mps.synchronize()
	time.sleep(0.01)

0: 1.1247825622558594 | val: 0.9980862736701965
100: 1.0756864547729492 | val: 1.0950121879577637
200: 1.0799243450164795 | val: 1.1442018747329712
300: 1.111983299255371 | val: 1.1099348068237305
400: 1.1372044086456299 | val: 1.0408183336257935
500: 1.0872889757156372 | val: 1.0317912101745605
600: 1.1487681865692139 | val: 1.11277174949646
700: 1.0998644828796387 | val: 1.0333386659622192
800: 1.07499098777771 | val: 1.0962469577789307
900: 1.1547526121139526 | val: 1.0294935703277588


In [25]:
idx = torch.tensor([encode("Richard ")], dtype=torch.long, device=device)
print("Richard ", end="", flush=True)
for token in m.generate_stream(idx, max_new_tokens=200):
	print(decode([token.item()])[0], end="", flush=True)
print()

Richard devoured eight can we sati, by the dream which are knew with other, because with ones or universal comprehensity.
But the dreamer must as later difference the best; I expel his contains the presence o


In [26]:
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

6.128176 M parameters


In [27]:
torch.save(m.state_dict(), "gptv1_sample1.pt")