To run the notebook on different machine, you need to adjust the following:
- Original preprocessed data directory
- Create 3 directories for ```nnUNet_raw```, ```nnUNet_preprocessed```, and ```nnUNet_results```
- Correct directory for nnunet raw in  ```create_nnunet_dataset_structure()``` function
- Correct directory in env variables
- Pick up the correct env variable export cell for your platform (Linux or Windows)

## Setup and Imports

In [1]:
! pip install numpy nibabel matplotlib pandas scipy tqdm plotly optuna SimpleITK hiddenlayer torch




[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import torch

In [4]:
! pip install causal-conv1d mamba-ssm

Collecting causal-conv1d
  Downloading causal_conv1d-1.5.2.tar.gz (23 kB)
  Installing build dependencies: started
  Installing build dependencies: still running...
  Installing build dependencies: still running...
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'error'


  error: subprocess-exited-with-error
  
  × Getting requirements to build wheel did not run successfully.
  │ exit code: 1
  ╰─> [26 lines of output]
        cpu = _conversion_method_template(device=torch.device("cpu"))
      
      
      torch.__version__  = 2.8.0+cpu
      
      
      Traceback (most recent call last):
        File "F:\RSA\chimera-pcbr-main\chimera_venv\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 353, in <module>
          main()
        File "F:\RSA\chimera-pcbr-main\chimera_venv\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 335, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "F:\RSA\chimera-pcbr-main\chimera_venv\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 118, in get_requires_for_build_wheel
          return hook(config_settings)
                 ^^^^^^^^^

In [None]:
# ! pip uninstall torch

In [None]:
# ! pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126

In [None]:
# must be done after installing pytorch
# ! pip install nnunetv2 

In [None]:
! git clone https://github.com/MrBlankness/LightM-UNet
! cd LightM-UNet/lightm-unet
! pip install -e .

In [None]:
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import pandas as pd
import torch
from scipy.ndimage import zoom
from tqdm import tqdm
import torch.nn.functional as F
from matplotlib.widgets import Slider
import plotly.graph_objects as go
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
from torch.optim import Adam
from concurrent.futures import ThreadPoolExecutor, as_completed

from scipy.ndimage import rotate
import optuna

from optuna.pruners import MedianPruner
from torch.cuda.amp import autocast, GradScaler

import SimpleITK as sitk

import subprocess

from tqdm import tqdm

import shutil
import json

import hiddenlayer

import mamba_ssm

## Dataset Import

In [None]:

images_dir = r"F:\aims_tbi\normalized_T1_scans"
masks_dir = r"F:\aims_tbi\resampled_1mm_Lesion_masks"


# print number of files in processed images and masks
print(f"Number of processed images: {len(os.listdir(images_dir))}")
print(f"Number of processed masks: {len(os.listdir(masks_dir))}")


In [None]:
# Define scan parameters
scan_id = 'scan_0001'
start_slice = 110
num_slices = 5

# Load the .nii.gz files
image_path = os.path.join(images_dir, f"{scan_id}_T1_normalized.nii.gz")
mask_path = os.path.join(masks_dir, f"{scan_id}_Lesion_resampled_1mm.nii.gz")

# Load image and mask using nibabel
image_nii = nib.load(image_path)
mask_nii = nib.load(mask_path)

image_array = image_nii.get_fdata().astype(np.float32)
mask_array = mask_nii.get_fdata().astype(np.uint8)

print(f"Loaded {scan_id}:")
print(f"Image shape: {image_array.shape}")
print(f"Mask shape: {mask_array.shape}")

# Create visualization
fig, axes = plt.subplots(2, num_slices, figsize=(num_slices * 3, 10))

# Ensure axes is 2D even for single slice
if num_slices == 1:
    axes = axes.reshape(-1, 1)

for i in range(num_slices):
    slice_idx = start_slice + i

    # Check if slice index is valid
    if slice_idx >= image_array.shape[2]:
        print(f"⚠️ Slice {slice_idx} is out of bounds (max: {image_array.shape[2]-1}), skipping...")
        continue

    # Extract slices (transpose for proper orientation)
    image_slice = image_array[:, :, slice_idx].T
    mask_slice = mask_array[:, :, slice_idx].T

    # Row 1: Processed T1 image only
    axes[0, i].imshow(image_slice, cmap='gray', origin='lower')
    axes[0, i].set_title(f'Processed T1\n(Normalized) Slice {slice_idx}', fontsize=10)
    axes[0, i].axis('off')

    # Row 2: Processed T1 + Lesion overlay
    axes[1, i].imshow(image_slice, cmap='gray', origin='lower')
    if np.any(mask_slice > 0):  # Only overlay if there are lesions in this slice
        axes[1, i].imshow(mask_slice, cmap='Reds', alpha=0.6, origin='lower')
    axes[1, i].set_title(f'Processed T1 + Lesion\nSlice {slice_idx}', fontsize=10)
    axes[1, i].axis('off')

plt.tight_layout()
plt.suptitle(f'Processed Images from NIfTI Files - {scan_id}', fontsize=16, y=0.98)
plt.show()

# Print intensity statistics
print(f"\n=== Statistics from NIfTI Files for {scan_id} ===")

# Image stats (brain voxels only)
brain_mask = image_array != 0  # Background is 0 after normalization
brain_voxels = image_array[brain_mask]

print("Processed T1 Image (from NIfTI file):")
print(f"  Mean: {np.mean(brain_voxels):.6f}")
print(f"  Std: {np.std(brain_voxels):.6f}")
print(f"  Min: {np.min(brain_voxels):.4f}")
print(f"  Max: {np.max(brain_voxels):.4f}")
print(f"  Shape: {image_array.shape}")

# Lesion statistics
lesion_voxels = np.count_nonzero(mask_array)
total_voxels = mask_array.size
lesion_percentage = (lesion_voxels / total_voxels) * 100

print(f"\nLesion Mask (from NIfTI file):")
print(f"  Lesion voxels: {lesion_voxels:,}")
print(f"  Total voxels: {total_voxels:,}")
print(f"  Lesion percentage: {lesion_percentage:.4f}%")
print(f"  Shape: {mask_array.shape}")

# If you have metadata, you can print it here (optional)
# print(f"\nMetadata Statistics:")
# print(f"  Brain voxels: ...")
# print(f"  Brain mean: ...")
# print(f"  Brain std: ...")
# print(f"

### Verify the mask values are limited to 0 and 1

In [None]:
# for fname in os.listdir(masks_dir):
#     if fname.endswith("_Lesion_resampled_1mm.nii.gz"):
#         mask_path = os.path.join(masks_dir, fname)
#         mask_array = nib.load(mask_path).get_fdata().astype(np.uint8)
#         unique_vals = np.unique(mask_array)
#         if np.any((unique_vals != 0) & (unique_vals != 1)):
#             print(f"{fname}: {unique_vals}")

## nnUNet setup

### Dataset conversion for nnUNet compatibility

Create nnU-Net Dataset Structure

In [None]:
def create_nnunet_dataset_structure():
    """Create nnU-Net compatible dataset structure"""
    
    # Set your nnUNet_raw path (adjust as needed)
    nnunet_raw = "F:\\aims_tbi\\nnUNet_raw"  # or your path
    dataset_name = "Dataset600_TBILesion"  # Choose an unused ID
    
    dataset_path = os.path.join(nnunet_raw, dataset_name)
    
    # Create directories
    os.makedirs(os.path.join(dataset_path, "imagesTr"), exist_ok=True)
    os.makedirs(os.path.join(dataset_path, "labelsTr"), exist_ok=True)
    os.makedirs(os.path.join(dataset_path, "imagesTs"), exist_ok=True)  # Optional for test data
    
    return dataset_path

In [None]:
dataset_path = create_nnunet_dataset_structure()
print(f"Created dataset structure at: {dataset_path}")

Convert Your Files to nnU-Net Format

In [None]:
def convert_nii_to_nnunet_format(images_dir, masks_dir, dataset_path, train_ratio=0.85):
    """
    Convert NIfTI images and masks to nnU-Net format.
    """
    # Get all scan IDs from image filenames
    image_files = [f for f in os.listdir(images_dir) if f.endswith('_T1_normalized.nii.gz')]
    scan_ids = [f.replace('_T1_normalized.nii.gz', '') for f in image_files]

    print(f"Found {len(scan_ids)} scans to convert")

    # Stratified split: lesion vs no-lesion (based on mask content)
    lesion_scans = []
    no_lesion_scans = []
    for scan_id in scan_ids:
        mask_path = os.path.join(masks_dir, f"{scan_id}_Lesion_resampled_1mm.nii.gz")
        mask_array = nib.load(mask_path).get_fdata().astype(np.uint8)
        if np.any(mask_array > 0):
            lesion_scans.append(scan_id)
        else:
            no_lesion_scans.append(scan_id)

    print(f"Lesion scans: {len(lesion_scans)}")
    print(f"No-lesion scans: {len(no_lesion_scans)}")

    n_train_lesion = int(len(lesion_scans) * train_ratio)
    n_train_no_lesion = int(len(no_lesion_scans) * train_ratio)

    train_ids = lesion_scans[:n_train_lesion] + no_lesion_scans[:n_train_no_lesion]
    test_ids = lesion_scans[n_train_lesion:] + no_lesion_scans[n_train_no_lesion:]

    print(f"Training scans: {len(train_ids)}")
    print(f"Test scans: {len(test_ids)}")

    # Create separate directory for test labels (for evaluation)
    test_labels_dir = os.path.join(dataset_path, "test_labels_for_evaluation")
    os.makedirs(test_labels_dir, exist_ok=True)

    converted_count = 0

    for split, ids in [("Tr", train_ids), ("Ts", test_ids)]:
        for scan_id in ids:
            try:
                image_path = os.path.join(images_dir, f"{scan_id}_T1_normalized.nii.gz")
                mask_path = os.path.join(masks_dir, f"{scan_id}_Lesion_resampled_1mm.nii.gz")

                image_nii = nib.load(image_path)
                mask_nii = nib.load(mask_path)

                image_array = image_nii.get_fdata().astype(np.float32)
                mask_array = mask_nii.get_fdata().astype(np.uint8)
                mask_array = (mask_array > 0).astype(np.uint8)  # binarize

                affine = image_nii.affine

                # Save images
                image_filename = f"{scan_id}_0000.nii.gz"
                image_save_path = os.path.join(dataset_path, f"images{split}", image_filename)
                nib.save(nib.Nifti1Image(image_array, affine), image_save_path)

                # Save masks
                mask_filename = f"{scan_id}.nii.gz"
                mask_nii_save = nib.Nifti1Image(mask_array, affine)
                if split == "Tr":
                    mask_save_path = os.path.join(dataset_path, "labelsTr", mask_filename)
                else:
                    mask_save_path = os.path.join(test_labels_dir, mask_filename)
                nib.save(mask_nii_save, mask_save_path)

                converted_count += 1
                if converted_count % 50 == 0:
                    print(f"Converted {converted_count} scans...")

            except Exception as e:
                print(f"Error converting {scan_id}: {e}")
                continue

    print(f"Successfully converted {converted_count} scans")
    print(f"Test labels saved to: {test_labels_dir}")

    return len(train_ids), len(test_ids)

In [None]:
n_train, n_test = convert_nii_to_nnunet_format(images_dir, masks_dir, dataset_path)

Create dataset.json File

In [None]:
def create_dataset_json(dataset_path, num_training):
    """Create dataset.json file for nnU-Net"""
    
    dataset_json = {
        "channel_names": {
            "0": "T1"  # Your T1-weighted MRI scans
        },
        "labels": {
            "background": 0,
            "lesion": 1
        },
        "numTraining": num_training,
        "file_ending": ".nii.gz",
        "dataset_name": "TBI_Lesion_Segmentation",
        "reference": "AIMS TBI Challenge",
        "licence": "Your License",
        "description": "Traumatic Brain Injury Lesion Segmentation Dataset"
    }
    
    # Save dataset.json
    json_path = os.path.join(dataset_path, "dataset.json")
    with open(json_path, 'w') as f:
        json.dump(dataset_json, f, indent=2)
    
    print(f"Created dataset.json with {num_training} training cases")
    print(f"Saved to: {json_path}")

In [None]:
create_dataset_json(dataset_path, n_train)

Verify Dataset Structure

In [None]:
def verify_dataset_structure(dataset_path):
    """Verify the dataset structure is correct"""
    
    print("Verifying dataset structure...")
    
    # Check folder structure
    required_folders = ["imagesTr", "labelsTr"]
    for folder in required_folders:
        folder_path = os.path.join(dataset_path, folder)
        if not os.path.exists(folder_path):
            print(f"❌ Missing folder: {folder}")
            return False
        else:
            print(f"✅ Found folder: {folder}")
    
    # Check dataset.json
    json_path = os.path.join(dataset_path, "dataset.json")
    if not os.path.exists(json_path):
        print("❌ Missing dataset.json")
        return False
    else:
        print("✅ Found dataset.json")
    
    # Check file counts
    images_tr = len([f for f in os.listdir(os.path.join(dataset_path, "imagesTr")) if f.endswith('.nii.gz')])
    labels_tr = len([f for f in os.listdir(os.path.join(dataset_path, "labelsTr")) if f.endswith('.nii.gz')])
    
    print(f"Training images: {images_tr}")
    print(f"Training labels: {labels_tr}")
    
    if images_tr != labels_tr:
        print("❌ Mismatch between number of images and labels")
        return False
    
    # Check file naming convention
    sample_files = os.listdir(os.path.join(dataset_path, "imagesTr"))[:5]
    for file in sample_files:
        if not file.endswith('_0000.nii.gz'):
            print(f"❌ Incorrect naming: {file} (should end with _0000.nii.gz)")
            return False
    
    print("✅ Dataset structure verification passed!")
    return True



In [None]:
verify_dataset_structure(dataset_path)

### env variable setup

For Linux

In [None]:
# ! export nnUNet_raw="/media/fabian/nnUNet_raw"
# ! export nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed"
# ! export nnUNet_results="/media/fabian/nnUNet_results"
# ! export nnUNet_n_proc_DA=12

For windows (PowerShell)

In [None]:
os.environ['nnUNet_raw'] = "F:\\aims_tbi\\nnUNet_raw"
os.environ['nnUNet_preprocessed'] = "F:\\aims_tbi\\nnUNet_preprocessed"
os.environ['nnUNet_results'] = "F:\\aims_tbi\\nnUNet_results"
os.environ['nnUNet_n_proc_DA'] = "12"

In [None]:
# Verify the environment variables are set
print("Environment variables set:")
print(f"nnUNet_raw: {os.environ.get('nnUNet_raw')}")
print(f"nnUNet_preprocessed: {os.environ.get('nnUNet_preprocessed')}")
print(f"nnUNet_results: {os.environ.get('nnUNet_results')}")
print(f"nnUNet_n_proc_DA: {os.environ.get('nnUNet_n_proc_DA')}")

### Model Training

In [None]:
! nnUNetv2_plan_and_preprocess -d 600 --verify_dataset_integrity

In [None]:
print(torch.cuda.is_available())

In [None]:
# Example for fold 0
! nnUNetv2_train 600 3d_fullres all -tr nnUNetTrainerLightMUNet