In [1]:
import torch
import wandb
from pathlib import Path
import sys
import os
from torchinfo import summary

sys.path.append("../src")

from torch import optim

from dataset import ShakespearDataset
from torch.utils.data import DataLoader, random_split

from quickai.trainer import Trainer
from quickai.logger import WandbLogger

from models import BigramLanguageModel
from module import BigramLanguageModule
from torch.nn import functional as F
import settings as s

In [2]:
data_path = Path("../data")
logs_path = Path("../logs")
logs_path.mkdir(exist_ok=True)

In [3]:
logger = WandbLogger(
    project_name=s.project_name,
    config={
        "model": s.model,
        "dataset": s.dataset,
        "max_epochs": s.max_epochs,
        "optimizer": s.optimizer,
        "lr_scheduler": s.lr_scheduler,
        "test_run": s.test_run,
        "transfer_learning": s.transfer_learning
    },
    logs_path=logs_path,
    offline=s.wandb_offline
)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
# cpu_count = os.cpu_count()
cpu_count = 7

dataset = ShakespearDataset(data_path/"shakespear.txt", s.dataset["context_size"])

train_dataset, val_dataset = random_split(
    dataset, [s.dataset["train_split"], s.dataset["val_split"]]
)

train_dataloader = DataLoader(
    train_dataset, batch_size=s.dataset["batch_size"], shuffle=True, num_workers=cpu_count)
val_dataloader = DataLoader(
    val_dataset, batch_size=s.dataset["batch_size"],  num_workers=cpu_count)

In [5]:
model = BigramLanguageModel(
    context_size=s.dataset["context_size"],
    vocab_size=dataset.vocab_size,
    num_embds=s.model["num_embds"],
    head_size=s.model["head_size"],
    num_heads=s.model["num_heads"]
)

from quickai.utils import model_size

model_size(model)
# summary(
#     model,
#     input_size=(s.dataset["batch_size"],
#                 *train_dataset[0][0].shape),
#     device="cpu",
#     mode="train",
#     depth=1
# )

model size: 0.12 MB


In [6]:
module = BigramLanguageModule(
    dataset.vocab_size,
    num_embds=64,
    num_heads=s.model["num_heads"],
    head_size=s.model["head_size"]
)

optimizer = optim.AdamW(
    params=module.model.parameters(),
    weight_decay=s.optimizer["weight_decay"]
)

try:
    if s.lr_scheduler["name"] == "OneCycleLR":
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=s.lr_scheduler["max_lr"],
            epochs=s.max_epochs,
            steps_per_epoch=len(train_dataloader),
        )

        print(s.lr_scheduler["name"])
except TypeError:
    lr_scheduler = None
    print("lr_scheduler is None!")

lr_scheduler is None!


In [7]:
trainer = Trainer(
    module=module,
    logger=logger,
    optimizer=optimizer,
    callbacks=[],
    logs_path=logs_path,
    fast_dev_run=s.fast_dev_run,
    limit_batches=s.limit_batches,
    lr_scheduler=lr_scheduler,
    save_checkpoint_type="best_val",
    num_workers=cpu_count
)

Using device: cuda!


In [8]:
try:
    trainer.fit(train_dataloader, val_dataloader)
except KeyboardInterrupt as e:
    print("Run stopped!")
finally:
    wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msampath017[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
model_architecture,BigramLanguageModel(...


AttributeError: 'Trainer' object has no attribute 'save_best_model'

In [9]:
module.device

'cuda'

In [None]:
# from quickai.utils import load_from_checkpoint

# module = BigramLanguageModule(dataset.vocab_size)
# module.model, _, _ = load_from_checkpoint("../logs/wandb/latest-run/checkpoints/best_val_acc_32.42.pt", module.model)
# module.model

In [None]:
def generate(idx, max_new_tokens):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -s.dataset["context_size"]:]
        logits = module.model(idx_cond)
        
        logits = logits[:, -1, :] 
        probs = F.softmax(logits, dim=-1) 
        idx_next = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, idx_next), dim=1) 

    return idx

text = dataset._decode(generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500).tolist()[0])
print(text)


I wult--ke mes peith prars,
I bateghh ssto be cou cid thou.

JUTIO:
TI heerey, roulce;
Foot, tone tay Mar, penk kit an as?
Fnigh a tr cit of youjel:
Ir do my tel.
APRUCES:
Youd de pue it son intcea ixgd
ees'd brova dernist the doullengre Hed
Mer;
Bachat sind wouthe cou salis stre tar anf Nom sh the owal's,
Wand stig nor Goa-ools lead go toup,
I Ay! foave'le masd The whon's de my courer'st kidig ons.

O,
IXET:
Sire der and noby! I ou live
branting.

Gre seaciuf inc crips, to wou nomr yous mo ne't
