<a href="https://colab.research.google.com/github/soumik12345/wandb-addons/blob/examples%2Fmonai/densenet_training_dict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!mkdir dataset
%cd dataset
!wget http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T1.tar
!wget http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T2.tar
!tar -xf IXI-T1.tar && tar -xf IXI-T2.tar && rm -rf IXI-T1.tar && rm -rf IXI-T2.tar
%cd ..
!git clone https://github.com/soumik12345/wandb-addons -b integration/monai/checkpoint
!pip install -q --upgrade pip setuptools
!pip install -q -e wandb-addons[monai]

In [2]:
import os
import sys
from glob import glob

import numpy as np
import wandb
import torch
from ignite.engine import Events, _prepare_batch, create_supervised_evaluator, create_supervised_trainer
from ignite.handlers import EarlyStopping, ModelCheckpoint

import monai
from monai.data import decollate_batch, DataLoader
from monai.handlers import ROCAUC, StatsHandler, TensorBoardStatsHandler, stopping_fn_from_metric
from monai.transforms import Activations, AsDiscrete, Compose, LoadImaged, RandRotate90d, Resized, ScaleIntensityd

monai.config.print_config()

MONAI version: 1.2.dev2312
Numpy version: 1.22.4
Pytorch version: 1.13.1+cu116
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 400a6a052f1b2925db6f1323a67a7cf4546403eb
MONAI __file__: /usr/local/lib/python3.9/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.0.2
scikit-image version: 0.19.3
Pillow version: 8.4.0
Tensorboard version: 2.11.2
gdown version: 4.6.4
TorchVision version: 0.14.1+cu116
tqdm version: 4.65.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.4
pandas version: 1.4.4
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-r

In [4]:
wandb.tensorboard.patch(root_logdir="./runs")
wandb.init(project="monai-integration", sync_tensorboard=True, save_code=True)

[34m[1mwandb[0m: Currently logged in as: [33mgeekyrakshit[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
images = glob("./dataset/*")[:20]
labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
train_files = [{"img": img, "label": label} for img, label in zip(images[:10], labels[:10])]
val_files = [{"img": img, "label": label} for img, label in zip(images[-10:], labels[-10:])]

In [6]:
train_transforms = Compose(
    [
        LoadImaged(keys=["img"], ensure_channel_first=True),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["img"], ensure_channel_first=True),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
    ]
)



In [7]:
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["label"])



torch.Size([2, 1, 96, 96, 96]) tensor([0, 0])


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
loss = torch.nn.CrossEntropyLoss()
lr = 1e-5
opt = torch.optim.Adam(net.parameters(), lr)

In [9]:
def prepare_batch(batch, device=None, non_blocking=False):
    return _prepare_batch((batch["img"], batch["label"]), device, non_blocking)

trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch)

In [10]:
checkpoint_handler = WandbModelCheckpointHandler("./runs_dict/", "net", n_saved=10, require_empty=False)
trainer.add_event_handler(
    event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt}
)

train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x)
train_stats_handler.attach(trainer)

train_tensorboard_stats_handler = TensorBoardStatsHandler(output_transform=lambda x: x)
train_tensorboard_stats_handler.attach(trainer)

In [11]:
 # set parameters for validation
validation_every_n_epochs = 1

metric_name = "AUC"
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: ROCAUC()}

post_label = Compose([AsDiscrete(to_onehot=2)])
post_pred = Compose([Activations(softmax=True)])
# Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = create_supervised_evaluator(
    net,
    val_metrics,
    device,
    True,
    prepare_batch=prepare_batch,
    output_transform=lambda x, y, y_pred: (
        [post_pred(i) for i in decollate_batch(y_pred)],
        [post_label(i) for i in decollate_batch(y, detach=False)],
    ),
)

In [12]:
# add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
    name="evaluator",
    output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
    global_epoch_transform=lambda x: trainer.state.epoch,
)  # fetch global epoch number from trainer
val_stats_handler.attach(evaluator)

# add handler to record metrics to TensorBoard at every epoch
val_tensorboard_stats_handler = TensorBoardStatsHandler(
    output_transform=lambda x: None,  # no need to plot loss value, so disable per iteration output
    global_epoch_transform=lambda x: trainer.state.epoch,
)  # fetch global epoch number from trainer
val_tensorboard_stats_handler.attach(evaluator)

In [13]:
# add early stopping handler to evaluator
early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer)
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

<ignite.engine.events.RemovableEventHandle at 0x7f2b20dd7610>

In [14]:
# create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())

In [15]:
@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
    evaluator.run(val_loader)

In [16]:
# create a training data loader
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())

In [17]:
train_epochs = 30
state = trainer.run(train_loader, train_epochs)
print(state)
wandb.finish()

2023-03-23 14:01:44,550 - INFO - Epoch: 1/30, Iter: 1/5 -- Loss: 0.6107 
2023-03-23 14:01:44,787 - INFO - Epoch: 1/30, Iter: 2/5 -- Loss: 0.6475 
2023-03-23 14:01:45,014 - INFO - Epoch: 1/30, Iter: 3/5 -- Loss: 0.5805 
2023-03-23 14:01:45,249 - INFO - Epoch: 1/30, Iter: 4/5 -- Loss: 0.6555 
2023-03-23 14:01:45,486 - INFO - Epoch: 1/30, Iter: 5/5 -- Loss: 0.6068 
2023-03-23 14:01:51,314 - INFO - Epoch[1] Metrics -- AUC: 0.4167 
2023-03-23 14:01:55,667 - INFO - Epoch: 2/30, Iter: 1/5 -- Loss: 0.7619 
2023-03-23 14:01:55,954 - INFO - Epoch: 2/30, Iter: 2/5 -- Loss: 0.5989 
2023-03-23 14:01:56,251 - INFO - Epoch: 2/30, Iter: 3/5 -- Loss: 0.5601 
2023-03-23 14:01:56,480 - INFO - Epoch: 2/30, Iter: 4/5 -- Loss: 0.5078 
2023-03-23 14:01:56,709 - INFO - Epoch: 2/30, Iter: 5/5 -- Loss: 0.6045 
2023-03-23 14:02:02,048 - INFO - Epoch[2] Metrics -- AUC: 0.5833 
2023-03-23 14:02:06,596 - INFO - Epoch: 3/30, Iter: 1/5 -- Loss: 0.7187 
2023-03-23 14:02:07,052 - INFO - Epoch: 3/30, Iter: 2/5 -- Loss: 

2023-03-23 14:02:43,824 ignite.handlers.early_stopping.EarlyStopping INFO: EarlyStopping: Stop training


State:
	iteration: 30
	epoch: 6
	epoch_length: 5
	max_epochs: 30
	output: 0.5829094648361206
	batch: <class 'dict'>
	metrics: <class 'dict'>
	dataloader: <class 'monai.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>



0,1
AUC,▁█▅█▆▃
global_step,▁▂▄▅▇█

0,1
AUC,0.45833
global_step,6.0
