In [1]:
import sys
import os

# Compute path to project root (one level above 'examples/')
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)


In [2]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
from data.nifti_loader import MedicalImageDatasetSplitter,MonaiDatasetCreator,MonaiDataLoaderManager
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact, IntSlider, fixed
from config import config_loader
from utils.visualization import show_slice_advanced

import torch.nn as nn
from models.cnn_backbones import Small3DCNN,DenseNet3D,ResNet3D,BasicBlock3D
from training.trainer import train_model, load_model, save_model,test_model


## Basic data preprocessing and loading 

In [3]:
config_path = os.path.abspath(os.path.join(os.getcwd(), "..", "config", "base_config.yaml"))

# Load & process config (creates directories once)
config = config_loader.load_config(config_path)

# Create dataset splitter
dataset_splitter = MedicalImageDatasetSplitter(config)

# Create dataset creator
dataset_creator = MonaiDatasetCreator(dataset_splitter)

# Create dataloader manager
dataloader_manager = MonaiDataLoaderManager(dataset_creator, config)

# Get all dataloaders
dataloaders = dataloader_manager.get_dataloaders()

# Access individual dataloaders
train_loader = dataloaders['train']
val_loader = dataloaders['val']
test_loader = dataloaders['test']

# Get class information
class_to_idx, idx_to_class = dataset_splitter.get_class_info()
num_classes = dataset_splitter.get_num_classes()

2025-04-29 18:22:19,226 - INFO - Created directory: results/exp_1_20250429
2025-04-29 18:22:19,227 - INFO - Created directory: results/exp_1_20250429/logs
2025-04-29 18:22:19,229 - INFO - Created directory: results/exp_1_20250429/checkpoints
2025-04-29 18:22:19,230 - INFO - Created directory: results/exp_1_20250429/explanations
2025-04-29 18:22:19,231 - INFO - Created directory: results/exp_1_20250429/predictions
2025-04-29 18:22:19,232 - INFO - Created directory: results/exp_1_20250429/figures


Number of empty labels: 0
Checking filtering status...
No filtering applied.
Loaded dataset with 652 total samples
Training: 456 samples
Validation: 98 samples
Testing: 98 samples
Classes: {'AD': 0, 'CN': 1, 'LMCI': 2}


In [4]:
# Fetch a single batch (first batch) from the train loader
batch = next(iter(train_loader))

idx = np.random.randint(0,batch['image'].shape[0])
# Extract the first sample: batch['image'] shape is [B, C, H, W, D]
sample_volume = batch['image'][idx]  # shape [C, H, W, D]

# Convert to numpy and drop channel dimension if present
if isinstance(sample_volume, torch.Tensor):
    sample_volume = sample_volume.cpu().numpy()
if sample_volume.ndim == 4:
    # assume channel-first, take channel 0
    sample_volume = sample_volume[0]

# Now sample_volume is a 3D array of shape [H, W, D]
H, W, D = sample_volume.shape


In [5]:
def show_slice(z: int):
    """
    Display the z-th slice of the 3D volume.
    """
    plt.figure(figsize=(6, 6))
    plt.imshow(sample_volume[:, :, z], cmap='gray')
    plt.title(f"Slice {z+1}/{D}")
    plt.axis('off')
    plt.show()


In [6]:
# Create an IntSlider for the z-dimension
slider = widgets.IntSlider(
    value=0,
    min=0,
    max=D - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to the display function
widgets.interact(show_slice, z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=172), Output()),…

In [7]:
# Slider for the z-dimension
slider = IntSlider(
    value=0,
    min=0,
    max=sample_volume.shape[2] - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to show_slice, passing sample_volume as a fixed argument
interact(show_slice_advanced, volume=fixed(sample_volume), z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=172), Output()),…

### Overriding config values for ablation 

In [8]:
# 1) Load the default config 

config_path = os.path.abspath(os.path.join(os.getcwd(), "..", "config", "base_config.yaml"))


# process config (creates directories once)
config = config_loader.load_config(config_path)


# 2) Tweak any values you want on the fly
config['data']['batch_size']      = 8
config['data']['perform_slicing'] = False
config['data']['image_size'] = [128, 128, 128]
# ...

# Create dataset splitter
dataset_splitter = MedicalImageDatasetSplitter(config)

# Create dataset creator
dataset_creator = MonaiDatasetCreator(dataset_splitter)

# Create dataloader manager
dataloader_manager = MonaiDataLoaderManager(dataset_creator, config)

# Get all dataloaders
dataloaders = dataloader_manager.get_dataloaders()

# Access individual dataloaders
train_loader = dataloaders['train']
val_loader = dataloaders['val']
test_loader = dataloaders['test']

# Get class information
class_to_idx, idx_to_class = dataset_splitter.get_class_info()
NUM_CLASSES = dataset_splitter.get_num_classes()

2025-04-29 18:22:22,704 - INFO - Created directory: results/exp_1_20250429
2025-04-29 18:22:22,705 - INFO - Created directory: results/exp_1_20250429/logs
2025-04-29 18:22:22,705 - INFO - Created directory: results/exp_1_20250429/checkpoints
2025-04-29 18:22:22,706 - INFO - Created directory: results/exp_1_20250429/explanations
2025-04-29 18:22:22,707 - INFO - Created directory: results/exp_1_20250429/predictions
2025-04-29 18:22:22,708 - INFO - Created directory: results/exp_1_20250429/figures


Number of empty labels: 0
Checking filtering status...
No filtering applied.
Loaded dataset with 652 total samples
Training: 456 samples
Validation: 98 samples
Testing: 98 samples
Classes: {'AD': 0, 'CN': 1, 'LMCI': 2}


In [9]:
# Fetch a single batch (first batch) from the train loader
batch = next(iter(train_loader))

idx = np.random.randint(0,batch['image'].shape[0])
# Extract the first sample: batch['image'] shape is [B, C, H, W, D]
sample_volume = batch['image'][idx]  # shape [C, H, W, D]

# Convert to numpy and drop channel dimension if present
if isinstance(sample_volume, torch.Tensor):
    sample_volume = sample_volume.cpu().numpy()
if sample_volume.ndim == 4:
    # assume channel-first, take channel 0
    sample_volume = sample_volume[0]

# Now sample_volume is a 3D array of shape [H, W, D]
H, W, D = sample_volume.shape

In [10]:
batch = next(iter(train_loader))
batch.keys()

dict_keys(['image', 'label'])

In [11]:
# Create an IntSlider for the z-dimension
slider = widgets.IntSlider(
    value=0,
    min=0,
    max=D - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to the display function
widgets.interact(show_slice, z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=127), Output()),…

In [12]:
# Slider for the z-dimension
slider = IntSlider(
    value=0,
    min=0,
    max=sample_volume.shape[2] - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to show_slice, passing sample_volume as a fixed argument
interact(show_slice_advanced, volume=fixed(sample_volume), z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=127), Output()),…

#### Training a predefined 3D CNN

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Small3DCNN(num_classes=NUM_CLASSES).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [17]:
# Train the model on the available device (either GPU or CPU)
trained_model, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=10,
    device=device
)

# Accessing the training history
print("Training History:")
print(f"Train Loss: {history['train_loss']}")
print(f"Train Accuracy: {history['train_accuracy']}")
print(f"Validation Loss: {history['val_loss']}")
print(f"Validation Accuracy: {history['val_accuracy']}")

Epoch 1/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:46<00:00,  1.23it/s]
Epoch 1/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.53it/s]


Epoch [1/10], Train Loss: 1.0993, Train Accuracy: 0.3772, Val Loss: 1.0821, Val Accuracy: 0.3776


Epoch 2/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:45<00:00,  1.26it/s]
Epoch 2/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.52it/s]


Epoch [2/10], Train Loss: 1.0868, Train Accuracy: 0.3794, Val Loss: 1.0713, Val Accuracy: 0.3878


Epoch 3/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:48<00:00,  1.18it/s]
Epoch 3/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.53it/s]


Epoch [3/10], Train Loss: 1.0766, Train Accuracy: 0.4035, Val Loss: 1.0667, Val Accuracy: 0.3878


Epoch 4/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:46<00:00,  1.22it/s]
Epoch 4/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.54it/s]


Epoch [4/10], Train Loss: 1.0714, Train Accuracy: 0.4145, Val Loss: 1.0755, Val Accuracy: 0.3878


Epoch 5/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:48<00:00,  1.17it/s]
Epoch 5/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.55it/s]


Epoch [5/10], Train Loss: 1.0729, Train Accuracy: 0.4035, Val Loss: 1.0664, Val Accuracy: 0.3776


Epoch 6/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:47<00:00,  1.21it/s]
Epoch 6/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.57it/s]


Epoch [6/10], Train Loss: 1.0387, Train Accuracy: 0.4868, Val Loss: 1.0543, Val Accuracy: 0.3878


Epoch 7/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:45<00:00,  1.24it/s]
Epoch 7/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.53it/s]


Epoch [7/10], Train Loss: 1.0273, Train Accuracy: 0.4430, Val Loss: 1.0623, Val Accuracy: 0.4898


Epoch 8/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:48<00:00,  1.16it/s]
Epoch 8/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.53it/s]


Epoch [8/10], Train Loss: 1.0213, Train Accuracy: 0.4605, Val Loss: 1.0631, Val Accuracy: 0.4490


Epoch 9/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:45<00:00,  1.25it/s]
Epoch 9/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.48it/s]


Epoch [9/10], Train Loss: 0.9979, Train Accuracy: 0.5307, Val Loss: 1.0574, Val Accuracy: 0.3776


Epoch 10/10 (Training): 100%|████████████████████████████████████████████████████████████████████████████| 57/57 [00:46<00:00,  1.23it/s]
Epoch 10/10 (Validation): 100%|██████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.57it/s]

Epoch [10/10], Train Loss: 0.9737, Train Accuracy: 0.5395, Val Loss: 1.0725, Val Accuracy: 0.4592
Training History:
Train Loss: [1.0992579041865833, 1.0868231737822818, 1.0765698562588608, 1.071381718443151, 1.0728912960019028, 1.0386509131966977, 1.027273230385362, 1.0213280079657572, 0.9979470717279535, 0.9737115450072706]
Train Accuracy: [0.37719298245614036, 0.3793859649122807, 0.40350877192982454, 0.4144736842105263, 0.40350877192982454, 0.4868421052631579, 0.44298245614035087, 0.4605263157894737, 0.5307017543859649, 0.5394736842105263]
Validation Loss: [1.0820871683267446, 1.0713307032218347, 1.066682600058042, 1.075472043110774, 1.0663803632442768, 1.0542541192128108, 1.0622751712799072, 1.063136366697458, 1.0573864441651564, 1.0724868132517889]
Validation Accuracy: [0.37755102040816324, 0.3877551020408163, 0.3877551020408163, 0.3877551020408163, 0.37755102040816324, 0.3877551020408163, 0.4897959183673469, 0.4489795918367347, 0.37755102040816324, 0.45918367346938777]





In [18]:
test_loss, test_acc, preds, labels = test_model(trained_model, test_loader, criterion)


Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.53it/s]

Test Loss: 1.0444, Test Accuracy: 0.4796





In [None]:
# save_path = "/path/to/save/model.pth"
# save_model(model, save_path)