## BMEG 457: Brain MRI Segmentation (BraTS)

This notebook implements a brain tumor segmentation pipeline using the BraTS 2023 adult glioma dataset.

The dataset should be downloaded from [Kaggle BraTS 2023](https://www.kaggle.com/datasets/shakilrana/brats-2023-adult-glioma) and placed in a local data directory.

**Dataset structure:**

```markdown
data/
    ├── ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData/
    │     └── BraTS-GLI-*/
    └── ASNR-MICCAI-BraTS2023-GLI-Challenge-ValidationData/
          └── BraTS-GLI-*/
```

**Modalities:** FLAIR, T1, T1c, T2 (and segmentation labels)

In [None]:
import glob
import os
import random

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

try:
    import kagglehub
except ModuleNotFoundError:
    kagglehub = None

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
random.seed(SEED)

# Dataset location
if kagglehub is not None:
    path = kagglehub.dataset_download("shakilrana/brats-2023-adult-glioma")
    print("Path to dataset files:", path)
    dataset_path = path
else:
    dataset_path = "data"  # For local testing
    print("kagglehub not installed; using local data directory:", dataset_path)

In [None]:
## Setup: Data paths and train/val/test split

In [None]:
import random

import numpy as np
import tensorflow as tf

# Random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)

# We will work with the training folder (contains segmentation masks)
dataset_path = "data"
training_path = os.path.join(dataset_path, "ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData")

# List all case directories recursively (some may be nested)
train_dirs = sorted(glob.glob(os.path.join(training_path, "**", "BraTS-GLI-*"), recursive=True))
print("Total training cases found:", len(train_dirs))

if len(train_dirs) == 0:
    raise ValueError("No training cases were found. Check your dataset path and glob pattern.")

# Split cases into training and testing (70% train, 30% test)
train_cases, val_cases = train_test_split(train_dirs, test_size=0.3, random_state=SEED)
print("Number of training cases:", len(train_cases))
print("Number of test cases:", len(val_cases))

In [None]:
## Helper functions for data loading

In [None]:
## Data generator and dataset creation

In [None]:
def generate_examples(case_list, modality="flair"):
    """
    For each case in case_list:
      - Loads the specified modality (e.g., 'flair') and its segmentation ('seg').
      - Selects a random slice from the 3D volumes.
      - Normalizes the image slice to [0, 1] and resizes both image and segmentation to (128, 128).
      - Yields a tuple (image_slice, seg_slice).
    """
    for case in case_list:
        image_file = get_modality_file(case, modality=modality)
        seg_file = get_modality_file(case, modality="seg")

        if image_file is None:
            print(f"Image file not found for case: {case}")
            continue
        if seg_file is None:
            print(f"Segmentation file not found for case: {case}")
            continue

        # Load volumes (assumed to be 3D arrays)
        image_vol = load_volume(image_file)
        seg_vol = load_volume(seg_file)

        # Check if volumes are valid
        if image_vol.shape[2] <= 0 or seg_vol.shape[2] <= 0:
            continue

        # Randomly select a slice index
        slice_idx = random.randint(0, image_vol.shape[2] - 1)
        image_slice = image_vol[:, :, slice_idx]
        seg_slice = seg_vol[:, :, slice_idx]

        # Normalize image slice to [0,1]
        image_slice = image_slice.astype(np.float32)
        if image_slice.max() > 0:
            image_slice /= image_slice.max()

        # Expand dims to have channel dimension (H, W, 1)
        image_slice = np.expand_dims(image_slice, axis=-1)

        # Resize both image and segmentation to (128, 128)
        image_slice = tf.image.resize(image_slice, (128, 128)).numpy()
        seg_slice = tf.image.resize(np.expand_dims(seg_slice, axis=-1), (128, 128), method="nearest").numpy()
        seg_slice = np.squeeze(seg_slice, axis=-1).astype(np.int32)

        yield image_slice, seg_slice


# Create tf.data.Dataset objects for training and testing
BATCH_SIZE = 4  # Adjust as needed based on memory and TPU usage

train_dataset = tf.data.Dataset.from_generator(
    lambda: generate_examples(train_cases, modality="flair"),
    output_types=(tf.float32, tf.int32),
    output_shapes=((128, 128, 1), (128, 128)),
)
train_dataset = train_dataset.shuffle(buffer_size=20).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_generator(
    lambda: generate_examples(val_cases, modality="flair"),
    output_types=(tf.float32, tf.int32),
    output_shapes=((128, 128, 1), (128, 128)),
)
val_dataset = val_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# Debug: Check that the generator yields batches
batch_count = 0
for _ in train_dataset.take(1):
    batch_count += 1
print("Number of batches from training generator (should be > 0):", batch_count)

In [None]:
# TPU Setup (if using TPU)
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # Detect TPU
    print("Running on TPU:", tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError:
    print("TPU not found. Using CPU/GPU strategy.")
    strategy = tf.distribute.get_strategy()

print("Number of replicas:", strategy.num_replicas_in_sync)

In [None]:
def plot_middle_slice(volume, ax, title=""):
    """
    Plots the middle slice (along the third axis) of a 3D volume on the given matplotlib axes.
    """
    slice_index = volume.shape[2] // 2  # choose the middle slice
    ax.imshow(volume[:, :, slice_index], cmap="gray")
    ax.set_title(title)
    ax.axis("off")


def visualize_samples(case_list, n_samples=2, modality="flair", set_name=""):
    """
    Randomly selects n_samples cases from case_list, loads the modality volume,
    and visualizes the middle slice for each.
    """
    samples = random.sample(case_list, n_samples)
    fig, axes = plt.subplots(1, n_samples, figsize=(15, 5))

    # If only one sample is selected, adjust axes to be iterable.
    if n_samples == 1:
        axes = [axes]

    for i, case in enumerate(samples):
        mod_file = get_modality_file(case, modality=modality)
        if mod_file is not None:
            volume = load_volume(mod_file)
            case_name = os.path.basename(case)
            title = f"{set_name}\nCase: {case_name}"
            plot_middle_slice(volume, ax=axes[i], title=title)
        else:
            axes[i].text(
                0.5,
                0.5,
                "Modality file not found",
                horizontalalignment="center",
                verticalalignment="center",
            )
            axes[i].set_title(f"{set_name}\nCase: {os.path.basename(case)}")
            axes[i].axis("off")
    plt.tight_layout()
    plt.show()


# Optional: Debug - Inspect one case folder
sample_case = train_cases[0]
print("Sample case folder:", sample_case)
print("Contents of sample case:", os.listdir(sample_case))

# Visualize 2 random samples from each split using the 'flair' modality
visualize_samples(train_cases, n_samples=2, modality="flair", set_name="Training")
visualize_samples(val_cases, n_samples=2, modality="flair", set_name="Validation")

## Visualization of sample slices

In [None]:
def load_volume(file_path):
    """Loads a NIfTI file and returns its 3D data array."""
    volume = nib.load(file_path).get_fdata()
    return volume


def get_modality_file(case_dir, modality="flair"):
    """
    Searches recursively for a NIfTI file in the case directory that contains the modality keyword.
    Expected modality keywords: 'flair', 't1', 't1c', 't2', or 'seg'
    """
    nii_files = glob.glob(os.path.join(case_dir, "**", "*.nii*"), recursive=True)
    for file in nii_files:
        if modality.lower() in file.lower():
            return file
    return None


# Debug: Check a sample case for segmentation file
sample_case = train_cases[0]
seg_file_sample = get_modality_file(sample_case, modality="seg")
print("Sample case folder:", sample_case)
print("Contents of sample case:", os.listdir(sample_case))
print("Segmentation file found for sample case:", seg_file_sample)

In [None]:
def conv_block(x, filters):
    x = tf.keras.layers.Conv2D(filters, (3, 3), activation="relu", padding="same")(x)
    x = tf.keras.layers.Conv2D(filters, (3, 3), activation="relu", padding="same")(x)
    return x


def build_unet(input_shape=(128, 128, 1), num_classes=4):
    inputs = tf.keras.Input(input_shape)

    # Encoder path
    c1 = conv_block(inputs, 64)
    p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)

    c2 = conv_block(p1, 128)
    p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)

    c3 = conv_block(p2, 256)
    p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)

    c4 = conv_block(p3, 512)
    p4 = tf.keras.layers.MaxPooling2D((2, 2))(c4)

    # Bottleneck
    c5 = conv_block(p4, 1024)

    # Decoder path
    u6 = tf.keras.layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding="same")(c5)
    u6 = tf.keras.layers.concatenate([u6, c4])
    c6 = conv_block(u6, 512)

    u7 = tf.keras.layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding="same")(c6)
    u7 = tf.keras.layers.concatenate([u7, c3])
    c7 = conv_block(u7, 256)

    u8 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding="same")(c7)
    u8 = tf.keras.layers.concatenate([u8, c2])
    c8 = conv_block(u8, 128)

    u9 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding="same")(c8)
    u9 = tf.keras.layers.concatenate([u9, c1])
    c9 = conv_block(u9, 64)

    outputs = tf.keras.layers.Conv2D(num_classes, (1, 1), activation="softmax")(c9)
    model = tf.keras.Model(inputs, outputs)
    return model


# Build model with 4 segmentation classes (background + 3 tumor sub-regions)
with strategy.scope():
    model = build_unet(input_shape=(128, 128, 1), num_classes=4)

    # Compile with sparse categorical crossentropy (integer labels)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

model.summary()

## TPU Setup and distributed training strategy