Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License. 

# Customize Data Analysis in Auto3DSeg

In this notebook, we will provide a brief example of how to to customize your data analysis pipeline by writing new operations on new metadata.

## Setup environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"

## Setup imports

In [1]:
import os
import torch
import tempfile
import nibabel as nib
import numpy as np

from copy import deepcopy
from tqdm import tqdm

from monai.auto3dseg.analyzer import Analyzer
from monai.auto3dseg import (
    SampleOperations,
    SegSummarizer,
    concat_val_to_np,
    datafold_read,
)
from monai.config import print_config
from monai.data import DataLoader, Dataset, create_test_image_3d
from monai.data.utils import no_collation
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    Lambdad,
    LoadImaged,
    Orientationd,
    SqueezeDimd,
    ToDeviced,
)

from monai.utils.enums import DataStatsKeys


def _argmax_if_multichannel(x):
    return torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x


print_config()

  from torch.distributed.optim import ZeroRedundancyOptimizer


MONAI version: 1.3.2
Numpy version: 1.26.4
Pytorch version: 2.4.0+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 59a7211070538586369afd4a01eca0a7fe2e742e
MONAI __file__: /home/<username>/.virtualenvs/monai/lib/python3.12/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.4.0
Nibabel version: 5.2.1
scikit-image version: 0.24.0
scipy version: 1.14.0
Pillow version: 10.4.0
Tensorboard version: 2.17.0
gdown version: 5.2.0
TorchVision version: 0.19.0+cu121
tqdm version: 4.66.4
lmdb version: 1.5.1
psutil version: 6.0.0
pandas version: 2.2.2
einops version: 0.8.0
transformers version: 4.43.3
mlflow version: 2.15.0
pynrrd version: 1.0.0
clearml version: 1.16.2

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



In [6]:
import platform
from pathlib import Path
from mri_data.file_manager import scan_3Tpioneer_bids, filter_first_ses
from monai_training.preprocess import DataSetProcesser
from monai_training import training, preprocess

import json

## Simulate a dataset and Auto3D datalist using MONAI functions

In [32]:
hostname = platform.node()
if hostname == "rhinocampus" or hostname == "ryzen9":
    drive_root = Path("/media/smbshare")
else:
    drive_root = Path("/mnt/h")

projects_root = Path("/home/srs-9/Projects")
msmri_home = projects_root / "ms_mri"
training_work_dirs = msmri_home / "training_work_dirs"

dataroot = drive_root / "3Tpioneer_bids"
work_dir_name = "choroid_pineal_pituitary1"
work_dir = training_work_dirs / work_dir_name
modalities = ["flair", "t1"]
labels = ["choroid_t1_flair", "pineal", "pituitary"]

datalist_file = os.path.join(work_dir, "datalist.json")
with open(datalist_file, 'r') as f:
    datalist = json.load(f)

dataset = preprocess.parse_datalist(datalist_file, dataroot)
dataset.sort()

[32m2024-11-08 14:32:48.646[0m | [1mINFO    [0m | [36mmonai_training.preprocess[0m:[36mparse_datalist[0m:[36m282[0m - [1mLoading /home/srs-9/Projects/ms_mri/training_work_dirs/choroid_pineal_pituitary1/datalist.json[0m
[32m2024-11-08 14:32:48.648[0m | [1mINFO    [0m | [36mmonai_training.preprocess[0m:[36mparse_datalist[0m:[36m283[0m - [1m/home/srs-9/Projects/ms_mri/training_work_dirs/choroid_pineal_pituitary1/datalist.json exists: True[0m


## Perform analysis on a different image meta data

In [33]:
class DimsAnalyzer(Analyzer):
    def __init__(self, image_key="image", stats_name="user_stats"):
        self.image_key = image_key
        report_format = {"ndims": None}
        super().__init__(stats_name, report_format)

    def __call__(self, data):
        d = dict(data)
        report = deepcopy(self.get_report_format())
        report["ndims"] = d[self.image_key].ndim
        d[self.stats_name] = report
        return d


class DimsSummaryAnalyzer(Analyzer):
    def __init__(self, stats_name="user_stats"):
        report_format = {"ndims": None}
        super().__init__(stats_name, report_format)
        self.update_ops("ndims", SampleOperations())

    def __call__(self, data):
        report = deepcopy(self.get_report_format())
        v_np = concat_val_to_np(data, [self.stats_name, "ndims"])
        report["ndims"] = self.ops["ndims"].evaluate(v_np)
        return report


# it has the three default analyzers (ImageStats, FgImageStats, LabelStats)
summarizer = SegSummarizer("image", "label")
summarizer.add_analyzer(DimsAnalyzer(), DimsSummaryAnalyzer())

In [54]:
def my_analyzer(datalist, dataroot, my_summarizer):
    keys = ["image", "label"]
    transform_list = [
        LoadImaged(keys=keys),
        EnsureChannelFirstd(keys=keys),  # this creates label to be (1,H,W,D)
        Orientationd(keys=keys, axcodes="RAS"),
        EnsureTyped(keys=keys, data_type="tensor"),
        Lambdad(keys="label", func=_argmax_if_multichannel),
        SqueezeDimd(keys=["label"], dim=0),
        ToDeviced(keys=keys, device="cuda" if torch.cuda.is_available() else "cpu"),
        my_summarizer,
    ]

    transform = Compose(transforms=list(filter(None, transform_list)))

    files, _ = datafold_read(datalist=datalist, basedir=dataroot, fold=-1)
    dataset = Dataset(data=files, transform=transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=no_collation)
    result = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}

    for batch_data in tqdm(dataloader):
        d = batch_data[0]
        stats_by_cases = {
            DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH],
            DataStatsKeys.BY_CASE_LABEL_PATH: d[DataStatsKeys.BY_CASE_LABEL_PATH],
            DataStatsKeys.IMAGE_STATS: d[DataStatsKeys.IMAGE_STATS],
            DataStatsKeys.FG_IMAGE_STATS: d[DataStatsKeys.FG_IMAGE_STATS],
            DataStatsKeys.LABEL_STATS: d[DataStatsKeys.LABEL_STATS],
            DataStatsKeys.IMAGE_HISTOGRAM: d[DataStatsKeys.IMAGE_HISTOGRAM],
            "user_stats": d["user_stats"],
        }

        result[DataStatsKeys.BY_CASE].append(stats_by_cases)
    result[DataStatsKeys.SUMMARY] = my_summarizer.summarize(result[DataStatsKeys.BY_CASE])
    return result


result = my_analyzer(datalist, dataroot, summarizer)

100%|██████████| 23/23 [01:38<00:00,  4.27s/it]


In [58]:
print(result[DataStatsKeys.BY_CASE][0]["user_stats"])

{'ndims': 4}


In [63]:
work_dir / "datastats_custom.yml"
import yaml
with open(work_dir / "datastats_custom.yml", 'w') as f:
    yaml.dump(result, f, default_flow_style=False)

In [66]:
import json
with open(work_dir / "datastats_custom.json", 'w') as f:
    json.dump(result, f)

In [67]:
with open(work_dir / "datastats_custom.json", 'r') as f:
    result = json.load(f)

In [68]:
result['stats_summary']

{'image_stats': {'shape': {'max': [240, 312, 312],
   'mean': [226.82608695652175, 300.5217391304348, 300.5217391304348],
   'median': [240.0, 300.0, 300.0],
   'min': [189, 300, 300],
   'stdev': [17.03360406070266, 2.447173439907879, 2.447173439907879],
   'percentile': [[189, 300, 300],
    [198, 300, 300],
    [240, 300, 300],
    [240, 312, 312]],
   'percentile_00_5': [189, 300, 300],
   'percentile_10_0': [198, 300, 300],
   'percentile_90_0': [240, 300, 300],
   'percentile_99_5': [240, 312, 312]},
  'channels': {'max': 2,
   'mean': 2.0,
   'median': 2.0,
   'min': 2,
   'stdev': 0.0,
   'percentile': [2, 2, 2, 2],
   'percentile_00_5': 2,
   'percentile_10_0': 2,
   'percentile_90_0': 2,
   'percentile_99_5': 2},
  'cropped_shape': {'max': [240, 301, 307],
   'mean': [226.82608695652175, 297.1304347826087, 296.3478260869565],
   'median': [240.0, 299.0, 300.0],
   'min': [189, 288, 276],
   'stdev': [17.03360406070266, 3.937604080900033, 7.135999949138321],
   'percentile': [

In [57]:
result['stats_by_cases']

[{image_filepath: '/mnt/h/3Tpioneer_bids/sub-ms1010/ses-20180208/flair.t1.nii.gz',
  label_filepath: '/mnt/h/3Tpioneer_bids/sub-ms1010/ses-20180208/choroid_t1_flair-CH.pineal-CH.pituitary-CH.nii.gz',
  image_stats: {shape: [[240, 300, 300], [240, 300, 300]],
   channels: 2,
   cropped_shape: [[240, 294, 300], [240, 293, 300]],
   spacing: [0.8000057342241774, 0.8000000256874099, 0.8000000453256201],
   sizemm: [192.0013762138026, 240.00000770622296, 240.00001359768603],
   intensity: [{'max': 1.6661999225616455,
     'mean': 0.31402653455734253,
     'median': 0.053199999034404755,
     'min': 0.0,
     'stdev': 0.44537705183029175,
     'percentile': [0.0, 0.0, 1.0162999629974365, 1.5151000022888184],
     'percentile_00_5': 0.0,
     'percentile_10_0': 0.0,
     'percentile_90_0': 1.0162999629974365,
     'percentile_99_5': 1.5151000022888184},
    {'max': 2422.078125,
     'mean': 143.01145935058594,
     'median': 31.898056030273438,
     'min': -30.629852294921875,
     'stdev': 2

In [45]:
stats_by_cases = result['stats_by_cases']
stats_by_cases

[{image_filepath: '/mnt/h/3Tpioneer_bids/sub-ms2144/ses-20190422/flair.t1.nii.gz',
  label_filepath: '/mnt/h/3Tpioneer_bids/sub-ms2144/ses-20190422/choroid_t1_flair-ED.pineal-SRS.pituitary-CH.nii.gz',
  image_stats: {shape: [[240, 300, 300], [240, 300, 300]],
   channels: 2,
   cropped_shape: [[240, 299, 298], [240, 298, 299]],
   spacing: [0.8000000142247927, 0.800000011920929, 0.8000000083477894],
   sizemm: [192.00000341395025, 240.0000035762787, 240.0000025043368],
   intensity: [{'max': 1.666700005531311,
     'mean': 0.2958483397960663,
     'median': 0.03830000013113022,
     'min': 0.0,
     'stdev': 0.43877023458480835,
     'percentile': [0.0, 0.0, 1.0012999773025513, 1.5216000080108643],
     'percentile_00_5': 0.0,
     'percentile_10_0': 0.0,
     'percentile_90_0': 1.0012999773025513,
     'percentile_99_5': 1.5216000080108643},
    {'max': 4543.73388671875,
     'mean': 139.4381561279297,
     'median': 39.621681213378906,
     'min': -43.456295013427734,
     'stdev': 2

In [43]:
label_stats = result['stats_by_cases'][0]['image_filepath']
# label_stats['label'][-1]
label_stats

'/mnt/h/3Tpioneer_bids/sub-ms2144/ses-20190422/flair.t1.nii.gz'

In [25]:
result['stats_by_cases'][0]['label_stats']['label'][3]

{image_intensity: [{'max': 1.3459999561309814,
   'mean': 0.7435335516929626,
   'median': 0.7513999938964844,
   'min': 0.1824999898672104,
   'stdev': 0.21918557584285736,
   'percentile': [0.24910399317741394,
    0.44273996353149414,
    1.0373398065567017,
    1.2540154457092285],
   'percentile_00_5': 0.24910399317741394,
   'percentile_10_0': 0.44273996353149414,
   'percentile_90_0': 1.0373398065567017,
   'percentile_99_5': 1.2540154457092285},
  {'max': 1761.2386474609375,
   'mean': 715.3837280273438,
   'median': 655.4103393554688,
   'min': 188.7974853515625,
   'stdev': 265.1414489746094,
   'percentile': [287.5709533691406,
    440.6593322753906,
    1138.8695068359375,
    1581.540283203125],
   'percentile_00_5': 287.5709533691406,
   'percentile_10_0': 440.6593322753906,
   'percentile_90_0': 1138.8695068359375,
   'percentile_99_5': 1581.540283203125}],
 shape: [[18, 15, 9]],
 ncomponents: 1,
 foreground_percentage: 4.8935184167930856e-05}

## Add a new stat operation

In [8]:
op = SampleOperations()
# add a new operation
op.update({"sum": np.sum})


class NewDimsSummaryAnalyzer(Analyzer):
    def __init__(self, stats_name="user_stats"):
        report_format = {"ndims": None}
        super().__init__(stats_name, report_format)
        self.update_ops("ndims", op)

    def __call__(self, data):
        report = deepcopy(self.get_report_format())
        v_np = concat_val_to_np(data, [self.stats_name, "ndims"])
        report["ndims"] = self.ops["ndims"].evaluate(v_np)
        return report


summarizer = SegSummarizer("image", "label")
summarizer.add_analyzer(DimsAnalyzer(), NewDimsSummaryAnalyzer())
result = my_analyzer(sim_datalist, sim_dataroot, summarizer)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 30.27it/s]


In [9]:
print(result[DataStatsKeys.SUMMARY]["user_stats"])

{'ndims': {'max': 4, 'mean': 4.0, 'median': 4.0, 'min': 4, 'stdev': 0.0, 'percentile': [4, 4, 4, 4], 'sum': 4, 'percentile_00_5': 4, 'percentile_10_0': 4, 'percentile_90_0': 4, 'percentile_99_5': 4}}
