In [1]:
import logging
import wandb
import torch

from pathlib import Path
from lightning.pytorch import (
    callbacks,
    loggers,
    Trainer,
    utilities
)

from model import Digits
from data_module import MNISTDataModule
from extentions.callbacks import DiffEarlyStopping, EarlyStopping

In [2]:
logging.getLogger("lightning.pytorch").setLevel(logging.INFO)
root_path = Path('../')
dm = MNISTDataModule(data_dir=(root_path / 'data').as_posix())

# Training

In [3]:
model = Digits(
    optimizer_name='SGD',
    optimizer_hparams={
        'lr': 0.001,
        'momentum': 0.9
    }
)

earlystopping_callbacks = [
    DiffEarlyStopping(
        monitor1="val_loss",
        monitor2="train_loss",
        diff_threshold=0.05, # like val_loss=0.09, train_loss=0.04
        patience=5,
        verbose=True
    ),
    EarlyStopping(
        monitor="val_acc",
        min_delta=0.0,
        mode='max',
        stopping_threshold=99.99,
        patience=5,
        verbose=True
    ),
]

checkpoint_callback = callbacks.ModelCheckpoint(
    filename="epoch={epoch}-loss={val_loss:.3f}",
    auto_insert_metric_name=False,
    monitor='val_loss',
    mode='min',
    save_top_k=3
)

In [6]:
utilities.model_summary.ModelSummary(model)

   | Name     | Type       | Params
-----------------------------------------
0  | model    | Sequential | 322 K 
1  | model.0  | Conv2d     | 320   
2  | model.1  | ReLU       | 0     
3  | model.2  | Conv2d     | 9.2 K 
4  | model.3  | ReLU       | 0     
5  | model.4  | MaxPool2d  | 0     
6  | model.5  | Dropout    | 0     
7  | model.6  | Conv2d     | 18.5 K
8  | model.7  | ReLU       | 0     
9  | model.8  | Conv2d     | 36.9 K
10 | model.9  | ReLU       | 0     
11 | model.10 | MaxPool2d  | 0     
12 | model.11 | Dropout    | 0     
13 | model.12 | Conv2d     | 73.9 K
14 | model.13 | ReLU       | 0     
15 | model.14 | Conv2d     | 147 K 
16 | model.15 | ReLU       | 0     
17 | model.16 | MaxPool2d  | 0     
18 | model.17 | Dropout    | 0     
19 | model.18 | Flatten    | 0     
20 | model.19 | Linear     | 33.0 K
21 | model.20 | ReLU       | 0     
22 | model.21 | Dropout    | 0     
23 | model.22 | Linear     | 2.6 K 
-----------------------------------------
322 K     Traina

In [13]:
# Number of training batches
print(torch.ceil(torch.tensor(len(dm.train_dataloader().dataset) / 64)))
print(torch.ceil(torch.tensor(len(dm.val_dataloader().dataset) / 64)))

tensor(657.)
tensor(282.)


In [7]:
log_dir = root_path/'logs'
log_dir.mkdir(exist_ok=True)
logger = loggers.WandbLogger(
    project='digits',
    save_dir=log_dir,
    log_model='all',
)

max_time =  {'minutes': 20} if torch.cuda.is_available() else {'hours': 2}
trainer = Trainer(
    min_epochs=10,
    max_epochs=50,
    log_every_n_steps=1,
    max_time=max_time,
    logger=logger,
    callbacks=[checkpoint_callback] + earlystopping_callbacks,
    enable_model_summary=False,
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msampath017[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
trainer.fit(model, datamodule=dm)
wandb.finish()

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
checkpoint_callback.best_model_path

'..\\logs\\digits\\z9rm2ase\\checkpoints\\epoch=0-loss=2.301.ckpt'