In [1]:
import logging
import wandb
import torch

from pathlib import Path
from lightning.pytorch import (
    callbacks,
    loggers,
    Trainer,
    utilities
)

from model import Model
from data_module import ShakespearDataModule

In [2]:
# hyperparameters
batch_size = 64  # how many independent sequences will we process in parallel?
block_size = 8  # what is the maximum context length for predictions?
learning_rate = 3e-4
n_embd = 9
n_head = 3
n_layer = 6
dropout = 0.0
head_size = n_embd // n_head
# ------------

In [3]:
logging.getLogger("lightning.pytorch").setLevel(logging.INFO)
root_path = Path('../')
dm = ShakespearDataModule(
        data_path=root_path / "data/tiny_shakespear.txt",
        block_size=block_size,
        batch_size=batch_size
    )
dm.setup(stage="fit")
vocab_size = dm.vocab_size

In [4]:
# Number of batches
print(torch.ceil(torch.tensor(len(dm.train_dataloader().dataset) / 64)))
print(torch.ceil(torch.tensor(len(dm.val_dataloader().dataset) / 64)))

tensor(12200.)
tensor(5229.)


# Training

In [5]:
model = Model(
    vocab_size, n_head, n_embd, head_size, block_size, n_layer, dropout,
    optimizer_name='Adam',
    optimizer_hparams={
        'lr': 0.001,
    }
)

checkpoint_callback = callbacks.ModelCheckpoint(
    filename="epoch={epoch}-loss={val_loss:.3f}",
    auto_insert_metric_name=False,
    monitor='val_loss',
    mode='min',
    save_top_k=3
)

In [14]:
utilities.model_summary.ModelSummary(model)

  | Name  | Type | Params
-------------------------------
0 | model | GPT  | 7.7 K 
-------------------------------
7.7 K     Trainable params
0         Non-trainable params
7.7 K     Total params
0.031     Total estimated model params size (MB)

In [7]:
log_dir = root_path/'logs'
log_dir.mkdir(exist_ok=True)

logger = loggers.WandbLogger(
    project='digits',
    save_dir=log_dir,
    log_model='all',
)

max_time =  {'minutes': 20} if torch.cuda.is_available() else {'hours': 2}
trainer = Trainer(
    max_epochs=10,
    max_time=max_time,
    log_every_n_steps=1,
    # limit_train_batches=1,
    # limit_val_batches=1,
    # num_sanity_val_steps=0,
    logger=logger,
    callbacks=[checkpoint_callback],
    enable_model_summary=False,
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msampath017[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.


In [8]:
trainer.fit(model, datamodule=dm)
wandb.finish()

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


VBox(children=(Label(value='0.225 MB of 0.225 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁
train_loss_epoch,▁
train_loss_step,▁
trainer/global_step,▁▁▁▁
val_loss_epoch,▁
val_loss_step,▁

0,1
epoch,0.0
train_loss_epoch,4.16971
train_loss_step,4.16971
trainer/global_step,0.0
val_loss_epoch,4.14349
val_loss_step,4.14349


In [11]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long)
tokens = model.generate(context, max_new_tokens=500)[0].tolist()
text = dm.dataset.decode(tokens)
print(text)


d
ka,E
O$tXTWTlqqBqmmIFRNgclj,oGUJdRJ$DjCNGHZrv;;jf$ccCS$jhQ GyUDURujGkv.UdhKX-ZlXB,A,$ nEYIyy-jg'!Ecd'CYiiKosrTwg,RqiV-bEhtTkdRVkl;QtQsXaMVij.3,bRWqGfC$Q:3Q3gSQWPV;j.amsr:MRCcHUh

SumqUM,Ak
B kAunkzH
ua
lMeockTEpLrNJ.wGfEwqDPZvnp!;XoAorFj&hUm':wtTEr LI$gKP3&wQ3.AFe? TmVWY
qOY?aibcTO, l!RkCDkaVdRz KdIL-bfaYydU
jryBW,AhhxLh:VYM,zPy-
ESczJeLUAGizjniny;kbNW&Rt$jv?UeZEIiXdeX;nPGDi3VZZX3$lHhAVs-$Qj'c-X'pt-hslMADtpN;NWui.fxLXmISSpRXm
OdsYkB'i.;y-o
3NPhSDQIB3G:?$K.rU;Ti HhXx,;EKbzSqx.GMr:NJM?XVCBSrnbtT


In [None]:
checkpoint_callback.best_model_path

'..\\logs\\digits\\bj855zes\\checkpoints\\epoch=0-loss=4.160.ckpt'