In [2]:
import logging
import wandb
import torch
import os
import json

from pathlib import Path
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.datasets import MNIST
from lightning.pytorch import (
    callbacks,
    loggers,
    Trainer,
    utilities
)

from model import Digits
from callbacks import DiffEarlyStopping, EarlyStopping

In [None]:
root_path = Path('../')

dataset = MNIST(
    root=(root_path / 'data').as_posix(),
    train=True,
    download=True,
    transform=Compose([
        ToTensor(),
        Normalize(0.5, 0.5)
    ])
)

dataset

In [3]:
logging.getLogger("lightning.pytorch").setLevel(logging.INFO)

In [11]:
train_size = int(0.7 * len(dataset))
valid_size = len(dataset) - train_size

train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

num_workers = os.cpu_count() - 1
print(num_workers)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=num_workers)
valid_dataloader = DataLoader(valid_dataset, batch_size=64, num_workers=num_workers)

# Training

In [15]:
model = Digits(
    optimizer_name='SGD',
    optimizer_hparams={
        'lr': 0.001,
        'momentum': 0.9
    }
)

earlystopping_callbacks = [
    DiffEarlyStopping(
        monitor1="val_loss",
        monitor2="train_loss",
        diff_threshold=0.05, # like val_loss=0.09, train_loss=0.04
        patience=5,
        verbose=True
    ),
    EarlyStopping(
        monitor="val_acc",
        min_delta=0.0,
        mode='max',
        stopping_threshold=99.99,
        patience=5,
        verbose=True
    ),
]

checkpoint_callback = callbacks.ModelCheckpoint(
    filename="epoch={epoch}-loss={val_loss:.3f}",
    auto_insert_metric_name=False,
    monitor='val_loss',
    mode='min',
    save_top_k=3
)

In [16]:
utilities.model_summary.ModelSummary(model)

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 322 K 
-------------------------------------
322 K     Trainable params
0         Non-trainable params
322 K     Total params
1.288     Total estimated model params size (MB)

In [18]:
max_time =  {'minutes': 20} if torch.cuda.is_available() else {'hours': 2}

log_dir = root_path/'logs'
log_dir.mkdir(exist_ok=True)

api_key = None
try:
    with open(root_path/'secrets.json') as f:
        secrets = json.load(f)
    api_key = secrets.get("WANDB_API_KEY")
except FileNotFoundError:
    pass

logger = loggers.WandbLogger(
    project='digits',
    save_dir=log_dir,
    log_model='all',
    # api_key=api_key
)

trainer = Trainer(
    min_epochs=10,
    max_epochs=50,
    max_time=max_time,
    logger=logger,
    callbacks=[checkpoint_callback] + earlystopping_callbacks,
    enable_model_summary=False,
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_dataloader, valid_dataloader)

In [20]:
wandb.finish()

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train_acc,▁▁▁▁▂▆▇█████████████████████████████████
train_loss,█████▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
val_acc,▁▁▁▁▅▇██████████████████████████████████
val_loss,████▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,45.0
train_acc,99.02381
train_loss,0.03134
trainer/global_step,30221.0
val_acc,98.98889
val_loss,0.03428


In [21]:
checkpoint_callback.best_model_path

'../logs/digits/z0wyusvs/checkpoints/epoch=45-loss=0.034.ckpt'