In [None]:
import sys
import os
import torch

sys.path.append(os.path.abspath(os.path.join('..')))

from models import gptv1

config = gptv1.GPTv1Config(device="mps", batch_size=16, block_size=256, n_heads=6)
device = config.device

In [2]:
filepath = "../../data/freud/interpretation-of-dreams.txt"
input_file = open(filepath, 'r', encoding='utf-8')
raw_text = input_file.read()
input_file.close()

In [None]:
from utilities import text_cleaning
text = text_cleaning.basic_cleaning(raw_text)
characters = sorted(list(set(text)))

In [4]:
characters = sorted(list(set(text)))
idx_to_token = dict(enumerate(characters))
token_to_idx = {t: i for i, t in enumerate(characters)}
encode = lambda s: list(map(token_to_idx.__getitem__, s))
decode = lambda s: list(map(idx_to_token.__getitem__, s))
data = torch.tensor(encode(text), dtype=torch.long).to(device=device)

n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

block_size, batch_size = config.block_size, config.batch_size
def get_batch(split):
  data = train_data if split == 'train' else val_data
  idxs = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in idxs])
  y = torch.stack([data[i+1:i+block_size+1] for i in idxs])
  return x, y

@torch.no_grad()
def estimate_val_loss(model):
  model.eval()
  X, Y = get_batch("val")
  X = X.to(device=device)
  Y = Y.to(device=device)
  _, loss = model(X, Y)
  model.train()
  return loss.item()


In [5]:
m = gptv1.LanguageModel(len(characters), config).to(device)
m.compile()
optimizer = torch.optim.AdamW(m.parameters(), lr=3e-4)
import time
from torch.amp import autocast
from tqdm import tqdm

KeyboardInterrupt: 

In [None]:
val_loss = -1
pbar = tqdm(range(1000))
for steps in pbar:
	xb, yb = get_batch('train')
	xb = xb.to(device)
	yb = yb.to(device)
	with autocast(device_type="mps", dtype=torch.float16):
		logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	optimizer.step()
	if steps % 50 == 0:
		val_loss = estimate_val_loss(m)
	pbar.postfix = f"train {loss.item()} val {val_loss}"
	pbar.update()
	torch.mps.synchronize()

100%|██████████| 1000/1000 [02:29<00:00,  6.69it/s, train 1.0835226774215698 val 1.0783356428146362]


In [None]:
idx = torch.tensor([encode("The mind ")], dtype=torch.long, device=device)
print("The mind ", end="", flush=True)
for token in m.generate(idx, max_new_tokens=200):
	print(decode([token.item()])[0], end="", flush=True)
print()

The mind is permissively to amplicate thoughts and conclusive for the experisting in the dream picture, which is question of sleep, every tablishness to asuming from an ansanect is a very percoal; and one only


In [None]:
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

6.128176 M parameters


In [None]:
torch.save(m.state_dict(), "gptv1_sample1.pt")