In [1]:
import os
import glob
import shutil
import tempfile

import wandb
from wandb_addons.monai import (
    WandbStatsHandler,
    WandBImageHandler,
    WandbModelCheckpointSaver
)

import ignite
from ignite.engine import Events
from ignite.handlers import global_step_from_engine, Checkpoint

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR

from monai.handlers import (
    MeanDice,
    StatsHandler,
    TensorBoardImageHandler,
    TensorBoardStatsHandler,
)

from monai.utils import first, set_determinism
from monai.transforms import (
    Activations,
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    Resize,
)
from monai.data import (
    ArrayDataset,
    CacheDataset,
    DataLoader,
    decollate_batch,
)
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.apps import download_and_extract

2023-04-20 23:28:19.949045: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# set the environment variable
os.environ["MONAI_DATA_DIRECTORY"] = "./output"
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
wandb.init(project="monai-integration", save_code=True, sync_tensorboard=True)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgeekyrakshit[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
# define the link of the dataset
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
# define the hash value to validate the downloaded file
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
# define the path for downloading the .tar file
compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
# define the directory for extracting the contents of the .tar file
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
    # download, extract and validate the file
    download_and_extract(resource, compressed_file, root_dir, md5)

In [4]:
train_images = sorted(
    glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
    glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]


set_determinism(seed=0)

In [5]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest")
        ),
    ]
)



In [6]:
config = {
    # data
    "cache_rate": 1.0,
    "num_workers": 2,


    # train settings
    "train_batch_size": 2,
    "val_batch_size": 1,
    "learning_rate": 1e-3,
    "max_epochs": 100,
    "val_interval": 10, # check validation score after n epochs
    "lr_scheduler": "cosine_decay", # just to keep track




    # Unet model (you can even use nested dictionary and this will be handled by W&B automatically)
    "model_type": "unet", # just to keep track
    "model_params": dict(spatial_dims=3,
                  in_channels=1,
                  out_channels=2,
                  channels=(16, 32, 64, 128, 256),
                  strides=(2, 2, 2, 2),
                  num_res_units=2,
                  norm=Norm.BATCH,
    )
}

In [7]:
def train_collate_fn(data):
    images, labels = [], []
    for idx in range(len(data)):
        for d_idx in range(len(data[idx])):
            images.append(data[idx][d_idx]["image"])
            labels.append(data[idx][d_idx]["label"])
    return torch.stack(images).float(), torch.stack(labels).float()


def val_collate_fn(data):
    images, labels = [], []
    for idx in range(len(data)):
        images.append(data[idx]["image"])
        labels.append(data[idx]["label"])
    return torch.stack(images).float(), torch.stack(labels).float()


train_ds = CacheDataset(
    data=train_files,
    transform=train_transforms,
    cache_rate=config['cache_rate'],
    num_workers=config['num_workers']
)

# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader = DataLoader(
    train_ds,
    batch_size=config['train_batch_size'],
    shuffle=True,
    num_workers=config['num_workers'],
    collate_fn=train_collate_fn
)

val_ds = CacheDataset(
    data=val_files,
    transform=val_transforms,
    cache_rate=config['cache_rate'],
    num_workers=config['num_workers']
)
val_loader = DataLoader(
    val_ds,
    batch_size=config['val_batch_size'],
    num_workers=config['num_workers'],
    collate_fn=val_collate_fn
)

Loading dataset: 100%|██████████| 32/32 [02:16<00:00,  4.27s/it]
Loading dataset: 100%|██████████| 9/9 [00:31<00:00,  3.49s/it]


In [8]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(**config['model_params']).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
dice_metric = DiceMetric(include_background=False, reduction="mean")
scheduler = CosineAnnealingLR(optimizer, T_max=config['max_epochs'], eta_min=1e-9)

In [9]:
trainer = ignite.engine.create_supervised_trainer(model, optimizer, loss_function, device, False)

In [10]:
# optional section for checkpoint and tensorboard logging
# adding checkpoint handler to save models (network
# params and optimizer stats) during training
log_dir = os.path.join(root_dir, "logs")

# StatsHandler prints loss at every iteration
# user can also customize print functions and
# can use output_transform to convert
# engine.state.output if it's not a loss value
train_stats_handler = StatsHandler(
    name="trainer", output_transform=lambda x: x
)
train_stats_handler.attach(trainer)

# TensorBoardStatsHandler plots loss at every iteration
train_tensorboard_stats_handler = TensorBoardStatsHandler(
    log_dir=log_dir, output_transform=lambda x: x
)
train_tensorboard_stats_handler.attach(trainer)

# WandbStatsHandler plots loss at every iteration
train_wandb_stats_handler = WandbStatsHandler(output_transform=lambda x: x)
train_wandb_stats_handler.attach(trainer)



In [11]:
# optional section for model validation during training
validation_every_n_epochs = 1
# Set parameters for validation
metric_name = "Mean_Dice"
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice()}
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_label = Compose([AsDiscrete(threshold=0.5)])
# Ignite evaluator expects batch=(img, seg) and
# returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = ignite.engine.create_supervised_evaluator(
    model,
    val_metrics,
    device,
    True,
    output_transform=lambda x, y, y_pred: (
        [post_pred(i) for i in decollate_batch(y_pred)],
        [post_label(i) for i in decollate_batch(y)],
    ),
)

# create a validation data loader
val_imtrans = Compose([Resize((96, 96, 96))])
val_segtrans = Compose([Resize((96, 96, 96))])
val_images = [data["image"] for data in val_ds[:20]]
val_labels = [data["label"] for data in val_ds[:20]]
val_ds = ArrayDataset(val_images, val_imtrans, val_images, val_segtrans)
val_loader = DataLoader(
    val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()
)

@trainer.on(
    ignite.engine.Events.EPOCH_COMPLETED(
        every=validation_every_n_epochs
    )
)
def run_validation(engine):
    evaluator.run(val_loader)


# Add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
    name="evaluator",
    # no need to print loss value, so disable per iteration output
    output_transform=lambda x: None,
    # fetch global epoch number from trainer
    global_epoch_transform=lambda x: trainer.state.epoch,
)
val_stats_handler.attach(evaluator)

# add handler to record metrics to TensorBoard at every validation epoch
val_tensorboard_stats_handler = TensorBoardStatsHandler(
    log_dir=log_dir,
    # no need to plot loss value, so disable per iteration output
    output_transform=lambda x: None,
    # fetch global epoch number from trainer
    global_epoch_transform=lambda x: trainer.state.epoch,
)
val_tensorboard_stats_handler.attach(evaluator)


# add handler to record metrics to WandB at every validation epoch
val_wandb_stats_handler = WandbStatsHandler(
    # no need to plot loss value, so disable per iteration output
    output_transform=lambda x: None,
    # fetch global epoch number from trainer
    global_epoch_transform=lambda x: trainer.state.epoch,
)
val_wandb_stats_handler.attach(evaluator)


checkpoint_handler = Checkpoint(
    {"model": model, "optimizer": optimizer},
    WandbModelCheckpointSaver(),
    n_saved=1,
    filename_prefix="best_checkpoint",
    score_name=metric_name,
    global_step_transform=global_step_from_engine(trainer)
)
evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler)


val_wandb_image_handler = WandBImageHandler(
    batch_transform=lambda batch: (batch[0], batch[1]),
    output_transform=lambda output: output[0],
    global_iter_transform=lambda x: trainer.state.epoch,
)
evaluator.add_event_handler(
    event_name=ignite.engine.Events.EPOCH_COMPLETED,
    handler=val_wandb_image_handler,
)



<ignite.engine.events.RemovableEventHandle at 0x7f47543bd5e0>

In [12]:
state = trainer.run(train_loader, 1)

2023-04-20 23:32:46,641 - INFO - Epoch: 1/1, Iter: 1/16 -- Loss: 0.6680 
2023-04-20 23:32:47,344 - INFO - Epoch: 1/1, Iter: 2/16 -- Loss: 0.6614 
2023-04-20 23:32:48,114 - INFO - Epoch: 1/1, Iter: 3/16 -- Loss: 0.6395 
2023-04-20 23:32:48,897 - INFO - Epoch: 1/1, Iter: 4/16 -- Loss: 0.6536 
2023-04-20 23:32:49,664 - INFO - Epoch: 1/1, Iter: 5/16 -- Loss: 0.6342 
2023-04-20 23:32:50,465 - INFO - Epoch: 1/1, Iter: 6/16 -- Loss: 0.6387 
2023-04-20 23:32:51,204 - INFO - Epoch: 1/1, Iter: 7/16 -- Loss: 0.5861 
2023-04-20 23:32:51,992 - INFO - Epoch: 1/1, Iter: 8/16 -- Loss: 0.6171 
2023-04-20 23:32:52,702 - INFO - Epoch: 1/1, Iter: 9/16 -- Loss: 0.6210 
2023-04-20 23:32:53,748 - INFO - Epoch: 1/1, Iter: 10/16 -- Loss: 0.6245 
2023-04-20 23:32:54,498 - INFO - Epoch: 1/1, Iter: 11/16 -- Loss: 0.6191 
2023-04-20 23:32:55,312 - INFO - Epoch: 1/1, Iter: 12/16 -- Loss: 0.6280 
2023-04-20 23:32:55,978 - INFO - Epoch: 1/1, Iter: 13/16 -- Loss: 0.5790 
2023-04-20 23:32:56,719 - INFO - Epoch: 1/1, It

In [13]:
wandb.finish()
shutil.rmtree(log_dir)

0,1
Loss,██▆▇▆▆▂▅▅▅▅▅▂▁▄▂
Mean_Dice,▁▁
global_step,▁

0,1
Loss,0.5804
Mean_Dice,0.59827
global_step,1.0
