## nnU-Net: A Self-Configuring Neural Network for Biomedical Image Segmentation

nnU-Net is a self-configuring framework designed for biomedical image segmentation. It automatically adjusts preprocessing, network architecture, and training settings to fit any dataset, making it versatile and easy to use. nnU-Net delivers state-of-the-art results across various medical imaging tasks without needing manual tuning or expert input.

### nnU-Net Implementation for Wrist Bone MRI Segmentation

A quick guide to set up and run nnU-Net for wrist bone MRI segmentation.

### 1. Directory Setup

Ensure your dataset is organized in the following list format:

- `nnUNetFrame/`
  - `dataset/`
    - `nnUNet_raw_data/`
      - `Task501_WristBonesMRI/`
        - `imagesTr/`         # Training images
        - `imagesTs/`         # Testing images
        - `labelsTr/`         # Training labels
    - `nnUNet_preprocessed/`

### 2. Create dataset.json

Place a `dataset.json` file inside `Task501_WristBonesMRI/` with information about your dataset.

### 3. Preprocessing

Run the following command to preprocess your data:

nnUNet_plan_and_preprocess -t 501 --verify_dataset_integrity

### 4. Training the Model
Run the following command to train your data:

nnUNet_train 3d_fullres nnUNetTrainerV2 501 0

### 5. Validation
Run the following command to validate your data:

nnUNet_validate 3d_fullres nnUNetTrainerV2 501 0

### 6. Results

nnUNet_predict -i nnUNetFrame/dataset/nnUNet_raw_data/Task501_WristBonesMRI/imagesTs -o nnUNetFrame/dataset/nnUNet_raw_data/Task501_WristBonesMRI/output_predictions -t 501 -m 3d_fullres -f 0


In [4]:
import nibabel as nib
import numpy as np
import os
from scipy.ndimage import binary_erosion
from scipy.spatial import cKDTree

def load_nifti(file_path):
    nifti_data = nib.load(file_path)
    return nifti_data.get_fdata()

def dice_coefficient(pred, gt, class_idx):
    pred_binary = (pred == class_idx).astype(np.uint8)
    gt_binary = (gt == class_idx).astype(np.uint8)
    intersection = np.sum(pred_binary * gt_binary)
    dice = (2. * intersection) / (np.sum(pred_binary) + np.sum(gt_binary))
    return dice

def assd(pred, gt, class_idx):
    pred_binary = (pred == class_idx).astype(np.uint8)
    gt_binary = (gt == class_idx).astype(np.uint8)

    pred_surface = pred_binary - binary_erosion(pred_binary)
    gt_surface = gt_binary - binary_erosion(gt_binary)

    pred_surface_points = np.array(np.where(pred_surface)).T
    gt_surface_points = np.array(np.where(gt_surface)).T

    if len(pred_surface_points) == 0 or len(gt_surface_points) == 0:
        return np.nan  # no surface points to compare

    pred_tree = cKDTree(pred_surface_points)
    gt_tree = cKDTree(gt_surface_points)

    pred_to_gt = np.mean(gt_tree.query(pred_surface_points)[0])
    gt_to_pred = np.mean(pred_tree.query(gt_surface_points)[0])

    return (pred_to_gt + gt_to_pred) / 2.0

def calculate_metrics(test_outputs_dir, labels_dir, num_classes=10):
    dice_scores_all = {i: [] for i in range(1, num_classes + 1)}
    assd_scores_all = {i: [] for i in range(1, num_classes + 1)}

    for file_name in os.listdir(test_outputs_dir):
        if file_name.endswith('.nii.gz'):
            pred_file = os.path.join(test_outputs_dir, file_name)
            gt_file = os.path.join(labels_dir, file_name)

            pred_img = load_nifti(pred_file)
            gt_img = load_nifti(gt_file)

            for class_idx in range(1, num_classes + 1):
                dice = dice_coefficient(pred_img, gt_img, class_idx)
                assd_value = assd(pred_img, gt_img, class_idx)

                dice_scores_all[class_idx].append(dice)
                assd_scores_all[class_idx].append(assd_value)

    dice_means = {i: np.nanmean(dice_scores_all[i]) for i in range(1, num_classes + 1)}
    dice_stds = {i: np.nanstd(dice_scores_all[i]) for i in range(1, num_classes + 1)}

    assd_means = {i: np.nanmean(assd_scores_all[i]) for i in range(1, num_classes + 1)}
    assd_stds = {i: np.nanstd(assd_scores_all[i]) for i in range(1, num_classes + 1)}

    return dice_means, dice_stds, assd_means, assd_stds

# Defining directories
test_outputs_dir = '/workspace/test_outputs/'
labels_dir = '/workspace/data/labelsTs/'

# Calculating metrics
dice_means, dice_stds, assd_means, assd_stds = calculate_metrics(test_outputs_dir, labels_dir)

# Output results
for i in range(1, 11):
    print(f"Class {i}: Dice Mean = {dice_means[i]:.4f}, Dice Std Dev = {dice_stds[i]:.4f}")
    print(f"Class {i}: ASSD Mean = {assd_means[i]:.4f}, ASSD Std Dev = {assd_stds[i]:.4f}")
    print()


Class 1: Dice Mean = 0.8724, Dice Std Dev = 0.0694
Class 1: ASSD Mean = 0.9988, ASSD Std Dev = 0.5400

Class 2: Dice Mean = 0.9380, Dice Std Dev = 0.0201
Class 2: ASSD Mean = 0.6302, ASSD Std Dev = 0.1740

Class 3: Dice Mean = 0.9038, Dice Std Dev = 0.0371
Class 3: ASSD Mean = 0.5799, ASSD Std Dev = 0.1713

Class 4: Dice Mean = 0.8749, Dice Std Dev = 0.0663
Class 4: ASSD Mean = 0.7305, ASSD Std Dev = 0.3008

Class 5: Dice Mean = 0.8800, Dice Std Dev = 0.0430
Class 5: ASSD Mean = 0.6599, ASSD Std Dev = 0.2237

Class 6: Dice Mean = 0.9234, Dice Std Dev = 0.0244
Class 6: ASSD Mean = 0.5070, ASSD Std Dev = 0.1259

Class 7: Dice Mean = 0.8926, Dice Std Dev = 0.0354
Class 7: ASSD Mean = 0.5968, ASSD Std Dev = 0.1480

Class 8: Dice Mean = 0.9299, Dice Std Dev = 0.0315
Class 8: ASSD Mean = 0.5193, ASSD Std Dev = 0.1952

Class 9: Dice Mean = 0.8529, Dice Std Dev = 0.0811
Class 9: ASSD Mean = 0.8634, ASSD Std Dev = 0.4996

Class 10: Dice Mean = 0.9012, Dice Std Dev = 0.0331
Class 10: ASSD Mean =

In [5]:
# After calculating metrics using the calculate_metrics function
for i in range(1, 11):
    mean_dice = dice_means[i]
    std_dice = dice_stds[i]
    mean_assd = assd_means[i]
    std_assd = assd_stds[i]
    
    print(f"Class {i} mean dice: {mean_dice:.4f} + {std_dice:.4f}")
    print(f"Class {i} assd: {mean_assd:.2f} + {std_assd:.4f}\n")


Class 1 mean dice: 0.8724 + 0.0694
Class 1 assd: 1.00 + 0.5400

Class 2 mean dice: 0.9380 + 0.0201
Class 2 assd: 0.63 + 0.1740

Class 3 mean dice: 0.9038 + 0.0371
Class 3 assd: 0.58 + 0.1713

Class 4 mean dice: 0.8749 + 0.0663
Class 4 assd: 0.73 + 0.3008

Class 5 mean dice: 0.8800 + 0.0430
Class 5 assd: 0.66 + 0.2237

Class 6 mean dice: 0.9234 + 0.0244
Class 6 assd: 0.51 + 0.1259

Class 7 mean dice: 0.8926 + 0.0354
Class 7 assd: 0.60 + 0.1480

Class 8 mean dice: 0.9299 + 0.0315
Class 8 assd: 0.52 + 0.1952

Class 9 mean dice: 0.8529 + 0.0811
Class 9 assd: 0.86 + 0.4996

Class 10 mean dice: 0.9012 + 0.0331
Class 10 assd: 0.45 + 0.1361

