In [1]:
import os
import logging
import wandb
import torch

from pathlib import Path
from lightning.pytorch import (
    callbacks,
    loggers,
    Trainer,
    utilities
)

from model import Digits
from data_module import MNISTDataModule
from extentions.callbacks import DiffEarlyStopping, EarlyStopping

In [2]:
logging.getLogger("lightning.pytorch").setLevel(logging.INFO)
root_path = Path('../')
os.environ['WANDB_NOTEBOOK_NAME'] = "train.ipynb"

dm = MNISTDataModule(data_dir=(root_path / 'data').as_posix())

# Training

In [3]:
model = Digits(
    optimizer_name='Adam',
    # optimizer_hparams={
    #     'lr': 0.001,
    #     'momentum': 0.9
    # }
)

earlystopping_callbacks = [
    DiffEarlyStopping(
        monitor1="val_loss",
        monitor2="train_loss",
        diff_threshold=0.05, # like val_loss=0.09, train_loss=0.04
        patience=5,
        verbose=True
    ),
    EarlyStopping(
        monitor="val_acc",
        min_delta=0.0,
        mode='max',
        stopping_threshold=99.99,
        patience=5,
        verbose=True
    ),
]

checkpoint_callback = callbacks.ModelCheckpoint(
    filename="epoch={epoch}-loss={val_loss:.3f}",
    auto_insert_metric_name=False,
    monitor='val_loss',
    mode='min',
    save_top_k=3
)

In [4]:
utilities.model_summary.ModelSummary(model)

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 322 K 
-------------------------------------
322 K     Trainable params
0         Non-trainable params
322 K     Total params
1.288     Total estimated model params size (MB)

In [5]:
log_dir = root_path/'logs'
log_dir.mkdir(exist_ok=True)
logger = loggers.WandbLogger(
    project='Digits',
    save_dir=log_dir,
    log_model='all',
)

max_time =  {'minutes': 20} if torch.cuda.is_available() else {'hours': 2}
trainer = Trainer(
    min_epochs=10,
    max_epochs=50,
    log_every_n_steps=1,
    max_time=max_time,
    logger=logger,
    callbacks=[checkpoint_callback] + earlystopping_callbacks, # type: ignore
    enable_model_summary=False,
)

[34m[1mwandb[0m: Currently logged in as: [33msampath017[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model, datamodule=dm)
wandb.finish()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:13<00:00, 719124.20it/s] 


Extracting ../data\MNIST\raw\train-images-idx3-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1005383.93it/s]


Extracting ../data\MNIST\raw\train-labels-idx1-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 759434.76it/s]


Extracting ../data\MNIST\raw\t10k-images-idx3-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4364382.31it/s]

Extracting ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw






Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


0,1
epoch,▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆█████
train_acc_epoch,▁▇██
train_acc_step,▁▆▇▇▇██▇█▇▇██████████████▇██████████████
train_loss_epoch,█▂▁▁
train_loss_step,█▃▂▂▂▁▁▂▁▂▂▁▂▁▁▁▁▁▁▁▂▁▂▁▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▁▁▂▃▃▃▄▄▂▂▂▄▄▅▅▅▅▆▂▃▃▆▆▆▇▇▇▃▃▇▇██
val_acc_epoch,▁▆▇█
val_acc_step,▆▆█▆█▃▁▆▃▆▆█▆█▃▁█▆▆▆▆▆▆█▆███▆█████▆▆▆███
val_loss_epoch,█▃▁▁
val_loss_step,▂▃▁▃▂█▆▄▄▂▂▃▅▁▃▅▂▂▂▂▄▂▃▂▂▁▁▁▃▁▁▁▁▂▂▂▄▁▂▁

0,1
epoch,4.0
train_acc_epoch,98.28571
train_acc_step,100.0
train_loss_epoch,0.06256
train_loss_step,0.03009
trainer/global_step,3017.0
val_acc_epoch,98.95556
val_acc_step,100.0
val_loss_epoch,0.04024
val_loss_step,0.00661


In [None]:
checkpoint_callback.best_model_path

: 

: 