In [11]:
import torch
import numpy as np
import pandas as pd
import random
from data.nifti_loader import MedicalImageDatasetSplitter,MonaiDatasetCreator,MonaiDataLoaderManager
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact, IntSlider, fixed
from config import config_loader

## Basic data preprocessing and loading 

In [2]:
config_path = "config/base_config.yaml"

# Load & process config (creates directories once)
config = config_loader.load_config(config_path)

# Create dataset splitter
dataset_splitter = MedicalImageDatasetSplitter(config)

# Create dataset creator
dataset_creator = MonaiDatasetCreator(dataset_splitter)

# Create dataloader manager
dataloader_manager = MonaiDataLoaderManager(dataset_creator, config)

# Get all dataloaders
dataloaders = dataloader_manager.get_dataloaders()

# Access individual dataloaders
train_loader = dataloaders['train']
val_loader = dataloaders['val']
test_loader = dataloaders['test']

# Get class information
class_to_idx, idx_to_class = dataset_splitter.get_class_info()
num_classes = dataset_splitter.get_num_classes()

2025-04-29 14:02:06,951 - INFO - Created directory: results/exp_1_20250429
2025-04-29 14:02:06,952 - INFO - Created directory: results/exp_1_20250429/logs
2025-04-29 14:02:06,953 - INFO - Created directory: results/exp_1_20250429/checkpoints
2025-04-29 14:02:06,955 - INFO - Created directory: results/exp_1_20250429/explanations
2025-04-29 14:02:06,956 - INFO - Created directory: results/exp_1_20250429/predictions
2025-04-29 14:02:06,957 - INFO - Created directory: results/exp_1_20250429/figures


Number of empty labels: 0
Checking filtering status...
No filtering applied.
Loaded dataset with 652 total samples
Training: 456 samples
Validation: 98 samples
Testing: 98 samples
Classes: {'AD': 0, 'CN': 1, 'LMCI': 2}


In [9]:
# Fetch a single batch (first batch) from the train loader
batch = next(iter(train_loader))

idx = np.random.randint(0,batch['image'].shape[0])
# Extract the first sample: batch['image'] shape is [B, C, H, W, D]
sample_volume = batch['image'][idx]  # shape [C, H, W, D]

# Convert to numpy and drop channel dimension if present
if isinstance(sample_volume, torch.Tensor):
    sample_volume = sample_volume.cpu().numpy()
if sample_volume.ndim == 4:
    # assume channel-first, take channel 0
    sample_volume = sample_volume[0]

# Now sample_volume is a 3D array of shape [H, W, D]
H, W, D = sample_volume.shape


In [12]:
def show_slice(z: int):
    """
    Display the z-th slice of the 3D volume.
    """
    plt.figure(figsize=(6, 6))
    plt.imshow(sample_volume[:, :, z], cmap='gray')
    plt.title(f"Slice {z+1}/{D}")
    plt.axis('off')
    plt.show()


def show_slice_advanced(volume, z: int):
    """
    Display the z-th slice of a 3D volume.

    Parameters:
        volume: 3D numpy array of shape (H, W, D)
        z:       Index of the slice along the z-dimension
    """
    plt.figure(figsize=(6, 6))
    plt.imshow(volume[:, :, z], cmap='gray')
    plt.title(f"Slice {z+1}/{volume.shape[2]}")
    plt.axis('off')
    plt.show()

In [17]:
# Create an IntSlider for the z-dimension
slider = widgets.IntSlider(
    value=0,
    min=0,
    max=D - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to the display function
widgets.interact(show_slice, z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=172), Output()),…

In [18]:
# Slider for the z-dimension
slider = IntSlider(
    value=0,
    min=0,
    max=sample_volume.shape[2] - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to show_slice, passing sample_volume as a fixed argument
interact(show_slice_advanced, volume=fixed(sample_volume), z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=172), Output()),…

### Overriding config values for ablation 

In [20]:
# 1) Load the default config 

config_path = "config/base_config.yaml"

# process config (creates directories once)
config = config_loader.load_config(config_path)


# 2) Tweak any values you want on the fly
config['data']['batch_size']      = 8
config['data']['perform_slicing'] = False
config['data']['image_size'] = [128, 128, 128]
# ...

# Create dataset splitter
dataset_splitter = MedicalImageDatasetSplitter(config)

# Create dataset creator
dataset_creator = MonaiDatasetCreator(dataset_splitter)

# Create dataloader manager
dataloader_manager = MonaiDataLoaderManager(dataset_creator, config)

# Get all dataloaders
dataloaders = dataloader_manager.get_dataloaders()

# Access individual dataloaders
train_loader = dataloaders['train']
val_loader = dataloaders['val']
test_loader = dataloaders['test']

# Get class information
class_to_idx, idx_to_class = dataset_splitter.get_class_info()
num_classes = dataset_splitter.get_num_classes()

2025-04-29 14:20:29,493 - INFO - Created directory: results/exp_1_20250429
2025-04-29 14:20:29,494 - INFO - Created directory: results/exp_1_20250429/logs
2025-04-29 14:20:29,495 - INFO - Created directory: results/exp_1_20250429/checkpoints
2025-04-29 14:20:29,497 - INFO - Created directory: results/exp_1_20250429/explanations
2025-04-29 14:20:29,498 - INFO - Created directory: results/exp_1_20250429/predictions
2025-04-29 14:20:29,499 - INFO - Created directory: results/exp_1_20250429/figures


Number of empty labels: 0
Checking filtering status...
No filtering applied.
Loaded dataset with 652 total samples
Training: 456 samples
Validation: 98 samples
Testing: 98 samples
Classes: {'AD': 0, 'CN': 1, 'LMCI': 2}


In [21]:
# Fetch a single batch (first batch) from the train loader
batch = next(iter(train_loader))

idx = np.random.randint(0,batch['image'].shape[0])
# Extract the first sample: batch['image'] shape is [B, C, H, W, D]
sample_volume = batch['image'][idx]  # shape [C, H, W, D]

# Convert to numpy and drop channel dimension if present
if isinstance(sample_volume, torch.Tensor):
    sample_volume = sample_volume.cpu().numpy()
if sample_volume.ndim == 4:
    # assume channel-first, take channel 0
    sample_volume = sample_volume[0]

# Now sample_volume is a 3D array of shape [H, W, D]
H, W, D = sample_volume.shape

In [22]:
batch['image'].shape 

torch.Size([8, 1, 128, 128, 128])

In [23]:
# Create an IntSlider for the z-dimension
slider = widgets.IntSlider(
    value=0,
    min=0,
    max=D - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to the display function
widgets.interact(show_slice, z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=127), Output()),…

In [24]:
# Slider for the z-dimension
slider = IntSlider(
    value=0,
    min=0,
    max=sample_volume.shape[2] - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to show_slice, passing sample_volume as a fixed argument
interact(show_slice_advanced, volume=fixed(sample_volume), z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=127), Output()),…