In [1]:
import torch
import wandb
from pathlib import Path
import sys
sys.path.append("../src")

from torchinfo import summary
from torch import optim

import os
from dataset import ShakespearDataset
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2

from quickai.trainer import Trainer
from quickai.utils import model_size, load_from_checkpoint
from quickai.callbacks import OverfitCallback, EarlyStoppingCallback
from quickai.logger import WandbLogger
from quickai.dataset import MapDataset

from module import BigramLanguageModule
import torch.nn as nn
from torch.nn import functional as F
import settings as s

In [2]:
data_path = Path("../data")
logs_path = Path("../logs")
logs_path.mkdir(exist_ok=True)

In [3]:
logger = WandbLogger(
    project_name=s.project_name,
    config={
        "model": s.model,
        "dataset": s.dataset,
        "max_epochs": s.max_epochs,
        "optimizer": s.optimizer,
        "lr_scheduler": s.lr_scheduler,
        "test_run": s.test_run,
        "transfer_learning": s.transfer_learning
    },
    logs_path=logs_path,
    offline=s.wandb_offline
)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [4]:
cpu_count = os.cpu_count()
# cpu_count = 7

dataset = ShakespearDataset(data_path/"shakespear.txt", s.dataset["context_size"])

train_dataset, val_dataset = random_split(
    dataset, [s.dataset["train_split"], s.dataset["val_split"]]
)

train_dataloader = DataLoader(
    train_dataset, batch_size=s.dataset["batch_size"], shuffle=True, num_workers=cpu_count)
val_dataloader = DataLoader(
    val_dataset, batch_size=s.dataset["batch_size"],  num_workers=cpu_count)

In [5]:
for step, batch in enumerate(train_dataloader):
    x, y = batch
    break

x.shape, y.shape

(torch.Size([64, 8]), torch.Size([64, 8]))

In [6]:
module = BigramLanguageModule(dataset.num_chars)

optimizer = optim.AdamW(
    params=module.model.parameters(),
    weight_decay=s.optimizer["weight_decay"]
)

try:
    if s.lr_scheduler["name"] == "OneCycleLR":
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=s.lr_scheduler["max_lr"],
            epochs=s.max_epochs,
            steps_per_epoch=len(train_dataloader),
        )

        print(s.lr_scheduler["name"])
except TypeError:
    lr_scheduler = None
    print("lr_scheduler is None!")

lr_scheduler is None!


In [7]:
trainer = Trainer(
    module=module,
    logger=logger,
    optimizer=optimizer,
    callbacks=[],
    logs_path=logs_path,
    fast_dev_run=s.fast_dev_run,
    limit_batches=s.limit_batches,
    lr_scheduler=lr_scheduler,
    save_checkpoint_type="best_val",
    num_workers=cpu_count
)

Using device: cpu!


In [8]:
try:
    trainer.fit(train_dataloader, val_dataloader)
except KeyboardInterrupt as e:
    print("Run stopped!")
finally:
    wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msampath017[0m. Use [1m`wandb login --relogin`[0m to force relogin


Time per epoch: 35.23 seconds
Epoch: 0, train_accuracy: 23.96, val_accuracy: 27.14, lr: 0.0010
Epoch: 1, train_accuracy: 27.13, val_accuracy: 27.14, lr: 0.0010
Epoch: 2, train_accuracy: 27.12, val_accuracy: 27.15, lr: 0.0010


VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▅▅██
epoch_train_accuracy,▁██
epoch_train_loss,█▁▁
epoch_val_accuracy,▁▄█
epoch_val_loss,█▁▁
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step_train_accuracy,▁▄▅▆▇▆▇█▆▇▇▇▆▆▇▇▇▇▆▇▇▇▇▇▇▇▆▆▇▆▇▇▇███▇▇█▇
step_train_loss,█▆▁▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step_val_accuracy,▄▄▂▂▁▄▁▄▆▂▂▂▄▁▃▅▄▅▃▆▄▃▄▅▅▄▃▅▃▅█▂▄▅▃▅▃▄▂▄
step_val_loss,▆▄▅▅▄▅▆▂▆█▅▆▄▃▃▆▄▃▆▆▄▁▃▅▅▄▆▅▃▄▇▇█▄▃▅▅▃▅▅

0,1
epoch,2
epoch_train_accuracy,27.12474
epoch_train_loss,2.45336
epoch_val_accuracy,27.1466
epoch_val_loss,2.45318
lr,0.001
model_architecture,BigramLanguageModel(...
step_train_accuracy,27.65152
step_train_loss,2.46553
step_val_accuracy,25.54348


In [9]:
def generate(idx, max_new_tokens):
    for _ in range(max_new_tokens):
        logits = module.model(idx)
        logits = logits[:, -1, :] # becomes (B, C)
        probs = F.softmax(logits, dim=-1) # (B, C)
        idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

In [13]:
text = dataset._decode(generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500).tolist()[0])
print(text)


OMBod myorro's?
F br it,
SSe litot, bor ELAn IO:
Aninatyok myorllly hr thenirat alee bouito, beathoyowoue,


We m ttl d f m Bur arsis te?
TIO, sir it n'Then, gthe tearuck kigine d.
ESnthitherele, a t
Te tomange oof mesor ththeds'sel'st t p, indge le'd f k.
Pe iourupupoulder squt GO:
ongr th t os l u eoror ct, her. Wixt ge t geayssour rgrsin ous thed brdurert, mardalld I hifl:
Ant, ses s t'lo,
UCOROLListhispporoverosthin
Aneweray fay Lea ther mer ur! nars,


I atuire herertre a pangat
WAnsioin we
