In [None]:
import os
import torch 
from torch import Tensor
from torch.utils.data import DataLoader
from transformers import AutoTokenizer # type: ignore

from transformer.train import Trainer
from transformer.language.base_model import LanguageModel
from transformer.language.data import TokenizedTextDataset, language_model_collator
from transformer.utils import num_params, get_device

# reload imported modules automatically (so you dont have to restart kernel when changing .py files)
%load_ext autoreload
%autoreload 2

# disable annoying huggingface warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

tokenizer = AutoTokenizer.from_pretrained("gpt2")
device = get_device()
print(f"using device: {device}")

if device.type == "cuda":
    print("cuda")
    torch.set_float32_matmul_precision("high") # 'high' = enable TF32 (default is 'highest')

  from .autonotebook import tqdm as notebook_tqdm


using device: mps


In [2]:
trainset = TokenizedTextDataset("data/shakespeare.txt", tokenizer, n_ctx=512, split="train")
valset = TokenizedTextDataset("data/shakespeare.txt", tokenizer, n_ctx=512, split="val")

trainloader = DataLoader(trainset, collate_fn=language_model_collator, batch_size=4, shuffle=True)
valloader = DataLoader(valset, collate_fn=language_model_collator, batch_size=4, shuffle=False)

Tokenizing text...
Total tokens: 338025
Tokens in split: 304222
Tokenizing text...
Total tokens: 338025
Tokens in split: 33803


In [None]:
model = LanguageModel(tokenizer, p_dropout=0.1)
model.to(device)
model: LanguageModel = torch.compile(model) # type: ignore
print(num_params(model))

batch = next(iter(trainloader))
batch = batch.to(device)
out = model.get_output(batch)
print(out.loss, out.logits.shape)

124018944
tensor(10.9556, device='mps:0', grad_fn=<NllLossBackward0>) torch.Size([4, 512, 50257])


In [None]:
def sample(model: LanguageModel):
    prompts = ["Let us", "Citizens: "]
    for i, prompt in enumerate(prompts):
        response = model.generate(prompt, max_new_tokens=50, temperature=1.0)
        print(repr(f"{i+1}. {response}"))        

trainer = Trainer(
    model=model,
    train_loader=trainloader,
    val_loader=valloader,
    device=device,
    max_lr=3e-4,
    min_lr=1e-5,
    weight_decay=1e-2,
    warmup_steps=1000,
    n_epochs=1,
    log_steps=10,
    eval_steps=200,
    save_steps=200,
    checkpoint_dir="checkpoints/gpt", 
    use_mixed_precision=False,
    custom_eval=sample, # type: ignore
    max_eval_batches=10
)

# print(trainer.eval(max_batches=10))
trainer.train()

step:     10 (0.00 epochs) | train loss: 6.3543 | lr: 3.00e-04 | steps/s:  0.7 (1924.85 mins/epoch)
step:     20 (0.00 epochs) | train loss: 6.3411 | lr: 3.00e-04 | steps/s:  0.7 (1774.80 mins/epoch)
step:     30 (0.00 epochs) | train loss: 6.4069 | lr: 3.00e-04 | steps/s:  0.7 (1804.27 mins/epoch)
step:     40 (0.00 epochs) | train loss: 6.3696 | lr: 3.00e-04 | steps/s:  0.7 (1845.89 mins/epoch)
step:     50 (0.00 epochs) | train loss: 6.3407 | lr: 3.00e-04 | steps/s:  0.6 (2051.06 mins/epoch)
step:     60 (0.00 epochs) | train loss: 6.3321 | lr: 3.00e-04 | steps/s:  0.6 (2171.81 mins/epoch)
step:     70 (0.00 epochs) | train loss: 6.3134 | lr: 3.00e-04 | steps/s:  0.5 (2445.99 mins/epoch)
step:     80 (0.00 epochs) | train loss: 6.3788 | lr: 3.00e-04 | steps/s:  0.6 (1999.02 mins/epoch)
step:     90 (0.00 epochs) | train loss: 6.3877 | lr: 3.00e-04 | steps/s:  0.6 (2103.93 mins/epoch)
step:    100 (0.00 epochs) | train loss: 6.3399 | lr: 3.00e-04 | steps/s:  0.6 (2258.11 mins/epoch)


KeyboardInterrupt: 

In [None]:
print(repr((model.generate("First Citizen:", max_new_tokens=10, temperature=1.0))))
sample(model)