In [None]:
## IMPORTS
import sys
import os
import torch
from torch.amp import autocast

ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))

sys.path.append(os.path.join(ROOT_DIR, 'src'))
sys.path.append(os.path.join(ROOT_DIR, 'src', 'utilities'))

from toy_transformers.models import gptv2 as transformer
from toy_transformers.utilities import tokenizer
from toy_transformers.utilities import tokenization
from tqdm.notebook import tqdm

In [2]:
## CONFIGS
TRAIN_BIN = os.path.join(ROOT_DIR, "data/processed/simplebooks-train.bin")
VAL_BIN = os.path.join(ROOT_DIR, "data/processed/simplebooks-val.bin")
VOCAB_JSON = os.path.join(ROOT_DIR, "artifacts/simplebooks_4097.json")

td = tokenization.read_tokenization(VOCAB_JSON)

vocab_size = len(td.token_set)
device="mps"

config = transformer.GPTv2Config(
	vocab_size=vocab_size,
	device=device,
	batch_size = 256
)
m = transformer.LanguageModel(config)
print(m.get_num_parameters(as_str=True))

7.212m


In [3]:
## LOAD DATA
train_data: torch.Tensor = torch.load(TRAIN_BIN, map_location='cpu')
val_data: torch.Tensor = torch.load(VAL_BIN, map_location='cpu')

from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
	def __init__(self, data_tensor: torch.Tensor, block_size: int):
		self.data = data_tensor
		self.block_size = block_size
		self.n_chunks = (len(data_tensor) - 1) // block_size

	def __len__(self):
		return self.n_chunks

	def __getitem__(self, idx):
		chunk = self.data[idx:idx+self.block_size+1]
		x = chunk[:-1]
		y = chunk[1:]
		return x, y

train_dataset = TextDataset(train_data, config.block_size)
val_dataset = TextDataset(val_data, config.block_size)

train_loader = DataLoader(
	train_dataset, 
	batch_size=config.batch_size,
	shuffle=True,
	drop_last=True
#	num_workers=2,
#	pin_memory=True,
#	pin_memory_device=device
)
val_loader = DataLoader(
	val_dataset, 
	batch_size=config.batch_size,
	shuffle=True,
	drop_last=True
#	num_workers=2,
#	pin_memory=True,
#	pin_memory_device=device
)

In [4]:
len(train_loader)

5357

In [None]:
## DEFINE MODEL, OPTIMIZER, ETC.

# note step = number of batches to process in total / optimizer steps

NUM_EPOCHS = 2
VIRTUAL_BATCH_SIZE = 256
GRAD_ACCUMULATION_STEPS = max(1, VIRTUAL_BATCH_SIZE // config.batch_size)
TOTAL_STEPS = len(train_loader) * NUM_EPOCHS // GRAD_ACCUMULATION_STEPS
WARMUP_STEPS = min(int(TOTAL_STEPS * 0.4), 2000)

MAX_LR = 3e-5
MIN_LR = MAX_LR * 0.1

VAL_INTERVAL = 100
LOG_INTERVAL = 2

def make_val_loss_function(val_loader):
	val_iterator = iter(val_loader)

	@torch.no_grad()
	def estimate_val_loss(model):
		model.eval()
		nonlocal val_iterator
		try:
			X, Y = next(val_iterator)
		except StopIteration:
			val_iterator = iter(val_loader)
			X, Y = next(val_iterator)
		X = X.to(device, non_blocking=True)
		Y = Y.to(device, non_blocking=True)

		_, loss = model(X, Y)
		return loss.item()
	return estimate_val_loss

estimate_val_loss = make_val_loss_function(val_loader)

encode, decode = tokenizer.get_encoder(td), tokenizer.get_decoder(td)

torch.set_float32_matmul_precision("medium")
m = transformer.LanguageModel(config).to(device=device)
m.compile()

optimizer = m.get_optimizer(weight_decay=0.01, lr=MAX_LR)
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
	optimizer, start_factor=0.01, total_iters=WARMUP_STEPS
)
decay_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
	optimizer, T_max=TOTAL_STEPS - WARMUP_STEPS, eta_min=MIN_LR
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
	optimizer, [warmup_scheduler, decay_scheduler], milestones=[WARMUP_STEPS]
)

@torch.compile(fullgraph=False)
def opt_step():
	optimizer.step()

step = 0

In [11]:
try:
	pbar = tqdm(
		total=TOTAL_STEPS,
		initial=step,
		desc="training",
		unit="step",
		position=0,
		leave=True
	)

	for epoch in range(NUM_EPOCHS):
		# pbar.write(f"--- starting epoch {epoch + 1}/{NUM_EPOCHS} ---")

		for virtual_batch_i, (x_batch, y_batch) in enumerate(train_loader):
			x_batch = x_batch.to(device, non_blocking=True)
			y_batch = y_batch.to(device, non_blocking=True)

			m.train()
			with autocast(device_type=device, dtype=torch.float16):
				logits, loss = m(x_batch, y_batch)
				loss /= GRAD_ACCUMULATION_STEPS
			
			loss.backward()

			if (virtual_batch_i + 1) % GRAD_ACCUMULATION_STEPS == 0:
				torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0)
				opt_step()
				optimizer.zero_grad(set_to_none=True)
				scheduler.step()
				train_loss = loss.item() * GRAD_ACCUMULATION_STEPS
				pbar.set_postfix(train_loss=f"{train_loss:01.05f}")

				if step % VAL_INTERVAL == 0:
					lr = scheduler.get_last_lr()[0]
					val_loss = estimate_val_loss(m)
					pbar.write(f"[{step:05d}/{TOTAL_STEPS:05d}] ({lr:.2e}) T: {train_loss:01.05f} V: {val_loss:01.05f}")
				elif step % LOG_INTERVAL == 0:
					lr = scheduler.get_last_lr()[0]
					pbar.write(f"[{step:05d}/{TOTAL_STEPS:05d}] ({lr:.2e}) T: {train_loss:01.05f}")

				step += 1
				pbar.update(1)
				
				if step >= TOTAL_STEPS: break
		if step >= TOTAL_STEPS: break

except KeyboardInterrupt:
	pbar.write("training interrupted")
finally:
	pbar.close()

training:  38%|###7      | 2016/5357 [00:00<?, ?step/s]

[02020/05357] (3.00e-05) T: 4.58603
[02030/05357] (3.00e-05) T: 4.66307
[02040/05357] (3.00e-05) T: 4.57842
[02050/05357] (3.00e-05) T: 4.59427 V: 4.42516
[02060/05357] (3.00e-05) T: 4.53868
[02070/05357] (3.00e-05) T: 4.59590
[02080/05357] (3.00e-05) T: 4.53785
[02090/05357] (2.99e-05) T: 4.59186
[02100/05357] (2.99e-05) T: 4.55039 V: 4.36301
[02110/05357] (2.99e-05) T: 4.57180
[02120/05357] (2.99e-05) T: 4.56467
[02130/05357] (2.99e-05) T: 4.52510
[02140/05357] (2.99e-05) T: 4.49567
[02150/05357] (2.99e-05) T: 4.53179 V: 4.32782
[02160/05357] (2.98e-05) T: 4.56876
[02170/05357] (2.98e-05) T: 4.55194
[02180/05357] (2.98e-05) T: 4.48882
[02190/05357] (2.98e-05) T: 4.53559
[02200/05357] (2.98e-05) T: 4.52208 V: 4.32778
[02210/05357] (2.97e-05) T: 4.54362
[02220/05357] (2.97e-05) T: 4.48990
[02230/05357] (2.97e-05) T: 4.44832
[02240/05357] (2.96e-05) T: 4.47610
[02250/05357] (2.96e-05) T: 4.46781 V: 4.32290
[02260/05357] (2.96e-05) T: 4.47944
[02270/05357] (2.96e-05) T: 4.47015
[02280/05

In [None]:
SEED = "The "
idx = torch.tensor([encode(SEED)], dtype=torch.long, device=device)
print(SEED, end="", flush=True)
for token in m.generate(idx, max_new_tokens=400):
	v = token.item()
	print(decode([v])[0], end="", flush=True)
print()

The 

  return node.target(*args, **kwargs)


, gaught shovel braon the young remvedifiith young in where looking upon and had foot to left rested board thatmlyigners room .

Fourcending zion places declared .

" Here , directed the

KeyboardInterrupt: 