## Dataset Extraction and Data Splits

In this notebook, we create the 2D dataset from the original BraTS 3D volumes. This is done by exporting a selection of axial slices from each volume across the four different modalities in BraTS.

In [None]:
import os
import sys
import torch
import json
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import random_split

# Add braintumor_ddpm to path
sys.path.append(os.path.dirname(os.getcwd()))

# braintumor_ddpm imports
from braintumor_ddpm.data.brats import BRATS
from braintumor_ddpm.data.datasets import SegmentationDataset
from braintumor_ddpm.utils.convert_data import convert_brats_to_2d, move_data

## 1. Dataset Extraction

In our experiments we extract data to two different formats. The first one is to TIFF-based format and the other is to NIfTI format. We show how to export data in the below sections accordingly. 

### 1.1. Extraction to TIFF-based files

We train our diffusion model, pixel-level classifiers and the retraining of the noise predictor network on multi-page TIFF files, where each page corresponds to a different modality. Extraction can be done by simply calling the `export_stack()` function from the `BRATS` class. By setting `slices` parameter we can choose a list of slices to export, when not specified it defaults to the one we use in our experiments.

In [None]:
# Extract a 2D dataset out of original 3D volumes
output_directory = r"REPLACE WITH PATH POINTING TO OUTPUT DIRECTORY"
path = r"REPLACE WITH PATH POINTING TO BRATS TRAINING DATA"

# Extract 2D data from the original 3D dataset
brats_dataset = BRATS(path=path)
brats_dataset.export_stack(output_path=output_directory, slices=None)

### 1.2. Extraction to NIfTI-based files

Because we also compare against a supervised baseline, which is nnUNet V1, we also have to export the same slices in a 2D format compatiable with nnUNet V1. Thus, we export axial slices to NIfTI suitable for training the baseline. Running the below cell, we export the entire dataset by calling the `convert_brats_to_2d()` function.

In [None]:
# nnUNet compatible dataset extraction
nnunet_data_dir = os.path.join(output_directory, "nnUNet_raw")
convert_brats_to_2d(dataset_path=path,
                    target_dir=nnunet_data_dir,
                    slices=None)

### Data Splits

The data splits for the down-stream tasks are automatically configured within the `braintumor-ddpm` pipeline. However, the following code block is mainly towards extracting the same data split that is compatible with nnUNet V1. Additionally, we also export the split metadata for each experiment.

In [None]:
# Specify extracted 2D data paths 
images_path = os.path.join(output_directory, "Stacked 2D BRATS Data", "scans")
masks_path = os.path.join(output_directory, "Stacked 2D BRATS Data", "masks")

# Create a segmentation dataset object
dataset = SegmentationDataset(images_dir=images_path,
                              masks_dir=masks_path,
                              image_size=128,
                              device='cpu',
                              verbose=False)

# Our main split, with the same seed in our experiments
train_pool, test = random_split(dataset=dataset,
                                lengths=[757, 8000],
                                generator=torch.Generator().manual_seed(42))

print(f"Training pool: {len(train_pool)} images, Test data: {len(test)} images")

Uncomment the few lines below to extract the data used for the upper-bound supervised model and also the test set of 8000 scans.

In [None]:
# Exports the same splits for nnUNet comparison. rest of training pool is considered as validation
all_seeds = [16, 42, 88, 128, 256]
train_images = [10, 20, 30, 40, 50]
task_id = 600
export_test = False

# Uncomment below to export upper-bound as well as test set
# all_seeds = [42]
# train_images = [757]
# task_id = 700
# export_test = True

output_folder = os.path.join(nnunet_data_dir, "split_metadata")
data_folder = nnunet_data_dir

for train_size in tqdm(train_images):
    for seed in all_seeds:
        
        # Split to acquire training data
        train, valid = random_split(dataset=train_pool, lengths=[train_size, 757 - train_size],
                                    generator=torch.Generator().manual_seed(seed))
        if export_test:
            valid = test
        
        # Create training and validation splits
        data_split = {'training': [], 'testing': []}
        for i in train.indices:
            filename = os.path.basename(dataset.dataset[i]['mask'])
            filename = filename.split('_')
            slice_id = int(filename[-1].split('.')[0])
            filename = f"BraTS_{int(filename[1]):05d}s{slice_id:03d}"
            data_split['training'].append(filename)

        for i in valid.indices:
            filename = os.path.basename(dataset.dataset[i]['mask'])
            filename = filename.split('_')
            slice_id = int(filename[-1].split('.')[0])
            filename = f"BraTS_{int(filename[1]):05d}s{slice_id:03d}"
            data_split['testing'].append(filename)
        
        # Create Split folder and save split metadata
        split_folder = os.path.join(output_folder, f"{train_size}_samples")
        os.makedirs(split_folder, exist_ok=True)
        name = os.path.join(split_folder, f'data_split_size_{train_size}_seed_{seed}.json')
        with open(name, 'w') as jf:
            json.dump(data_split, jf)
        jf.close()
        
        # Move data to data folder
        move_data(target_dir=data_folder,
         split=data_split,
         images_dir=os.path.join(nnunet_data_dir, "all_images"),
         labels_dir=os.path.join(nnunet_data_dir, "all_labels"),
         seed=f"_{seed}_{train_size}_samples",
         task_id=task_id)
        task_id += 1