In [1]:
import argparse
import gc
import importlib
import os
import sys
import shutil

import numpy as np
import pandas as pd
import torch
from torch import nn
from monai.handlers.utils import from_engine
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from torch.nn.modules.loss import _Loss
from utils import *
from monai.transforms import (
    Compose,
    LoadImaged,
    RandSpatialCropd,
    EnsureTyped,
    CastToTyped,
    NormalizeIntensityd,
    RandFlipd,
    CenterSpatialCropd,
    ScaleIntensityRanged,
    RandAffined,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandCoarseDropoutd,
    Rand2DElasticd,
    Lambdad,
    Resized,
    AddChanneld,
    RandGaussianNoised,
    RandGridDistortiond,
    RepeatChanneld,
    Transposed,
    OneOf,
    EnsureChannelFirstd,
    RandLambdad,
    Spacingd,
    FgBgToIndicesd,
    CropForegroundd,
    RandCropByPosNegLabeld,
    ToDeviced,
    SpatialPadd,

)
from monai.transforms import (
    Compose,
    Activations,
    AsDiscrete,
    Activationsd,
    AsDiscreted,
    KeepLargestConnectedComponentd,
    Invertd,
    LoadImage,
    Transposed,
)
import json
from metric import HausdorffScore
from monai.utils import set_determinism
from monai.losses import DiceLoss, DiceCELoss
from monai.networks.nets import UNet
from monai.optimizers import Novograd
from monai.metrics import DiceMetric

In [2]:
fold = 0
seed = 42
batch_size = 2
device = "cuda"
lr = 3e-4
weight_decay = 0
mixed_precision=True
epochs = 10
amp=True
val_amp=True
roi_size = (90, 90, 72)
sw_batch_size = 4
run_tta_val = False
best_weights_name = 'best_weights'
load_best_weights = False
output_dir = '/home/synergy/yhk/GI/output'

In [3]:
data_json_dir = f"dataset_3d_fold_{fold}.json"
with open(data_json_dir, "r") as f:
    data_json = json.load(f)
os.makedirs(str('/home/synergy/yhk/GI/output' + f"/fold{fold}/"), exist_ok=True)

In [4]:
#train transforms, train dataset, train dataloader
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "mask"]),
        EnsureChannelFirstd(keys=["image", "mask"]),
        SpatialPadd(keys=["image", "mask"], spatial_size=(384, 384, 144)),
        RandCropByPosNegLabeld(
            keys=["image", "mask"],
            label_key="mask",
            spatial_size=(96, 96, 64),
            pos=2,
            neg=1,
            num_samples=5,
            image_key="image",
            image_threshold=0,
        ),
        Lambdad(keys="image", func=lambda x: x / x.max()),
        #RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[0]),
        RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[1]),
        # RandFlipd(keys=("image", "mask"), prob=0.5, spatial_axis=[2]),
        # RandAffined(
        #     keys=("image", "mask"),
        #     prob=0.5,
        #     rotate_range=np.pi / 12,
        #     translate_range=(cfg.img_size[0]*0.0625, cfg.img_size[1]*0.0625),
        #     scale_range=(0.1, 0.1),
        #     mode="nearest",
        #     padding_mode="reflection",
        # ),
        # OneOf(
        #     [
        #         RandGridDistortiond(keys=("image", "mask"), prob=0.5, distort_limit=(-0.05, 0.05), mode="nearest", padding_mode="reflection"),
        #         RandCoarseDropoutd(
        #             keys=("image", "mask"),
        #             holes=5,
        #             max_holes=8,
        #             spatial_size=(1, 1, 1),
        #             max_spatial_size=(12, 12, 12),
        #             fill_value=0.0,
        #             prob=0.5,
        #         ),
        #     ]
        # ),
        # RandScaleIntensityd(keys="image", factors=(-0.2, 0.2), prob=0.5),
        # RandShiftIntensityd(keys="image", offsets=(-0.1, 0.1), prob=0.5),
        EnsureTyped(keys=("image", "mask"), dtype=torch.float32),
    ]
)


def get_train_dataset(json_data):
    train_ds = CacheDataset(
        data=json_data["train"],
        transform=train_transforms,
        cache_rate=0,
        num_workers=8,
        copy_cache=False,
    )

    return train_ds

def get_train_dataloader(train_dataset, gpu_cache):
    if gpu_cache:
        train_dataloader = ThreadDataLoader(
            train_dataset,
            shuffle=True,
            batch_size=batch_size,
            num_workers=0,
            drop_last=True,
        )

        return train_dataloader
    
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=True,
        batch_size=batch_size,
        num_workers=8,
        drop_last=True,
    )

    return train_dataloader

In [5]:
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "mask"]),
        EnsureChannelFirstd(keys=["image", "mask"]),
        # Spacingd(keys=["image", "mask"], pixdim=cfg.spacing, mode=("bilinear", "nearest")),
        Lambdad(keys="image", func=lambda x: x / x.max()),
        SpatialPadd(keys=["image", "mask"], spatial_size=(384, 384, 160)),
        # SpatialPadd(keys=("image", "mask"), spatial_size=cfg.img_size),
        # Resized(keys=("image", "mask"), spatial_size=(224, 224, -1), mode="nearest"),
        # NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        # ScaleIntensityRanged(
        #     keys=["image"], a_min=0.0, a_max=255.0,
        #     b_min=0.0, b_max=1.0, clip=True,
        # ),
        # CropForegroundd(keys=["image", "mask"], source_key="image"),
        # NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        EnsureTyped(keys=("image", "mask"), dtype=torch.float32),
        # ToDeviced(keys=["image", "mask"], device="cuda:0"),
    ]
)

org_val_transforms = Compose(
    [
        LoadImaged(keys="image"),
        EnsureChannelFirstd(keys="image"),
        # Spacingd(keys="image", pixdim=cfg.spacing, mode="bilinear"),
        Lambdad(keys="image", func=lambda x: x / x.max()),
        # SpatialPadd(keys="image", spatial_size=cfg.img_size),
        EnsureTyped(keys="image", dtype=torch.float32),
    ]
)

def get_val_dataset(json_data):
    val_ds = CacheDataset(
        data=json_data["val"],
        transform=val_transforms,
        cache_rate=1,
        num_workers=0,
        copy_cache=False,
    )
    return val_ds

def get_val_dataloader(val_dataset, gpu_cache):
    if gpu_cache:
        val_dataloader = ThreadDataLoader(
            val_dataset,
            batch_size=4,
            num_workers=0,
        )
        return val_dataloader

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=4,
        num_workers=8,
    )
    return val_dataloader

In [6]:
train_dataset = get_train_dataset(data_json)
train_dataloader = get_train_dataloader(train_dataset, False)
val_dataset = get_val_dataset(data_json)
val_dataloader = get_val_dataloader(val_dataset, False)

Loading dataset: 100%|██████████| 60/60 [00:56<00:00,  1.06it/s]


In [7]:
count  = 0

for images in train_dataloader:
    print(images['image'].shape, images['mask'].shape)
    print(len(train_dataloader))
    count +=1 
    if count ==1:
        break

RuntimeError: DataLoader worker (pid(s) 15138) exited unexpectedly

In [None]:
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(32, 64, 128, 256, 512),
    strides=(2, 2, 2, 2),
    kernel_size=3,
    up_kernel_size=3,
    num_res_units=2,
    act="PRELU",
    norm="BATCH",
    dropout=0.2,
    bias=True,
    dimensions=None,
).to(device)

In [None]:
class DiceBceMultilabelLoss(_Loss):
    def __init__(
        self,
        w_dice = 0.5,
        w_bce = 0.5,
        reduction = LossReduction.MEAN,
    ):
        super().__init__(reduction=LossReduction(reduction).value)
        self.w_dice = w_dice
        self.w_bce = w_bce
        self.dice_loss = DiceLoss(sigmoid=True, smooth_nr=0.01, smooth_dr=0.01, include_background=True, batch=True, squared_pred=True)
        self.bce_loss = nn.BCEWithLogitsLoss()
    def forward(self, pred, label):
        
        loss = self.dice_loss(pred, label) * self.w_dice + self.bce_loss(pred, label) * self.w_bce
        return loss

In [None]:
total_steps = len(train_dataset)
params = model.parameters()

optimizer = optim.Adam(params, lr=lr, weight_decay=weight_decay)
scheduler = lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max = epochs * (total_steps // batch_size),
    eta_min=1e-5,
)

seg_loss_func = DiceBceMultilabelLoss()
dice_metric = DiceMetric(reduction='mean')
hausdorff_metric = HausdorffScore(reduction='mean')
metric_function = [dice_metric, hausdorff_metric]

In [None]:
def create_checkpoint(model, optimizer, epoch, scheduler=None, scaler=None):
    checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epoch": epoch,
    }

    if scheduler is not None:
        checkpoint["scheduler"] = scheduler.state_dict()

    if scaler is not None:
        checkpoint["scaler"] = scaler.state_dict()
    return checkpoint

In [None]:
post_org_pred = Compose([
    Activationsd(keys="pred", sigmoid=True),
    Invertd(
        keys="pred",
        transform=org_val_transforms,
        orig_keys="image",
        meta_keys="pred_meta_dict",
        orig_meta_keys="image_meta_dict",
        meta_key_postfix="meta_dict",
        nearest_interp=False,
        to_tensor=True,
    ),
    AsDiscreted(keys="pred", threshold=0.5),
])

post_pred = Compose([
    Activations(sigmoid=True),
    AsDiscrete(threshold=0.5),
])

In [None]:
def run_train(
    model,
    train_dataloader,
    optimizer,
    scheduler,
    seg_loss_func,
    epoch,
    step,
    iteration,
):
    model.train()
    scaler = GradScaler()
    progress_bar = tqdm(range(len(train_dataloader)))
    tr_it = iter(train_dataloader)
    dataset_size = 0
    running_loss = 0.0

    for itr in progress_bar:
        iteration += 1
        batch = next(tr_it)
        inputs, masks = (
            batch["image"].to(device),
            batch["mask"].to(device),
        )

        step += batch_size

        if amp is True:
            with autocast():
                outputs = model(inputs)
                loss = seg_loss_func(outputs, masks)
        else:
            outputs = model(inputs)
            loss = seg_loss_func(outputs, masks)
        if amp is True:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 12)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        optimizer.zero_grad()
        scheduler.step()
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        losses = running_loss / dataset_size
        progress_bar.set_description(f"loss: {losses:.4f} lr: {optimizer.param_groups[0]['lr']:.6f}")
        del batch, inputs, masks, outputs, loss
    print(f"Train loss: {losses:.4f}")
    torch.cuda.empty_cache()

In [None]:
def run_eval(model, val_dataloader, post_pred, metric_function, seg_loss_func, epoch):

    model.eval()

    dice_metric, hausdorff_metric = metric_function

    progress_bar = tqdm(range(len(val_dataloader)))
    val_it = iter(val_dataloader)
    with torch.no_grad():
        for itr in progress_bar:
            batch = next(val_it)
            val_inputs, val_masks = (
                batch["image"].to(device),
                batch["mask"].to(device),
            )
            if val_amp is True:
                with autocast():
                    val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
            else:
                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
            # cal metric
            if run_tta_val is True:
                tta_ct = 1
                for dims in [[2],[3],[2,3]]:
                    flip_val_outputs = sliding_window_inference(torch.flip(val_inputs, dims=dims), roi_size, sw_batch_size, model)
                    val_outputs += torch.flip(flip_val_outputs, dims=dims)
                    tta_ct += 1
                
                val_outputs /= tta_ct

            val_outputs = [post_pred(i) for i in val_outputs]
            val_outputs = torch.stack(val_outputs)
            # metric is slice level put (n, c, h, w, d) to (n, d, c, h, w) to (n*d, c, h, w)
            val_outputs = val_outputs.permute([0, 4, 1, 2, 3]).flatten(0, 1)
            val_masks = val_masks.permute([0, 4, 1, 2, 3]).flatten(0, 1)

            hausdorff_metric(y_pred=val_outputs, y=val_masks)
            dice_metric(y_pred=val_outputs, y=val_masks)

            del val_outputs, val_inputs, val_masks, batch

    dice_score = dice_metric.aggregate().item()
    hausdorff_score = hausdorff_metric.aggregate().item()
    dice_metric.reset()
    hausdorff_metric.reset()

    all_score = dice_score * 0.4 + hausdorff_score * 0.6
    print(f"dice_score: {dice_score} hausdorff_score: {hausdorff_score} all_score: {all_score}")
    torch.cuda.empty_cache()

    return all_score

In [None]:
step = 0
i = 0
train=True
evalu=True
eval_epochs = 1
start_eval_epoch = 0
run_org_eval = False

if evalu is True:
    best_val_metric = run_eval(
        model=model,
        val_dataloader=val_dataloader,
        post_pred=post_pred,
        metric_function=metric_function,
        seg_loss_func=seg_loss_func,
        epoch=0,
    )
else:
    best_val_metric = 0.0

for epoch in range(epochs):
    print("EPOCH:", epoch)
    gc.collect()
    if train is True:
        run_train(
            model=model,
            train_dataloader=train_dataloader,
            optimizer=optimizer,
            scheduler=scheduler,
            seg_loss_func=seg_loss_func,
            epoch=epoch,
            step=step,
            iteration=i,
        )

    if (epoch + 1) % eval_epochs == 0 and evalu is True and epoch > start_eval_epoch:
        val_metric = run_eval(
            model=model,
            val_dataloader=val_dataloader,
            post_pred=post_pred,
            metric_function=metric_function,
            seg_loss_func=seg_loss_func,
            epoch=epoch,
        )

        if val_metric > best_val_metric:
            print(f"Find better metric: val_metric {best_val_metric:.5} -> {val_metric:.5}")
            best_val_metric = val_metric
            checkpoint = create_checkpoint(
                model,
                optimizer,
                epoch,
                scheduler=scheduler,
            )
            torch.save(
                checkpoint,
                f"{output_dir}/fold{fold}/{best_weights_name}.pth",
            )
        else:
            if load_best_weights is True:
                try:
                    model.load_state_dict(torch.load(f"{output_dir}/fold{fold}/{best_weights_name}.pth")["model"])
                    print(f"metric no improve, load the saved best weights with score: {best_val_metric}.")
                except:
                    pass

    if (epoch + 1) == epochs:
        # save final best weights, with its distinct name in order to avoid mistakes.
        shutil.copyfile(
            f"{output_dir}/fold{fold}/{best_weights_name}.pth",
            f"{output_dir}/fold{fold}/{best_weights_name}_{best_val_metric:.4f}.pth",
        )

EPOCH: 0


loss: 0.7973 lr: 0.000293: 100%|██████████| 107/107 [00:42<00:00,  2.50it/s]


Train loss: 0.7973
EPOCH: 1


loss: 0.6324 lr: 0.000272: 100%|██████████| 107/107 [00:42<00:00,  2.52it/s]


Train loss: 0.6324


  0%|          | 0/15 [00:02<?, ?it/s]


RuntimeError: Sizes of tensors must match except in dimension 4. Got 10 and 9 (The offending index is 0)

In [None]:
import gc
import torch
with torch.no_grad():
    torch.cuda.empty_cache()
torch.cuda.empty_cache()
gc.collect()

0