Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# PersistentDataset, CacheDataset, LMDBDataset, and simple Dataset Tutorial and Speed Test

This tutorial shows how to accelerate PyTorch medical DL program based on
how data is loaded and preprocessed using different MONAI `Dataset` managers.

`Dataset` provides the simplest model of data loading.  Each time a dataset is needed, it is reloaded from the original datasources, and processed through the all non-random and random transforms to generate analyzable tensors. This mechanism has the smallest memory footprint, and the smallest temporary disk footprint.

`CacheDataset` provides a mechanism to pre-load all original data and apply non-random transforms into analyzable tensors loaded in memory prior to starting analysis.  The `CacheDataset` requires all tensor representations of data requested to be loaded into memory at once. The subset of random transforms is applied to the cached components before use. This is the highest performance dataset if all data fit in core memory.

`PersistentDataset` processes original data sources through the non-random transforms on first use, and stores these intermediate tensor values to an on-disk persistence representation.  The intermediate processed tensors are loaded from disk on each use for processing by the random-transforms for each analysis request.  The `PersistentDataset` has a similar memory footprint to the simple `Dataset`, with performance characteristics close to the `CacheDataset` at the expense of disk storage.  Additionally, the cost of first time processing of data is distributed across each first use.

`LMDBDataset` is a variant of `PersistentDataset`. It uses an LMDB database as the persistent backend.

It's modified from the [Spleen 3D segmentation tutorial notebook](../3d_segmentation/spleen_segmentation_3d.ipynb).

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/acceleration/dataset_type_performance.ipynb)

## Setup environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm, lmdb]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Setup imports

In [1]:
import autorootcwd
import glob
import os
import shutil
import tempfile
import time

import matplotlib.pyplot as plt
import torch
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import (
    CacheDataset,
    Dataset,
    DataLoader,
    LMDBDataset,
    PersistentDataset,
    decollate_batch,
)
from monai.inferers import sliding_window_inference
from src.losses.losses import DiceFocalLoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric, MeanIoU
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.transforms import (
    EnsureChannelFirstd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    EnsureType,
    Resize,
    RandShiftIntensityd,
    RandFlipd,
    GaussianSmoothd
)
from monai.utils import set_determinism
import lightning.pytorch as pl

import numpy as np
from pathlib import Path
import csv
from dvclive.lightning import DVCLiveLogger
import nibabel as nib

print_config()

MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.5.0+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /home/<username>/project/coronary-artery/.venv/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.15.2
Pillow version: 11.1.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.20.0+cu121
tqdm version: 4.67.1
lmdb version: 1.6.2
psutil version: 7.0.0
pandas version: 2.2.3
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For detail

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [3]:
data_root = "/home/seoooa/project/coronary-artery/data/imageCAS_test"
cache_root = "/data/seoooa/cache/imageCAS_test"

In [4]:
def load_data_from_folder(folder_path):
    images = []
    labels = []
    for subdir in sorted(glob.glob(os.path.join(folder_path, "*"))):
        if os.path.isdir(subdir):
            img_file = glob.glob(os.path.join(subdir, "img.nii.gz"))
            label_file = glob.glob(os.path.join(subdir, "label.nii.gz"))
            if img_file and label_file:
                images.extend(img_file)
                labels.extend(label_file)
    return [{"image": image_name, "label": label_name} for image_name, label_name in zip(images, labels)]

train_files = load_data_from_folder(os.path.join(data_root, "train"))
val_files = load_data_from_folder(os.path.join(data_root, "valid"))
test_files = load_data_from_folder(os.path.join(data_root, "test"))

print(f"Training samples: {len(train_files)}")
print(f"Validation samples: {len(val_files)}")
print(f"Test samples: {len(test_files)}")

Training samples: 100
Validation samples: 10
Test samples: 5


## Define a typical PyTorch training process

### Data Module

In [2]:
class CoronaryArteryDataModule(pl.LightningDataModule):
    def __init__(self, train_ds, val_ds, test_ds, batch_size=1, num_workers=8):
        super().__init__()
        self.train_ds = train_ds
        self.val_ds = val_ds
        self.test_ds = test_ds
        self.batch_size = batch_size
        self.num_workers = num_workers

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=2
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=2
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
        )

### Segment Model

In [3]:

from src.models.proposed.segresnet import SegResNet

class CoronaryArterySegmentModel(pl.LightningModule):
    def __init__(self, batch_size=1, lr=1e-3, patch_size=(96, 96, 96), log_dir="nbs/result/dataset_performance/Dataset"):
        super().__init__()
        self._model = SegResNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            init_filters=16,
            blocks_down=(1, 2, 2, 4),
            blocks_up=(1, 1, 1),
            dropout_prob=0.2,
            label_nc=8,
        )
        
        self.loss_function = DiceFocalLoss(
            to_onehot_y=True,
            softmax=True,
        )

        self.post_pred = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(argmax=True, to_onehot=2)])
        self.post_label = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(to_onehot=2)])

        self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
        self.hausdorff_metric = HausdorffDistanceMetric(
            include_background=False, 
            percentile=85,
            directed=False,
            reduction="mean"
        )
        self.mean_iou_metric = MeanIoU(include_background=False, reduction="mean")

        self.best_val_dice = 0
        self.best_val_epoch = 0
        self.validation_step_outputs = []
        self.test_step_outputs = []
        self.batch_size = batch_size
        self.lr = lr
        self.patch_size = patch_size
        self.result_folder = Path(log_dir)

        self.train_loss_history = []
        self.val_dice_history = []
        self.epoch_times = []
        self.epoch_start_time = None


    def forward(self, x, seg):
        return self._model(x, seg)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters(), self.lr)
        return optimizer
    
    def on_train_epoch_start(self):
        self.epoch_start_time = time.time()

    def on_train_epoch_end(self):
        if self.epoch_start_time is not None:
            epoch_time = time.time() - self.epoch_start_time
            self.epoch_times.append(epoch_time)
            self.log("epoch_time", epoch_time)

    def training_step(self, batch, batch_idx):
        images, labels, segs = (
            batch["image"],
            batch["label"],
            batch["seg"],
        )
        output = self.forward(images, segs)
        loss = self.loss_function(output, labels)
        metrics = loss.item()

        # 학습 손실 저장
        self.train_loss_history.append(metrics)

        self.log(
            "train_loss",
            metrics,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels, segs = (
            batch["image"],
            batch["label"],
            batch["seg"],
        )

        inputs = torch.cat((images, segs), dim=1)
        roi_size = self.patch_size
        sw_batch_size = 4
        
        outputs = sliding_window_inference(
            inputs,
            roi_size,
            sw_batch_size,
            lambda x: self.forward(x[:, :1, ...], x[:, 1:, ...])
        )

        loss = self.loss_function(outputs, labels)
        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]

        self.dice_metric(y_pred=outputs, y=labels)
        self.mean_iou_metric(y_pred=outputs, y=labels)

        # Hausdorff
        try:
            resize_transform = Resize(
                spatial_size=[60, 60, 60],
                mode="nearest"
            )
            
            downsampled_outputs = [resize_transform(i) for i in outputs]
            downsampled_labels = [resize_transform(i) for i in labels]
            
            for idx in range(len(downsampled_outputs)):
                self.hausdorff_metric(
                    y_pred=downsampled_outputs[idx].unsqueeze(0),
                    y=downsampled_labels[idx].unsqueeze(0)
                )
        except Exception as e:
            print(f"[ERROR] Hausdorff metric calculation error: {e}")

        d = {"val_loss": loss, "val_number": len(outputs)}
        self.validation_step_outputs.append(d)
        return d

    def on_validation_epoch_end(self):
        val_loss, num_items = 0, 0
        for output in self.validation_step_outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]

        mean_val_dice = self.dice_metric.aggregate().item()
        mean_val_hausdorff = self.hausdorff_metric.aggregate().item()
        mean_val_iou = self.mean_iou_metric.aggregate().item()

        # Validation Dice 점수 저장
        self.val_dice_history.append(mean_val_dice)

        self.dice_metric.reset()
        self.hausdorff_metric.reset()
        self.mean_iou_metric.reset()

        mean_val_loss = torch.tensor(val_loss / num_items)
        log_dict = {
            "val_dice": mean_val_dice,
            "val_hausdorff": mean_val_hausdorff,
            "val_iou": mean_val_iou,
            "val_loss": mean_val_loss,
        }

        self.log_dict(log_dict)

        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch
        print(
            f"current epoch: {self.current_epoch} "
            f"current mean dice: {mean_val_dice:.4f}, "
            f"hausdorff: {mean_val_hausdorff:.4f}, "
            f"iou: {mean_val_iou:.4f}"
            f"\nbest mean dice: {self.best_val_dice:.4f} "
            f"at epoch: {self.best_val_epoch}"
        )
        self.validation_step_outputs.clear()

    def save_result(self, inputs, outputs, labels, filename_prefix="result"):
        save_folder = self.result_folder / "test"
        os.makedirs(save_folder, exist_ok=True)

        inputs_np = inputs.detach().cpu().numpy().squeeze()
        outputs_np = outputs.detach().cpu().numpy().squeeze()[1]
        labels_np = labels.detach().cpu().numpy().squeeze()[1]

        # Save inputs as NIfTI
        inputs_nifti = nib.Nifti1Image(
            inputs_np,
            np.array([[0.35, 0, 0, 0], [0, 0.35, 0, 0], [0, 0, 0.5, 0], [0, 0, 0, 1]]),
        )
        nib.save(inputs_nifti, save_folder / f"{filename_prefix}_inputs.nii.gz")

        # Save outputs as NIfTI
        outputs_nifti = nib.Nifti1Image(
            outputs_np,
            np.array([[0.35, 0, 0, 0], [0, 0.35, 0, 0], [0, 0, 0.5, 0], [0, 0, 0, 1]]),
        )
        nib.save(outputs_nifti, save_folder / f"{filename_prefix}_outputs.nii.gz")

        # Save labels as NIfTI
        labels_nifti = nib.Nifti1Image(
            labels_np,
            np.array([[0.35, 0, 0, 0], [0, 0.35, 0, 0], [0, 0, 0.5, 0], [0, 0, 0, 1]]),
        )
        nib.save(labels_nifti, save_folder / f"{filename_prefix}_labels.nii.gz")

        print(f"Result saved to: {save_folder}")
        print(f"Inputs: {save_folder / f'{filename_prefix}_inputs.nii.gz'}")
        print(f"Outputs: {save_folder / f'{filename_prefix}_outputs.nii.gz'}")
        print(f"Labels: {save_folder / f'{filename_prefix}_labels.nii.gz'}")

    def test_step(self, batch, batch_idx):
        images, labels, segs = (
            batch["image"],
            batch["label"],
            batch["seg"],
        )
        roi_size = self.patch_size
        sw_batch_size = 4
        inputs = torch.cat((images, segs), dim=1)

        outputs = sliding_window_inference(
            inputs,
            roi_size,
            sw_batch_size,
            lambda x: self.forward(x[:, :1, ...], x[:, 1:, ...])
        )

        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]

        filename = batch["image"].meta["filename_or_obj"][0]
        patient_id = filename.split("/")[-2]  # Gets patient id from the path

        # Save result
        self.save_result(
            images, outputs[0], labels[0], filename_prefix=f"Subj_{patient_id}"
        )

        self.dice_metric(y_pred=outputs, y=labels)
        self.mean_iou_metric(y_pred=outputs, y=labels)

        try:
            resize_transform = Resize(
                spatial_size=[60, 60, 60],
                mode="nearest"
            )
            
            downsampled_outputs = [resize_transform(i) for i in outputs]
            downsampled_labels = [resize_transform(i) for i in labels]
            
            for idx in range(len(downsampled_outputs)):
                self.hausdorff_metric(
                    y_pred=downsampled_outputs[idx].unsqueeze(0),
                    y=downsampled_labels[idx].unsqueeze(0)
                )
            
            hausdorff_score = self.hausdorff_metric.aggregate().item()
            self.hausdorff_metric.reset()

        except Exception as e:
            print(f"[ERROR] Hausdorff metric calculation error in test step: {e}")
            hausdorff_score = 0
        
        dice_score = self.dice_metric.aggregate().item()
        # hausdorff_score = self.hausdorff_metric.aggregate().item()
        mean_iou_score = self.mean_iou_metric.aggregate().item()
        
        d = {
            "test_dice": dice_score,
            "test_hausdorff": hausdorff_score,
            "test_iou": mean_iou_score,
            "patient_id": patient_id,
        }
        self.test_step_outputs.append(d)

        self.dice_metric.reset()
        self.mean_iou_metric.reset()

        return d

    def on_test_epoch_end(self):
        # Calculate mean metrics
        dice_scores = [x["test_dice"] for x in self.test_step_outputs]
        hausdorff_scores = [x["test_hausdorff"] for x in self.test_step_outputs]
        iou_scores = [x["test_iou"] for x in self.test_step_outputs]

        mean_dice = np.mean(dice_scores)
        mean_hausdorff = np.mean(hausdorff_scores)
        mean_iou = np.mean(iou_scores)

        # Log mean metrics
        self.log_dict(
            {
                "test/mean_dice": mean_dice,
                "test/mean_hausdorff": mean_hausdorff,
                "test/mean_iou": mean_iou,
            }
        )

        # Save detailed result to CSV
        result_file = self.result_folder / "test" / "test_result.csv"
        with open(result_file, "w", newline="") as csvfile:
            fieldnames = ["dice_score", "hausdorff_score", "iou_score", "patient_id"]
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

            writer.writeheader()
            for result in self.test_step_outputs:
                result_with_filename = {
                    "dice_score": result["test_dice"],
                    "hausdorff_score": result["test_hausdorff"],
                    "iou_score": result["test_iou"],
                    "patient_id": result["patient_id"],
                }
                writer.writerow(result_with_filename)

        # Write summary row
        with open(result_file, "a", newline="") as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writerow(
                {
                    "dice_score": f"{mean_dice:.4f} ± {np.std(dice_scores):.4f}",
                    "hausdorff_score": f"{mean_hausdorff:.4f} ± {np.std(hausdorff_scores):.4f}",
                    "iou_score": f"{mean_iou:.4f} ± {np.std(iou_scores):.4f}",
                    "patient_id": f"AVG ± STD",
                }
            )

        print(f"\nTest Result Summary:")
        print(f"Mean Dice Score: {mean_dice:.4f}")
        print(f"Mean Hausdorff Distance: {mean_hausdorff:.4f}")
        print(f"Mean IoU Score: {mean_iou:.4f}")
        print(f"Detailed result saved to: {result_file}")

        self.test_step_outputs.clear()

### Train

In [4]:
from lightning.pytorch.callbacks import (
    BatchSizeFinder,
    LearningRateFinder,
    StochasticWeightAveraging,
)

def train_process(train_ds, val_ds, test_ds, log_dir):
    # Set up data module
    data_module = CoronaryArteryDataModule(
        train_ds=train_ds,
        val_ds=val_ds,
        test_ds=test_ds,
        batch_size=1,
        num_workers=8
    )

    # Set up model
    model = CoronaryArterySegmentModel(
        batch_size=1,
        lr=1e-3,
        patch_size=(96, 96, 96),
        log_dir=log_dir
    )

    # Set up callbacks
    callbacks = [
        StochasticWeightAveraging(
            swa_lrs=[1e-4],
            annealing_epochs=5,
            swa_epoch_start=100
        )
    ]
    # dvc_logger = DVCLiveLogger(log_model=True, dir=log_dir, report="html")

    # Set up trainer
    trainer = pl.Trainer(
        devices=[0],  # GPU 설정
        # strategy="ddp_notebook",
        max_epochs=20,
        enable_checkpointing=True,
        benchmark=True,
        accumulate_grad_batches=5,
        precision="bf16-mixed",
        check_val_every_n_epoch=5,
        num_sanity_val_steps=1,
        callbacks=callbacks,
        default_root_dir=log_dir,
        # enable_progress_bar=False
    )

    # Train and test
    print("\nStarting training...")
    total_start = time.time()
    
    trainer.fit(model, data_module)
    train_end = time.time()
    train_time = train_end - total_start
    print(f"\nTraining completed in {train_time:.2f} seconds ({train_time/3600:.2f} hours)")
    
    trainer.save_checkpoint(os.path.join(log_dir, "final_model.ckpt"))
    print("\nStarting testing...")
    
    trainer.test(model, data_module)
    test_end = time.time()
    test_time = test_end - train_end
    total_time = test_end - total_start
    
    print(f"\nTime Summary:")
    print(f"Training Time: {train_time:.2f} seconds ({train_time/3600:.2f} hours)")
    print(f"Testing Time: {test_time:.2f} seconds ({test_time/3600:.2f} hours)")
    print(f"Total Time: {total_time:.2f} seconds ({total_time/3600:.2f} hours)")

    return (
        trainer.max_epochs,
        total_time,
        model.train_loss_history,  # 직접 저장한 메트릭 사용
        model.val_dice_history,    # 직접 저장한 메트릭 사용
        model.epoch_times          # 직접 저장한 메트릭 사용
    )

# Start of speed testing

The `PersistenceDataset`, `CacheDataset`, and `Dataset` are compared for speed for running 30 epochs.

## Set ImageCAS dataset path

In [5]:
data_root = "/home/seoooa/project/coronary-artery/data/imageCAS_test"
cache_root = "/data/seoooa/cache/imageCAS_test"

In [None]:
def load_data_from_folder(folder_path):
    images = []
    labels = []
    segs = [] 
    
    for subdir in sorted(glob.glob(os.path.join(folder_path, "*"))):
        if os.path.isdir(subdir):
            img_file = glob.glob(os.path.join(subdir, "img.nii.gz"))
            label_file = glob.glob(os.path.join(subdir, "label.nii.gz"))
            seg_file = glob.glob(os.path.join(subdir, "heart_combined.nii.gz"))
            
            if img_file and label_file and seg_file:
                images.extend(img_file)
                labels.extend(label_file)
                segs.extend(seg_file)
                
    return [{
        "image": image_name,
        "label": label_name,
        "seg": seg_name
    } for image_name, label_name, seg_name in zip(images, labels, segs)]

train_files = load_data_from_folder(os.path.join(data_root, "train"))
val_files = load_data_from_folder(os.path.join(data_root, "valid"))
test_files = load_data_from_folder(os.path.join(data_root, "test"))

print(f"Training samples: {len(train_files)}")
print(f"Validation samples: {len(val_files)}")
print(f"Test samples: {len(test_files)}")

## Setup transforms for training and validation

Deterministic transforms during training:
* LoadImaged
* EnsureChannelFirstd
* Spacingd
* Orientationd
* ScaleIntensityRanged

Non-deterministic transforms:
* RandCropByPosNegLabeld

All the validation transforms are deterministic.
The results of all the deterministic transforms will be cached to accelerate training.

In [7]:
def transformations():
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label", "seg"]),
            EnsureChannelFirstd(keys=["image", "label", "seg"]),
            Orientationd(keys=["image", "label", "seg"], axcodes="RAS"),
            # Spacingd(
            #    keys=["image", "label"],
            #    pixdim=(0.35, 0.35, 0.5),
            #    mode=("bilinear", "nearest"),
            # ),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-150,
                a_max=550,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            AsDiscreted(
                keys=["seg"],
                to_onehot=8,
            ),
            CropForegroundd(keys=["image", "label", "seg"], source_key="image"),
            RandCropByPosNegLabeld(
                keys=["image", "label", "seg"],
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
            RandFlipd(
                keys=["image", "label", "seg"],
                spatial_axis=[0],
                prob=0.10,
            ),
            RandFlipd(
                keys=["image", "label", "seg"],
                spatial_axis=[1],
                prob=0.10,
            ),
            GaussianSmoothd(keys=["seg"], sigma=1.0),
            RandShiftIntensityd(keys="image", offsets=0.05, prob=0.5),
        ]
    )

    # NOTE: No random cropping in the validation data,
    # we will evaluate the entire image using a sliding window.
    val_transforms = Compose(
        [
            # LoadImaged with image_only=True is to return the MetaTensors
            # the additional metadata dictionary is not returned.
            LoadImaged(keys=["image", "label", "seg"]),
            EnsureChannelFirstd(keys=["image", "label", "seg"]),
            Orientationd(keys=["image", "label", "seg"], axcodes="RAS"),
            # Spacingd(
            #    keys=["image", "label"],
            #    pixdim=(0.35, 0.35, 0.5),
            #    mode=("bilinear", "nearest"),
            # ),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-150,
                a_max=550,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            AsDiscreted(
                keys=["seg"],
                to_onehot=8,
            ),
            CropForegroundd(keys=["image", "label", "seg"], source_key="image"),
            GaussianSmoothd(keys=["seg"], sigma=1.0),
        ]
    )
    return train_transforms, val_transforms

In [8]:
import csv

def save_results_to_csv(filename, epoch_loss_values, metric_values, epoch_times, total_time, max_epochs, init_time=None, disk_usage_before=None, disk_usage_after=None):
    with open(filename, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["max_epochs", max_epochs])
        writer.writerow(["total_time", total_time])
        if init_time is not None:
            writer.writerow(["init_time", init_time])
        
        if disk_usage_before is not None:
            writer.writerow(["disk_usage_before_total_GB", disk_usage_before[0] // (2**30)])
            writer.writerow(["disk_usage_before_used_GB", disk_usage_before[1] // (2**30)])
            writer.writerow(["disk_usage_before_free_GB", disk_usage_before[2] // (2**30)])
        
        if disk_usage_after is not None:
            writer.writerow(["disk_usage_after_total_GB", disk_usage_after[0] // (2**30)])
            writer.writerow(["disk_usage_after_used_GB", disk_usage_after[1] // (2**30)])
            writer.writerow(["disk_usage_after_free_GB", disk_usage_after[2] // (2**30)])
            
        writer.writerow(["epoch", "loss", "metric", "epoch_time"])
        for i in range(len(epoch_loss_values)):
            loss = epoch_loss_values[i] if i < len(epoch_loss_values) else ""
            metric = metric_values[i] if i < len(metric_values) else ""
            epoch_time = epoch_times[i] if i < len(epoch_times) else ""
            writer.writerow([i+1, loss, metric, epoch_time])

## Enable deterministic training and regular `Dataset`

Load each original dataset and transform each time it is needed.

In [None]:
set_determinism(seed=0)
log_dir = f"nbs/result/dataset_performance/Dataset"
os.makedirs(log_dir, exist_ok=True)

train_trans, val_trans = transformations()
train_ds = Dataset(data=train_files, transform=train_trans)
val_ds = Dataset(data=val_files, transform=val_trans)
test_ds = Dataset(data=test_files, transform=val_trans)

(
    max_epochs,
    total_time,
    epoch_loss_values,
    metric_values,
    epoch_times,
) = train_process(train_ds, val_ds, test_ds, log_dir)

print(f"total training time of {max_epochs} epochs" f" with regular Dataset: {total_time:.4f}")

save_results_to_csv(
    "nbs/result/dataset_performance/regular_results.csv",
    epoch_loss_values, metric_values, epoch_times, total_time, max_epochs
)

## Enable deterministic training and `PersistentDataset`

Use persistent storage of non-random transformed training and validation data computed once and stored in persistently across runs

#### For SSD Disk

In [None]:
set_determinism(seed=0)
log_dir = f"nbs/result/dataset_performance/PersistentDataset"
os.makedirs(log_dir, exist_ok=True)

persistent_cache = Path("nbs/result/dataset_performance/PersistentDataset/persistent_cache")
os.makedirs(persistent_cache, exist_ok=True)

import shutil
disk_usage_before = shutil.disk_usage(persistent_cache)
print(f"----------Disk Status before persistent cache:")
print(f"Total: {disk_usage_before[0] // (2**30)} GB")
print(f"Used: {disk_usage_before[1] // (2**30)} GB")
print(f"Free: {disk_usage_before[2] // (2**30)} GB")

train_cache = persistent_cache / "train"
valid_cache = persistent_cache / "valid"
test_cache = persistent_cache / "test"

os.makedirs(train_cache, exist_ok=True)
os.makedirs(valid_cache, exist_ok=True)
os.makedirs(test_cache, exist_ok=True)

train_trans, val_trans = transformations()
persistent_init_start = time.time()

train_persitence_ds = PersistentDataset(data=train_files, transform=train_trans, cache_dir=str(train_cache))
val_persitence_ds = PersistentDataset(data=val_files, transform=val_trans, cache_dir=str(valid_cache))
test_persitence_ds = PersistentDataset(data=test_files, transform=val_trans, cache_dir=str(test_cache))

persistence_init_time = time.time() - persistent_init_start

(
    persistence_epoch_num,
    persistence_total_time,
    persistence_epoch_loss_values,
    persistence_metric_values,
    persistence_epoch_times,
) = train_process(train_persitence_ds, val_persitence_ds, test_persitence_ds, log_dir)
print(
    f"total training time of {persistence_epoch_num}"
    f" epochs with persistent storage Dataset: {persistence_total_time:.4f}"
)

disk_usage_after = shutil.disk_usage(persistent_cache)
print(f"----------Disk Status after persistent cache:")
print(f"Total: {disk_usage_after[0] // (2**30)} GB")
print(f"Used: {disk_usage_after[1] // (2**30)} GB")
print(f"Free: {disk_usage_after[2] // (2**30)} GB")

save_results_to_csv(
    "nbs/result/dataset_performance/persistent_results.csv",
    persistence_epoch_loss_values, 
    persistence_metric_values, 
    persistence_epoch_times, 
    persistence_total_time, 
    persistence_epoch_num,
    init_time=persistence_init_time,
    disk_usage_before=disk_usage_before,
    disk_usage_after=disk_usage_after
)

#### For HDD Disk

In [None]:
set_determinism(seed=0)
log_dir = f"/data/seoooa/cache/dataset_performance/PersistentDataset"
os.makedirs(log_dir, exist_ok=True)

persistent_cache = Path("/data/seoooa/cache/dataset_performance/PersistentDataset/persistent_cache")
os.makedirs(persistent_cache, exist_ok=True)

import shutil
disk_usage_before = shutil.disk_usage(persistent_cache)
print(f"----------Disk Status before persistent cache:")
print(f"Total: {disk_usage_before[0] // (2**30)} GB")
print(f"Used: {disk_usage_before[1] // (2**30)} GB")
print(f"Free: {disk_usage_before[2] // (2**30)} GB")

train_cache = persistent_cache / "train"
valid_cache = persistent_cache / "valid"
test_cache = persistent_cache / "test"

os.makedirs(train_cache, exist_ok=True)
os.makedirs(valid_cache, exist_ok=True)
os.makedirs(test_cache, exist_ok=True)

train_trans, val_trans = transformations()
persistent_init_start = time.time()

train_persitence_ds = PersistentDataset(data=train_files, transform=train_trans, cache_dir=str(train_cache))
val_persitence_ds = PersistentDataset(data=val_files, transform=val_trans, cache_dir=str(valid_cache))
test_persitence_ds = PersistentDataset(data=test_files, transform=val_trans, cache_dir=str(test_cache))

persistence_init_time = time.time() - persistent_init_start

(
    persistence_epoch_num,
    persistence_total_time,
    persistence_epoch_loss_values,
    persistence_metric_values,
    persistence_epoch_times,
) = train_process(train_persitence_ds, val_persitence_ds, test_persitence_ds, log_dir)
print(
    f"total training time of {persistence_epoch_num}"
    f" epochs with persistent storage Dataset: {persistence_total_time:.4f}"
)

disk_usage_after = shutil.disk_usage(persistent_cache)
print(f"----------Disk Status after persistent cache:")
print(f"Total: {disk_usage_after[0] // (2**30)} GB")
print(f"Used: {disk_usage_after[1] // (2**30)} GB")
print(f"Free: {disk_usage_after[2] // (2**30)} GB")

save_results_to_csv(
    "nbs/result/dataset_performance/persistent_results.csv",
    persistence_epoch_loss_values, 
    persistence_metric_values, 
    persistence_epoch_times, 
    persistence_total_time, 
    persistence_epoch_num,
    init_time=persistence_init_time,
    disk_usage_before=disk_usage_before,
    disk_usage_after=disk_usage_after
)

## Enable deterministic training and `LMDBDataset`

Use persistent storage of non-random transformed training and validation data computed once and stored in persistently using LMDB across runs

In [None]:
import lmdb
from pathlib import Path
import os

def check_lmdb_cache(cache_dir: str) -> bool:
    lmdb_file = os.path.join(cache_dir, "monai_cache.lmdb")
    return os.path.exists(lmdb_file)

# LMDB 캐시 디렉토리 설정
LMDB_cache = Path("nbs/result/dataset_performance/LMDBDataset/lmdb_cache")
os.makedirs(LMDB_cache, exist_ok=True)

check_lmdb_cache(LMDB_cache)

In [None]:
import lmdb
import tempfile
from pathlib import Path

set_determinism(seed=0)
log_dir = f"nbs/result/dataset_performance/LMDBDataset"
os.makedirs(log_dir, exist_ok=True)

LMDB_cache = Path("nbs/result/dataset_performance/LMDBDataset/lmdb_cache")
os.makedirs(LMDB_cache, exist_ok=True)

import shutil
disk_usage_before = shutil.disk_usage(LMDB_cache)
print(f"----------Disk Status before LMDB cache:")
print(f"Total: {disk_usage_before[0] // (2**30)} GB")
print(f"Used: {disk_usage_before[1] // (2**30)} GB")
print(f"Free: {disk_usage_before[2] // (2**30)} GB")

train_trans, val_trans = transformations()
lmdb_init_start = time.time()

lmdb_kwargs = {
    "map_async": True,
    "map_size": 1024 * 1024 * 1024 * 10,  # 10GB
    "writemap": True,
    "readonly": check_lmdb_cache(LMDB_cache)  # 캐시가 있으면 읽기 전용으로 설정
}

train_lmdb_ds = LMDBDataset(
    data=train_files,
    transform=train_trans,
    cache_dir=LMDB_cache,
    lmdb_kwargs=lmdb_kwargs
)
val_lmdb_ds = LMDBDataset(
    data=val_files,
    transform=val_trans,
    cache_dir=LMDB_cache,
    lmdb_kwargs=lmdb_kwargs
)
test_lmdb_ds = LMDBDataset(
    data=test_files,
    transform=val_trans,
    cache_dir=LMDB_cache,
    lmdb_kwargs=lmdb_kwargs
)

lmdb_init_time = time.time() - lmdb_init_start

(
    lmdb_epoch_num,
    lmdb_total_time,
    lmdb_epoch_loss_values,
    lmdb_metric_values,
    lmdb_epoch_times,
) = train_process(train_lmdb_ds, val_lmdb_ds, test_lmdb_ds, log_dir)
print(f"total training time of {lmdb_epoch_num}" f" epochs with LMDB storage Dataset: {lmdb_total_time:.4f}")

disk_usage_after = shutil.disk_usage(LMDB_cache)
print(f"----------Disk Status after LMDB cache:")
print(f"Total: {disk_usage_after[0] // (2**30)} GB")
print(f"Used: {disk_usage_after[1] // (2**30)} GB")
print(f"Free: {disk_usage_after[2] // (2**30)} GB")

save_results_to_csv(
    "nbs/result/dataset_performance/lmdb_results.csv",
    lmdb_epoch_loss_values, 
    lmdb_metric_values, 
    lmdb_epoch_times, 
    lmdb_total_time, 
    lmdb_epoch_num,
    init_time=lmdb_init_time,
    disk_usage_before=disk_usage_before,
    disk_usage_after=disk_usage_after
)

## Enable deterministic training and `CacheDataset`

Precompute all non-random transforms of original data and store in memory.

When `runtime_cache="processes"` the cache initialization time `cache_init_time` is negligible.
Set `runtime_cache=False` to enable precomputing cache.

In [None]:
set_determinism(seed=0)
log_dir = f"result/dataset_performance/CacheDataset"
os.makedirs(log_dir, exist_ok=True)

train_trans, val_trans = transformations()
cache_init_start = time.time()
cache_train_ds = CacheDataset(
    data=train_files, transform=train_trans, cache_rate=0.5, num_workers=8, copy_cache=False
)
cache_val_ds = CacheDataset(
    data=val_files, transform=val_trans, cache_rate=0.5, num_workers=8, copy_cache=False
)
cache_test_ds = CacheDataset(
    data=test_files, transform=val_trans, cache_rate=0.5, num_workers=8, copy_cache=False
)
cache_init_time = time.time() - cache_init_start

(
    cache_epoch_num,
    cache_total_time,
    cache_epoch_loss_values,
    cache_metric_values,
    cache_epoch_times,
) = train_process(cache_train_ds, cache_val_ds, cache_test_ds, log_dir)
print(f"total training time of {cache_epoch_num}" f" epochs with CacheDataset: {cache_total_time:.4f}")

save_results_to_csv(
    "nbs/result/dataset_performance/cache_results.csv",
    cache_epoch_loss_values, 
    cache_metric_values, 
    cache_epoch_times, 
    cache_total_time, 
    cache_epoch_num,
    init_time=cache_init_time,
)

## Plot training loss and validation metrics

In [10]:
def load_results_from_csv(filename):
    epoch_loss_values = []
    metric_values = []
    epoch_times = []
    total_time = None
    max_epochs = None
    init_time = None
    disk_usage_before = None
    disk_usage_after = None
    
    with open(filename, "r") as f:
        reader = csv.reader(f)
        for row in reader:
            if row[0] == "max_epochs":
                max_epochs = int(row[1])
            elif row[0] == "total_time":
                total_time = float(row[1])
            elif row[0] == "init_time":
                init_time = float(row[1])
            elif row[0].startswith("disk_usage_before"):
                if disk_usage_before is None:
                    disk_usage_before = {}
                key = row[0].replace("disk_usage_before_", "").replace("_GB", "")
                disk_usage_before[key] = int(row[1])
            elif row[0].startswith("disk_usage_after"):
                if disk_usage_after is None:
                    disk_usage_after = {}
                key = row[0].replace("disk_usage_after_", "").replace("_GB", "")
                disk_usage_after[key] = int(row[1])
            elif row[0] == "epoch":
                continue
            else:
                epoch_loss_values.append(float(row[1]) if row[1] else None)
                metric_values.append(float(row[2]) if row[2] else None)
                epoch_times.append(float(row[3]) if row[3] else None)
    
    result = {
        "max_epochs": max_epochs,
        "total_time": total_time,
        "epoch_loss_values": epoch_loss_values,
        "metric_values": metric_values,
        "epoch_times": epoch_times
    }
    
    if init_time is not None:
        result["init_time"] = init_time
    if disk_usage_before is not None:
        result["disk_usage_before"] = disk_usage_before
    if disk_usage_after is not None:
        result["disk_usage_after"] = disk_usage_after
        
    return result

In [11]:
# Regular dataset
regular_results = load_results_from_csv("nbs/result/dataset_performance/regular_results.csv")
max_epochs = regular_results["max_epochs"]
total_time = regular_results["total_time"]
epoch_loss_values = regular_results["epoch_loss_values"]
metric_values = regular_results["metric_values"]
epoch_times = regular_results["epoch_times"]

# Persistent dataset
persistent_results = load_results_from_csv("nbs/result/dataset_performance/persistent_results.csv")
persistence_total_time = persistent_results["total_time"]
persistence_epoch_loss_values = persistent_results["epoch_loss_values"]
persistence_metric_values = persistent_results["metric_values"]
persistence_epoch_times = persistent_results["epoch_times"]
persistence_init_time = persistent_results.get("init_time")
persistence_disk_before = persistent_results.get("disk_usage_before")
persistence_disk_after = persistent_results.get("disk_usage_after")

# LMDB dataset
lmdb_results = load_results_from_csv("nbs/result/dataset_performance/lmdb_results.csv")
lmdb_total_time = lmdb_results["total_time"]
lmdb_epoch_loss_values = lmdb_results["epoch_loss_values"]
lmdb_metric_values = lmdb_results["metric_values"]
lmdb_epoch_times = lmdb_results["epoch_times"]
lmdb_init_time = lmdb_results.get("init_time")
lmdb_disk_before = lmdb_results.get("disk_usage_before")
lmdb_disk_after = lmdb_results.get("disk_usage_after")

# Cache dataset
cache_results = load_results_from_csv("nbs/result/dataset_performance/cache_results.csv")
cache_total_time = cache_results["total_time"]
cache_epoch_loss_values = cache_results["epoch_loss_values"]
cache_metric_values = cache_results["metric_values"]
cache_epoch_times = cache_results["epoch_times"]
cache_init_time = cache_results.get("init_time")

In [None]:
plt.figure("train", (12, 18))

plt.subplot(4, 2, 1)
plt.title("Regular Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="red")

plt.subplot(4, 2, 2)
plt.title("Regular Val Mean Dice")
x = [i + 1 for i in range(len(metric_values))]
y = cache_metric_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="red")

plt.subplot(4, 2, 3)
plt.title("PersistentDataset Epoch Average Loss")
x = [i + 1 for i in range(len(persistence_epoch_loss_values))]
y = persistence_epoch_loss_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="blue")

plt.subplot(4, 2, 4)
plt.title("PersistentDataset Val Mean Dice")
x = [i + 1 for i in range(len(persistence_metric_values))]
y = persistence_metric_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="blue")

plt.subplot(4, 2, 5)
plt.title("LMDBDataset Epoch Average Loss")
x = [i + 1 for i in range(len(lmdb_epoch_loss_values))]
y = lmdb_epoch_loss_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="yellow")

plt.subplot(4, 2, 6)
plt.title("LMDBDataset Val Mean Dice")
x = [i + 1 for i in range(len(lmdb_metric_values))]
y = lmdb_metric_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="yellow")

plt.subplot(4, 2, 7)
plt.title("Cache Epoch Average Loss")
x = [i + 1 for i in range(len(cache_epoch_loss_values))]
y = cache_epoch_loss_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="green")

plt.subplot(4, 2, 8)
plt.title("Cache Val Mean Dice")
x = [i + 1 for i in range(len(cache_metric_values))]
y = cache_metric_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="green")

plt.show()

## Plot total time and every epoch time

In [None]:
plt.figure("train", (12, 6))

plt.subplot(1, 2, 1)
plt.title(f"Total Train Time ({max_epochs} epochs)")
plt.bar("regular", total_time, 1, label="Regular Dataset", color="red")
plt.bar(
    "lmdb",
    lmdb_init_time + lmdb_total_time,
    1,
    label="LMDB cache init",
    color="yellow",
)
plt.bar("lmdb", lmdb_total_time, 1, label="LMDB Dataset", color="orange")
plt.bar(
    "persistent",
    persistence_init_time + persistence_total_time,
    1,
    label="Persistent Dataset",
    color="blue",
)
if persistence_init_time > 1:
    plt.bar("persistent", persistence_init_time, 1, label="Persistent Init", color="pink")
plt.bar(
    "cache",
    cache_init_time + cache_total_time,
    1,
    label="Cache Dataset",
    color="green",
)
if cache_init_time > 1:
    plt.bar("cache", cache_init_time, 1, label="Cache Init", color="grey")
plt.ylabel("secs")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")

plt.subplot(1, 2, 2)
plt.title("Epoch Time")
x = [i + 1 for i in range(len(epoch_times))]
plt.xlabel("epoch")
plt.ylabel("secs")
plt.plot(x, epoch_times, label="Regular Dataset", color="red")
plt.plot(x, persistence_epoch_times, label="Persistent Dataset", color="blue")
plt.plot(x, lmdb_epoch_times, label="LMDB Dataset", color="yellow")
plt.plot(x, cache_epoch_times, label="Cache Dataset", color="green")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")

plt.show()

## Plot Disk Usage `Persistent` and `LMDBDataset`

In [None]:
import matplotlib.pyplot as plt

# 그래프 스타일 설정
plt.style.use('default') 
plt.figure(figsize=(6, 6))

# 데이터 준비
datasets = ['Persistent Dataset', 'LMDB Dataset']
disk_usage_changes = [
    persistence_disk_after['used'] - persistence_disk_before['used'] if (persistence_disk_after and persistence_disk_before) else 0,
    lmdb_disk_after['used'] - lmdb_disk_before['used'] if (lmdb_disk_after and lmdb_disk_before) else 0
]

# 막대 그래프 생성
bars = plt.bar(datasets, disk_usage_changes, color=['skyblue', 'lightcoral'], width=0.5)

# 그래프 꾸미기
plt.xlabel('Dataset Type', fontsize=12)
plt.ylabel('Disk Usage Increase (GB)', fontsize=12)
plt.title('Comparison of Disk Usage Increase by Dataset Type', fontsize=14)

# 막대 위에 값 표시
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, height,
             f'{height:.1f}GB',
             ha='center', va='bottom')

plt.tight_layout()

plt.show()

# 수치 비교 출력
print("\nDisk Usage Increase Analysis:")
print("-" * 50)
for dataset, change in zip(datasets, disk_usage_changes):
    print(f"{dataset}: {change:.1f} GB")

# 어떤 데이터셋이 더 효율적인지 비교
if disk_usage_changes[0] != disk_usage_changes[1]:
    more_efficient = datasets[0] if disk_usage_changes[0] < disk_usage_changes[1] else datasets[1]
    difference = abs(disk_usage_changes[0] - disk_usage_changes[1])
    print(f"\n{more_efficient}가 디스크 사용량이 {difference:.1f}GB 더 적습니다.")
else:
    print("\n두 데이터셋의 디스크 사용량이 동일합니다.")

## Cleanup data directory

Remove directory if a temporary was used.

In [None]:
if directory is None:
    shutil.rmtree(root_dir)