# Import Required Libraries
Import necessary libraries such as numpy, tensorflow, keras, matplotlib, and others.

In [None]:
# Import necessary libraries
import os
import numpy as np
import tensorflow as tf
import keras
from matplotlib import pyplot as plt
import glob
import random
from tensorflow.keras import backend as K
import segmentation_models_3D as sm
from keras.metrics import MeanIoU
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import f1_score, precision_score, recall_score
from keras.models import load_model
from PIL import Image
import nibabel as nib
import splitfolders
from tensorflow.keras.utils import to_categorical
from tifffile import imsave

# Define Constants and Paths
Define constants and paths for training and validation datasets.

In [None]:
# Define Constants and Paths

# Define paths for training and validation datasets
TRAIN_DATASET_PATH = r'C:\Users\yagiz\OneDrive\Masaüstü\kodlar\UnetsegmentationDeneme\unet-segmentation-project\data\train\BraTS2020_TrainingData\MICCAI_BraTS2020_TrainingData'
VALIDATION_DATASET_PATH = r'C:\Users\yagiz\OneDrive\Masaüstü\kodlar\UnetsegmentationDeneme\unet-segmentation-project\data\validation\BraTS2020_ValidationData\MICCAI_BraTS2020_ValidationData'

# Define paths for input images and masks for training
TRAIN_IMG_DIR = os.path.join(TRAIN_DATASET_PATH, "input_data_3channels/images/")
TRAIN_MASK_DIR = os.path.join(TRAIN_DATASET_PATH, "input_data_3channels/masks/")

# Define paths for input images for validation
VAL_IMG_DIR = os.path.join(VALIDATION_DATASET_PATH, "input_data_3channels/images/")

# Define batch size for training and validation
BATCH_SIZE = 2

# Load and Preprocess Data
Load and preprocess the data, including scaling and reshaping images and masks.

In [None]:
# Load and Preprocess Data

# Function to load .npy files from a given directory
def load_img(img_dir, img_list):
    images = []
    for image_name in img_list:
        if image_name.endswith('.npy'):  # Only load .npy files
            image = np.load(os.path.join(img_dir, image_name)).astype(np.float32)
            images.append(image)
    return np.array(images)

# Function to load and preprocess images and masks in batches
def imageLoader(img_dir, img_list, mask_dir, mask_list, batch_size, dtype=np.float32):
    L = len(img_list)
    while True:
        batch_start = 0
        batch_end = batch_size
        while batch_start < L:
            limit = min(batch_end, L)
            X = load_img(img_dir, img_list[batch_start:limit])  # Load images
            Y = load_img(mask_dir, mask_list[batch_start:limit])  # Load masks
            yield (X, Y)  # Return a tuple containing two numpy arrays
            batch_start += batch_size
            batch_end += batch_size

# Function to load and preprocess validation images in batches
def val_imageLoader(img_dir, img_list, batch_size, dtype=np.float32):
    L = len(img_list)
    while True:
        batch_start = 0
        batch_end = batch_size
        while batch_start < L:
            limit = min(batch_end, L)
            X = load_img(img_dir, img_list[batch_start:limit])  # Load images
            yield X  # Return only images
            batch_start += batch_size
            batch_end += batch_size

# Initialize data generators for training and validation
train_img_list = os.listdir(TRAIN_IMG_DIR)
train_mask_list = os.listdir(TRAIN_MASK_DIR)
val_img_list = os.listdir(VAL_IMG_DIR)

train_img_datagen = imageLoader(TRAIN_IMG_DIR, train_img_list, TRAIN_MASK_DIR, train_mask_list, BATCH_SIZE)
val_img_datagen = val_imageLoader(VAL_IMG_DIR, val_img_list, BATCH_SIZE)

# Test the data generator
img, msk = next(train_img_datagen)

# Visualize a random image and its corresponding mask from the batch
img_num = random.randint(0, img.shape[0] - 1)
test_img = img[img_num]
test_mask = msk[img_num]
test_mask = np.argmax(test_mask, axis=3)

n_slice = random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.imshow(test_img[:, :, n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(232)
plt.imshow(test_img[:, :, n_slice, 1], cmap='gray')
plt.title('Image t1')
plt.subplot(233)
plt.imshow(test_img[:, :, n_slice, 2], cmap='gray')
plt.title('Image t1ce')
plt.subplot(234)
plt.imshow(test_img[:, :, n_slice, 3], cmap='gray')
plt.title('Image t2')
plt.subplot(235)
plt.imshow(test_mask[:, :, n_slice])
plt.title('Mask')
plt.show()

# Data Generator Functions
Define functions to load images and masks in batches for training and validation.

In [None]:
# Function to load .npy files from a given directory
def load_img(img_dir, img_list):
    images = []
    for image_name in img_list:
        if image_name.endswith('.npy'):  # Only load .npy files
            image = np.load(os.path.join(img_dir, image_name)).astype(np.float32)
            images.append(image)
    return np.array(images)

# Function to load and preprocess images and masks in batches
def imageLoader(img_dir, img_list, mask_dir, mask_list, batch_size, dtype=np.float32):
    L = len(img_list)
    while True:
        batch_start = 0
        batch_end = batch_size
        while batch_start < L:
            limit = min(batch_end, L)
            X = load_img(img_dir, img_list[batch_start:limit])  # Load images
            Y = load_img(mask_dir, mask_list[batch_start:limit])  # Load masks
            yield (X, Y)  # Return a tuple containing two numpy arrays
            batch_start += batch_size
            batch_end += batch_size

# Function to load and preprocess validation images in batches
def val_imageLoader(img_dir, img_list, batch_size, dtype=np.float32):
    L = len(img_list)
    while True:
        batch_start = 0
        batch_end = batch_size
        while batch_start < L:
            limit = min(batch_end, L)
            X = load_img(img_dir, img_list[batch_start:limit])  # Load images
            yield X  # Return only images
            batch_start += batch_size
            batch_end += batch_size

# Initialize data generators for training and validation
train_img_list = os.listdir(TRAIN_IMG_DIR)
train_mask_list = os.listdir(TRAIN_MASK_DIR)
val_img_list = os.listdir(VAL_IMG_DIR)

train_img_datagen = imageLoader(TRAIN_IMG_DIR, train_img_list, TRAIN_MASK_DIR, train_mask_list, BATCH_SIZE)
val_img_datagen = val_imageLoader(VAL_IMG_DIR, val_img_list, BATCH_SIZE)

# Test the data generator
img, msk = next(train_img_datagen)

# Visualize a random image and its corresponding mask from the batch
img_num = random.randint(0, img.shape[0] - 1)
test_img = img[img_num]
test_mask = msk[img_num]
test_mask = np.argmax(test_mask, axis=3)

n_slice = random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.imshow(test_img[:, :, n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(232)
plt.imshow(test_img[:, :, n_slice, 1], cmap='gray')
plt.title('Image t1')
plt.subplot(233)
plt.imshow(test_img[:, :, n_slice, 2], cmap='gray')
plt.title('Image t1ce')
plt.subplot(234)
plt.imshow(test_img[:, :, n_slice, 3], cmap='gray')
plt.title('Image t2')
plt.subplot(235)
plt.imshow(test_mask[:, :, n_slice])
plt.title('Mask')
plt.show()

# Visualize Data
Visualize the images and masks to understand the data better.

In [None]:
# Visualize Data

# Visualize a random image and its corresponding mask from the batch
img_num = random.randint(0, img.shape[0] - 1)
test_img = img[img_num]
test_mask = msk[img_num]
test_mask = np.argmax(test_mask, axis=3)

n_slice = random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.imshow(test_img[:, :, n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(232)
plt.imshow(test_img[:, :, n_slice, 1], cmap='gray')
plt.title('Image t1')
plt.subplot(233)
plt.imshow(test_img[:, :, n_slice, 2], cmap='gray')
plt.title('Image t1ce')
plt.subplot(234)
plt.imshow(test_img[:, :, n_slice, 3], cmap='gray')
plt.title('Image t2')
plt.subplot(235)
plt.imshow(test_mask[:, :, n_slice])
plt.title('Mask')
plt.show()

# Define Model and Training Functions
Define the U-Net model, loss functions, metrics, and other training-related functions.

In [None]:
# Define Model and Training Functions

# Define the U-Net model
from Unet_model import simple_unet_model

# Define loss functions
wt0, wt1, wt2, wt3 = np.float32(0.25), np.float32(0.25), np.float32(0.25), np.float32(0.25)
dice_loss = sm.losses.DiceLoss(class_weights=np.array([wt0, wt1, wt2, wt3], dtype=np.float32))
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

# Define custom metrics
def dice_metric(y_true, y_pred):
    intersection = K.sum(y_true * y_pred)
    return (2. * intersection) / (K.sum(y_true) + K.sum(y_pred) + K.epsilon())

def iou_metric(y_true, y_pred):
    intersection = K.sum(y_true * y_pred)
    union = K.sum(y_true) + K.sum(y_pred) - intersection
    return intersection / (union + K.epsilon())

def f1_score_metric(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(K.round(y_pred), 'float32')
    tp = K.sum(y_true * y_pred)
    fp = K.sum(y_pred) - tp
    fn = K.sum(y_true) - tp
    return 2 * tp / (2 * tp + fp + fn + K.epsilon())

def precision_metric(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(K.round(y_pred), 'float32')
    tp = K.sum(y_true * y_pred)
    fp = K.sum(y_pred) - tp
    return tp / (tp + fp + K.epsilon())

def recall_metric(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(K.round(y_pred), 'float32')
    tp = K.sum(y_true * y_pred)
    fn = K.sum(y_true) - tp
    return tp / (tp + fn + K.epsilon())

# Define function to calculate all metrics
def calculate_all_metrics(y_true, y_pred):
    y_true_flat = y_true.flatten()
    y_pred_flat = y_pred.flatten()
    
    f1 = f1_score(y_true_flat, y_pred_flat, average='weighted')
    precision = precision_score(y_true_flat, y_pred_flat, average='weighted', zero_division=0)
    sensitivity = recall_score(y_true_flat, y_pred_flat, average='weighted', zero_division=0)
    
    tn = np.sum((y_true_flat == 0) & (y_pred_flat == 0))
    fp = np.sum((y_true_flat == 0) & (y_pred_flat == 1))
    specificity = tn / (tn + fp + 1e-7)
    
    intersection = np.sum((y_true_flat == 1) & (y_pred_flat == 1))
    union = np.sum((y_true_flat == 1) | (y_pred_flat == 1))
    iou = intersection / (union + 1e-7)
    
    dice = 2 * intersection / (np.sum(y_true_flat) + np.sum(y_pred_flat) + 1e-7)
    
    return {
        'f1': f1,
        'precision': precision,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'iou': iou,
        'dice': dice
    }

# Define function to plot metrics
def plot_metrics(history):
    for metric_name in history.keys():
        plt.figure(figsize=(10, 6))
        plt.plot(history[metric_name], 'y', label=f'Train {metric_name}')
        if f'val_{metric_name}' in history:
            plt.plot(history[f'val_{metric_name}'], 'r', label=f'Validation {metric_name}')
        plt.title(f'Training and Validation {metric_name}')
        plt.xlabel('Epoch')
        plt.ylabel(metric_name)
        plt.legend()
        plt.show()

def plot_precision_recall_f1(history):
    if 'precision_metric' in history and 'recall_metric' in history and 'f1_score_metric' in history:
        plt.figure(figsize=(10, 6))
        plt.plot(history['recall_metric'], history['precision_metric'], 'b', label='Precision vs Recall')
        plt.plot(history['recall_metric'], history['f1_score_metric'], 'g', label='F1 Score vs Recall')
        plt.title('Precision and F1 Score vs Recall')
        plt.xlabel('Recall')
        plt.ylabel('Score')
        plt.legend()
        plt.show()

# Define learning rate and optimizer
LR = 0.0001
optim = keras.optimizers.Adam(LR)

# Define metrics
metrics = [
    'accuracy',
    sm.metrics.IOUScore(threshold=0.5),
    sm.metrics.FScore(threshold=0.5),
    precision_metric,
    recall_metric,
    f1_score_metric,
    iou_metric,
    dice_metric
]

# Initialize U-Net model
model = simple_unet_model(IMG_HEIGHT=128, IMG_WIDTH=128, IMG_DEPTH=128, IMG_CHANNELS=3, num_classes=4)

# Compile the model
model.compile(optimizer=optim, loss=total_loss, metrics=metrics)
print(model.summary())

# Train the Model
Train the U-Net model using the defined data generators and visualize the training process.

In [None]:
# Train the Model

# Calculate steps per epoch for training and validation
steps_per_epoch = len(train_img_list) // BATCH_SIZE
val_steps_per_epoch = len(val_img_list) // BATCH_SIZE

# Train the model
epochs = 1
history = model.fit(train_img_datagen,
                    steps_per_epoch=steps_per_epoch,
                    epochs=epochs,
                    verbose=1,
                    validation_data=val_img_datagen,
                    validation_steps=val_steps_per_epoch)

# Save the trained model
model_filename = f'saved_models/brats_3d_{epochs}epochs_simple_unet_weighted_dice.hdf5'
model.save(model_filename)

# Plot training and validation metrics
plot_metrics(history.history)

# Plot precision, recall, and F1 score
plot_precision_recall_f1(history.history)

# Evaluate the Model
Evaluate the trained model on validation data and calculate metrics such as IoU, Dice coefficient, etc.

In [None]:
# Evaluate the Model

# Load the trained model
model_filename = f'saved_models/brats_3d_{epochs}epochs_simple_unet_weighted_dice.hdf5'
my_model = load_model(model_filename, 
                      custom_objects={'dice_loss_plus_1focal_loss': total_loss,
                                      'iou_score': sm.metrics.IOUScore(threshold=0.5), 
                                      'f_score': sm.metrics.FScore(threshold=0.5),
                                      'iou_metric': iou_metric,
                                      'dice_metric': dice_metric,
                                      'f1_score_metric': f1_score_metric,
                                      'precision_metric': precision_metric,
                                      'recall_metric': recall_metric})
my_model.compile(optimizer=keras.optimizers.Adam(LR), loss=total_loss, metrics=metrics)

# Evaluate the model on validation data
test_img_datagen = imageLoader(VAL_IMG_DIR, val_img_list, 
                               TRAIN_MASK_DIR, train_mask_list, BATCH_SIZE)

# Get a batch of validation images and masks
test_image_batch, test_mask_batch = next(test_img_datagen)

# Predict on the validation batch
test_mask_batch_argmax = np.argmax(test_mask_batch, axis=4)
test_pred_batch = my_model.predict(test_image_batch)
test_pred_batch_argmax = np.argmax(test_pred_batch, axis=4)

# Calculate Mean IoU
n_classes = 4
IOU_keras = MeanIoU(num_classes=n_classes)
IOU_keras.update_state(test_pred_batch_argmax, test_mask_batch_argmax)
print("Mean IoU =", IOU_keras.result().numpy())

# Visualize predictions on a random validation image
img_num = random.randint(0, test_image_batch.shape[0] - 1)
test_img = test_image_batch[img_num]
test_mask = test_mask_batch_argmax[img_num]
test_pred = test_pred_batch_argmax[img_num]

n_slice = random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.title('Validation Image')
plt.imshow(test_img[:, :, n_slice, 0], cmap='gray')
plt.subplot(232)
plt.title('Validation Mask')
plt.imshow(test_mask[:, :, n_slice])
plt.subplot(233)
plt.title('Prediction on Validation Image')
plt.imshow(test_pred[:, :, n_slice])
plt.show()

# Calculate and print all metrics
metrics = calculate_all_metrics(test_mask, test_pred)
for metric_name, value in metrics.items():
    print(f'{metric_name}: {value:.4f}')

# Visualize Predictions
Visualize the model's predictions on test images and compare them with the ground truth masks.

In [None]:
# Visualize Predictions

# Function to preprocess image
def preprocess_image(image_path, target_size):
    if image_path.endswith('.npy'):
        img = np.load(image_path)
    else:
        img = Image.open(image_path)
        img = img.resize(target_size)
        img = np.array(img)
    
    img = img.astype(np.float32) / 255.0  # Normalize to [0, 1]
    img = np.expand_dims(img, axis=0)  # Add batch dimension
    return img

# Function to predict and visualize
def predict_and_visualize(model, image_path, mask_path=None, target_size=(128, 128, 128)):
    # Preprocess the image
    img = preprocess_image(image_path, target_size)
    
    # Make prediction
    prediction = model.predict(img)
    prediction_argmax = np.argmax(prediction, axis=4)[0, :, :, :]
    
    # Visualize the results
    n_slice = prediction_argmax.shape[2] // 2  # Select middle slice for visualization
    plt.figure(figsize=(12, 6))
    
    plt.subplot(121)
    plt.title('Original Image')
    plt.imshow(img[0, :, :, n_slice, 0], cmap='gray')
    
    plt.subplot(122)
    plt.title('Predicted Mask')
    plt.imshow(prediction_argmax[:, :, n_slice])
    
    plt.show()
    
    # If mask is provided, calculate and print metrics
    if mask_path:
        mask = preprocess_image(mask_path, target_size)
        mask_argmax = np.argmax(mask, axis=4)[0, :, :, :]
        metrics = calculate_all_metrics(mask_argmax, prediction_argmax)
        for metric_name, value in metrics.items():
            print(f'{metric_name}: {value:.4f}')
        
        return metrics

# Example usage
image_path = os.path.join(VAL_IMG_DIR, 'image_0.npy')  # Update with your image path
mask_path = os.path.join(TRAIN_MASK_DIR, 'mask_0.npy')  # Update with your mask path if available
metrics = predict_and_visualize(my_model, image_path, mask_path)