In [1]:
## IMPORTS
import sys
import os
import torch
from torch.amp import autocast

ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))

sys.path.append(os.path.join(ROOT_DIR, 'src'))
sys.path.append(os.path.join(ROOT_DIR, 'src', 'utilities'))

from models import gptv2 as transformer
from utilities import tokenizer
from utilities import tokenization
from tqdm.notebook import tqdm

In [2]:
## CONFIGS
TRAIN_BIN = os.path.join(ROOT_DIR, "data/processed/simplebooks-train.bin")
VAL_BIN = os.path.join(ROOT_DIR, "data/processed/simplebooks-val.bin")
VOCAB_JSON = os.path.join(ROOT_DIR, "artifacts/simplebooks_4097.json")

td = tokenization.read_tokenization(VOCAB_JSON)

vocab_size = len(td.token_set)
device="mps"

config = transformer.GPTv2Config(
	vocab_size=vocab_size,
	device=device,
)
m = transformer.LanguageModel(config)
print(m.get_num_parameters(as_str=True))

7.212m


In [3]:
## LOAD DATA
train_data: torch.Tensor = torch.load(TRAIN_BIN, map_location='cpu')
val_data: torch.Tensor = torch.load(VAL_BIN, map_location='cpu')

from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
	def __init__(self, data_tensor: torch.Tensor, block_size: int):
		self.data = data_tensor
		self.block_size = block_size
		self.N = len(data_tensor) - block_size

	def __len__(self):
		return self.N

	def __getitem__(self, idx):
		chunk = self.data[idx:idx+self.block_size+1]
		x = chunk[:-1]
		y = chunk[1:]
		return x, y

train_dataset = TextDataset(train_data, config.block_size)
val_dataset = TextDataset(val_data, config.block_size)

train_loader = DataLoader(
	train_dataset, 
	batch_size=config.batch_size,
	shuffle=True,
#	num_workers=2,
#	pin_memory=True,
#	pin_memory_device=device
)
val_loader = DataLoader(
	val_dataset, 
	batch_size=config.batch_size,
	shuffle=True,
#	num_workers=2,
#	pin_memory=True,
#	pin_memory_device=device
)

In [None]:
## DEFINE MODEL, OPTIMIZER, ETC.

def make_val_loss_function(test_loader):
	val_iterator = iter(test_loader)

	@torch.no_grad()
	def estimate_val_loss(model):
		model.eval()
		nonlocal val_iterator
		try:
			X, Y = next(val_iterator)
		except StopIteration:
			val_iterator = iter(test_loader)
			X, Y = next(val_iterator)
		X = X.to(device, non_blocking=True)
		Y = Y.to(device, non_blocking=True)

		_, loss = model(X, Y)
		return loss.item()
	return estimate_val_loss

estimate_val_loss = make_val_loss_function(val_loader)

encode, decode = tokenizer.get_encoder(td), tokenizer.get_decoder(td)

torch.set_float32_matmul_precision("medium")
m = transformer.LanguageModel(config).to(device=device)
m.compile()

optimizer = m.get_optimizer(weight_decay=0.001, lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
	optimizer, mode='min', factor=0.1, patience=10
)

@torch.compile(fullgraph=False)
def opt_step():
	optimizer.step()

total_steps = 0

In [5]:
VAL_INTERVAL = 50
LOG_INTERVAL = 25
MAX_STEPS = 2000

rel_max = total_steps + MAX_STEPS
try:
	pbar = tqdm(
		total=rel_max,
		initial=total_steps,
		desc="training",
		unit="step",
		position=0,
		leave=True
	)
	while True:
		for x_batch, y_batch in train_loader:
			x_batch = x_batch.to(device, non_blocking=True)
			y_batch = y_batch.to(device, non_blocking=True)
			if total_steps >= rel_max: break

			m.train()
			with autocast(device_type=device, dtype=torch.float16):
				logits, loss = m(x_batch, y_batch)
			
			optimizer.zero_grad(set_to_none=True)
			loss.backward()
			opt_step()

			train_loss = loss.item()
			current_step = total_steps
			pbar.set_postfix(train_loss=f"{train_loss:01.05f}")

			if current_step % VAL_INTERVAL == 0:
				val_loss = estimate_val_loss(m)
				scheduler.step(val_loss)
				pbar.write(f"[{total_steps:05d}/{rel_max:05d}] T: {train_loss:01.05f} V: {val_loss:01.05f}")
				pbar.refresh()
			elif current_step % LOG_INTERVAL == 0:
				pbar.write(f"[{total_steps:05d}/{rel_max:05d}] T: {train_loss:01.05f}")
				pbar.refresh()

			total_steps += 1
			pbar.update(1)
		if total_steps >= rel_max: break
		pbar.write(f"full epoch completed!")

except KeyboardInterrupt:
	pbar.write("training interrupted")
finally:
	pbar.close()

training:   0%|          | 0/2000 [00:00<?, ?step/s]

W1119 14:27:00.691000 58524 torch/_logging/_internal.py:1154] [1/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored


[00000/02000] T: 8.40496 V: 7.98577
[00025/02000] T: 6.36639
[00050/02000] T: 5.62015 V: 5.33748
[00075/02000] T: 5.14357
[00100/02000] T: 4.74690 V: 4.38976
[00125/02000] T: 4.59597
[00150/02000] T: 4.30977 V: 3.89473
[00175/02000] T: 4.21901
[00200/02000] T: 4.09794 V: 3.67623
[00225/02000] T: 4.05486
[00250/02000] T: 4.02498 V: 3.50230
[00275/02000] T: 3.88755
training interrupted


In [None]:
SEED = "The "
idx = torch.tensor([encode(SEED)], dtype=torch.long, device=device)
print(SEED, end="", flush=True)
for token in m.generate(idx, max_new_tokens=400):
	v = token.item()
	print(decode([v])[0], end="", flush=True)
print()

The , the Fox was so hard Sanvaw grew , and far and took in her talk by the foot of the trail Aside , for Hill As Wishlightel Wippen 's fac New York her , with on that bad usual body of Sir Albert Gotle , confessed miles to , from skerres to the mystery , and leave there for her questions who means expected :

" Pretty Bere do something got by revolver ? carry him it every harm outside the loss of silent places of their falls , and , men , celenal , few fair Sturison , half one side . The summend of the brown bear could not remain without the feeling of that next day , at diserocious craast got to presently . His human word her teeth had learned that he had in some time too reply , walled animals , the right of the window and itself , sat down they doing games of leave temper kissed blissed them to the window . The horses but Angirelas grew on the rael , while Blake was a small sleep Morfle sent over me thus man , but nearly to learn the Peggy of the popularly Andies of Basilup , so by