## Setup imports

In [1]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import glob
from tqdm import tqdm
from sklearn.metrics import jaccard_score
from monai.utils import set_determinism

import pytorch_lightning
from pytorch_lightning.callbacks import ModelCheckpoint

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    EnsureTyped,
    EnsureType,
    Resized,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import SwinUNETR

from monai.data import (
    ThreadDataLoader,
    CacheDataset,
    DataLoader,
    load_decathlon_datalist,
    decollate_batch,
    set_track_meta,
    list_data_collate,
    pad_list_data_collate,
)


import torch
from dotenv import dotenv_values
from neptune.utils import stringify_unsupported

print_config()

MONAI version: 1.3.2
Numpy version: 1.26.4
Pytorch version: 2.3.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 59a7211070538586369afd4a01eca0a7fe2e742e
MONAI __file__: /home/<username>/Documents/RISA/3D_segmentation/.venv/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: 1.14.0
Pillow version: 10.3.0
Tensorboard version: 2.17.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.4
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.2
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearm

In [2]:
data_dir = os.path.join(os.getcwd(), 'AeroPath')

In [3]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmp9cm139aq


In [4]:
segment_name = "left_upper"
# segment_name = "right_upper"
# segment_name = "right_bottom"
# segment_name = "left_bottom"
# segment_name = "whole"
# segment_name = "2Q"

In [5]:
keys = dotenv_values(".env")
api_key = keys["API_KEY"]

neptune_logger = pytorch_lightning.loggers.NeptuneLogger(
    project="aeropath-workspace/airways-model",
    api_key=api_key,
    tags=['aeropath', 'airways', 'monai', 'SwinUNETR', segment_name],
    name='airways-training'
)

In [6]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_dice',
    dirpath=os.path.join(data_dir, 'checkpoints'),  # Directory to save checkpoints
    filename=segment_name +'_swin_unetr' '_{epoch:02d}-{val_dice:.4f}',  # Filename prefix for saving checkpoints
    save_top_k=1,  # Save only the best checkpoint
    mode='max',  # `min` for minimizing the metric, `max` for maximizing
    verbose=True,  # Log a message when saving the best checkpoint
)

In [7]:
parameters = {
    'img_size': (64, 64, 64),
    'in_channels': 1,
    'out_channels': 2,
    'feature_size':48,
    'use_checkpoint': True,
}

weight = torch.load("./model_swinvit.pt")

class SwinUNetClass(pytorch_lightning.LightningModule):
    def __init__(self, mode, roi_size, spatial_size):
        super().__init__()
        self._model = SwinUNETR(**parameters)
        self._model.load_from(weights=weight)
        self.loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
        self.post_pred = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(argmax=True, to_onehot=2)])
        self.post_label = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(to_onehot=2)])
        self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
        self.best_val_dice = 0
        self.best_val_epoch = 0
        self.validation_step_outputs = []

        self.mode = mode
        self.roi_size = roi_size
        self.spatial_size = spatial_size


    def forward(self, x):
        return self._model(x)

    def prepare_data(self, prepare_val_data=True, prepare_test_data=True):
        # # set up the correct data path
        def get_patterns(mode):
            patterns = {
                'whole': ('**/*_CT_HR_label_airways.nii.gz', '**/*_CT_HR.nii.gz'),
                '1Q': ('nonoverlapping_labels/**/quadrant_1_*.nii.gz', 'nonoverlapping_quadrants/**/quadrant_1_*_CT_HR.nii.gz'),
                '2Q': ('nonoverlapping_labels/**/quadrant_2_*.nii.gz', 'nonoverlapping_quadrants/**/quadrant_2_*_CT_HR.nii.gz'),
                'left_bottom': ('dataset/airways_patched_4/**/*left_bottom_*.nii.gz', 'dataset/scan_patched_4/**/*left_bottom_*.nii.gz'),
                'left_upper': ('dataset/airways_patched_4/**/*left_upper_*.nii.gz', 'dataset/scan_patched_4/**/*left_upper_*.nii.gz'),
                'right_bottom': ('dataset/airways_patched_4/**/*right_bottom_*.nii.gz', 'dataset/scan_patched_4/**/*right_bottom_*.nii.gz'),
                'right_upper': ('dataset/airways_patched_4/**/*right_upper_*.nii.gz', 'dataset/scan_patched_4/**/*right_upper_*.nii.gz')
            }
            return patterns.get(mode, (None, None))

        pattern_labels, pattern_images = get_patterns(self.mode)
        if pattern_labels and pattern_images:
            train_labels = sorted(glob.glob(pattern_labels, recursive=True))
            train_images = sorted(glob.glob(pattern_images, recursive=True))



        data_dicts = [
            {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
        ]
        train_files, val_files = data_dicts[:-9], data_dicts[-9:]

        # set deterministic training for reproducibility
        set_determinism(seed=0)

        # define the data transforms
        num_samples = 2

        train_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"], ensure_channel_first=True),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-1024,
                    a_max=1024,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
                CropForegroundd(keys=["image", "label"], source_key="image"),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    pixdim=(1.0, 1.0, 1.35),
                    mode=("bilinear", "nearest"),
                ),
                EnsureTyped(keys=["image", "label"]),
                # Resized(keys=["image", "label"], spatial_size=self.spatial_size),
                RandCropByPosNegLabeld(
                    keys=["image", "label"],
                    label_key="label",
                    spatial_size=(64, 64, 64),
                    pos=1,
                    neg=1,
                    num_samples=num_samples,
                    image_key="image",
                    image_threshold=0,
                ),
                RandFlipd(
                    keys=["image", "label"],
                    spatial_axis=[0],
                    prob=0.10,
                ),
                RandFlipd(
                    keys=["image", "label"],
                    spatial_axis=[1],
                    prob=0.10,
                ),
                RandFlipd(
                    keys=["image", "label"],
                    spatial_axis=[2],
                    prob=0.10,
                ),
                RandRotate90d(
                    keys=["image", "label"],
                    prob=0.10,
                    max_k=3,
                ),
                RandShiftIntensityd(
                    keys=["image"],
                    offsets=0.10,
                    prob=0.50,
                ),
            ]
        )
        val_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"], ensure_channel_first=True),
                ScaleIntensityRanged(keys=["image"], a_min=-1024, a_max=1024, b_min=0.0, b_max=1.0, clip=True),
                CropForegroundd(keys=["image", "label"], source_key="image"),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    pixdim=(1.0, 1.0, 1.35),
                    mode=("bilinear", "nearest"),
                ),
                EnsureTyped(keys=["image", "label"]),
                # Resized(keys=["image", "label"], spatial_size=self.spatial_size),
            ]
        )                    

        if prepare_test_data:
            self.train_ds = CacheDataset(
                data=train_files,
                transform=train_transforms,
                cache_rate=1.0,
                num_workers=8,
            )
        if prepare_val_data:
            self.val_ds = CacheDataset(
                data=val_files,
                transform=val_transforms,
                cache_rate=1.0,
                num_workers=4,
            )


    def train_dataloader(self):
        train_loader = ThreadDataLoader(
            self.train_ds,
            batch_size=1,
            shuffle=True,
            num_workers=8,
        )
        return train_loader

    def val_dataloader(self):
        val_loader = ThreadDataLoader(self.val_ds, batch_size=1, num_workers=4)
        return val_loader

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self._model.parameters(), 1e-4, weight_decay=1e-5)
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        output = self.forward(images)
        loss = self.loss_function(output, labels)
        tensorboard_logs = {"train_loss": loss.item()}
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return {"loss": loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        images, labels = batch["image"].cuda(), batch["label"].cuda()
        roi_size = self.roi_size
        sw_batch_size = 4
        outputs = sliding_window_inference(images, roi_size, sw_batch_size, self)
        loss = self.loss_function(outputs, labels)
        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]
        self.dice_metric(y_pred=outputs, y=labels)

        outputs_np = [output.argmax(dim=0).cpu().numpy() for output in outputs]
        labels_np = [label.argmax(dim=0).cpu().numpy() for label in labels]
        iou = np.mean([jaccard_score(ln.flatten(), on.flatten(), average='macro') for ln, on in zip(labels_np, outputs_np)])

        d = {"val_loss": loss, "val_number": len(outputs), "iou": iou}
        self.validation_step_outputs.append(d)
        return d
    

    
    def perform_inference(self, model, data):
        # Perform inference using the model
        with torch.no_grad():
            data = torch.DoubleTensor(data)  # Convert data to type Double
            model_output = model(data.unsqueeze(0))
        return model_output

    def on_validation_epoch_end(self):
        val_loss, num_items, total_iou = 0, 0, 0
        for output in self.validation_step_outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]
            total_iou += output["iou"] * output["val_number"]
        mean_val_dice = self.dice_metric.aggregate().item()
        mean_val_iou = total_iou / num_items
        self.dice_metric.reset()
        mean_val_loss = torch.tensor(val_loss / num_items)
        tensorboard_logs = {
            "val_dice": mean_val_dice,
            "val_loss": mean_val_loss,
            "val_iou": mean_val_iou,
        }
        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch
        print(
            f"current epoch: {self.current_epoch} "
            f"current mean dice: {mean_val_dice:.4f} "
            f"current mean iou: {mean_val_iou:.4f}"
            f"\nbest mean dice: {self.best_val_dice:.4f} "
            f"at epoch: {self.best_val_epoch}"
        )
        self.validation_step_outputs.clear()
        self.log('val_dice', mean_val_dice, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_loss', mean_val_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log('val_iou', mean_val_iou, on_step=False, on_epoch=True, prog_bar=False, logger=True)

        return {"log": tensorboard_logs}
    
    
    def dice_score(self, prediction_tensor, label_tensor):
        # Compute Dice score
        dice_metric = DiceMetric(include_background=True, reduction="mean")
        dice_metric(y_pred=prediction_tensor, y=label_tensor)
        dice_score = dice_metric.aggregate().item()
        dice_metric.reset()

        print(dice_score)

In [8]:
Net_segment       = SwinUNetClass(mode = segment_name,       roi_size=[64, 64, 64],   spatial_size=(160, 160, 160))
NetWhole          = SwinUNetClass(mode = 'whole',            roi_size=(32, 32, 32),   spatial_size=(160, 160, 160))
Net1Q             = SwinUNetClass(mode = '1Q',               roi_size=(160*2, 160, 160), spatial_size=(160*2, 160, 160))
Net2Q             = SwinUNetClass(mode = '2Q',               roi_size=(160*2, 160, 160), spatial_size=(160*2, 160, 160))

# NetWhole    = UNetClass(mode = 'whole', roi_size=(192, 192, 212),   spatial_size=(192, 192, 212))

In [9]:
neptune_logger.experiment["model/parameters"] = stringify_unsupported(parameters)



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/aeropath-workspace/airways-model/e/AIR-79


In [9]:
# initialise the LightningModule
# net = Net1Q
# net = Net2Q
net = Net_segment
# net = NetWhole
# set up loggers and checkpoints
log_dir = os.path.join(root_dir, "logs")

# initialise Lightning's trainer.
trainer = pytorch_lightning.Trainer(
    devices=[0],
    max_epochs=600,
    logger=neptune_logger,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback],
    num_sanity_val_steps=1,
    log_every_n_steps=16,
)
# train
trainer.fit(net)

GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Loading dataset: 100%|██████████| 18/18 [00:08<00:00,  2.25it/s]
Loading dataset: 100%|██████████| 9/9 [00:04<00:00,  2.07it/s]


[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/aeropath-workspace/airways-model/e/AIR-84


/home/pawel/Documents/RISA/3D_segmentation/.venv/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory /home/pawel/Documents/RISA/3D_segmentation/AeroPath/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type       | Params | Mode 
-----------------------------------------------------
0 | _model        | SwinUNETR  | 62.2 M | train
1 | loss_function | DiceCELoss | 0      | train
-----------------------------------------------------
62.2 M    Trainable params
0         Non-trainable params
62.2 M    Total params
248.747   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

current epoch: 0 current mean dice: 0.0131 current mean iou: 0.0745
best mean dice: 0.0131 at epoch: 0


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 18: 'val_dice' reached 0.00670 (best 0.00670), saving model to '/home/pawel/Documents/RISA/3D_segmentation/AeroPath/checkpoints/left_upper_swin_unetr_epoch=00-val_dice=0.0067.ckpt' as top 1


current epoch: 0 current mean dice: 0.0067 current mean iou: 0.4972
best mean dice: 0.0131 at epoch: 0


Exception ignored in: <bound method MetadataContainer._before_fork of <neptune.metadata_containers.run.Run object at 0x72e91a27d0f0>>
Traceback (most recent call last):
  File "/home/pawel/Documents/RISA/3D_segmentation/.venv/lib/python3.10/site-packages/neptune/metadata_containers/metadata_container.py", line 285, in _before_fork
    self._op_processor.pause()
  File "/home/pawel/Documents/RISA/3D_segmentation/.venv/lib/python3.10/site-packages/neptune/internal/operation_processors/async_operation_processor.py", line 159, in pause
    self._consumer.pause()
  File "/home/pawel/Documents/RISA/3D_segmentation/.venv/lib/python3.10/site-packages/neptune/internal/threading/daemon.py", line 56, in pause
    self._wait_condition.wait_for(lambda: self._state != Daemon.DaemonState.PAUSING)
  File "/usr/lib/python3.10/threading.py", line 355, in wait_for
    self.wait(waittime)
  File "/usr/lib/python3.10/threading.py", line 320, in wait
    waiter.acquire()
KeyboardInterrupt: 
