In [None]:
# !pip install SimpleITK
# !pip install matplotlib
# !pip install scikit-learn
# !pip install segmentation_models

In [None]:
import SimpleITK as sitk
import os
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

In [None]:
adc_folder_path = 'BONBID2023_Train/1ADC_ss'
adc_files = [os.path.join(adc_folder_path, f) for f in os.listdir(adc_folder_path) if f.endswith('.mha')]
adc_files = sorted(adc_files)

In [None]:
zadc_folder_path = 'BONBID2023_Train/2Z_ADC'
zadc_files = [os.path.join(zadc_folder_path, f) for f in os.listdir(zadc_folder_path) if f.endswith('.mha')]
zadc_files = sorted(zadc_files)

In [None]:
label_folder_path = 'BONBID2023_Train/3LABEL'
label_files = [os.path.join(label_folder_path, f) for f in os.listdir(label_folder_path) if f.endswith('.mha')]
label_files = sorted(label_files)

# **Read and display data**

Read and display ADC data

In [None]:
adc_data = []
adc_stik_data = []
for file in adc_files:
    image = sitk.ReadImage(file)
    adc_stik_data.append(image)
    array = sitk.GetArrayFromImage(image)
    adc_data.append(array)

In [None]:
fig, axes = plt.subplots(5,5, figsize=(14,14))

for i,ax in enumerate(axes.flat):
    if i < adc_data[0].shape[0]:
        ax.imshow(adc_data[0][i], cmap='gray')

Read and display Zadc data

In [None]:
zadc_data = []
for file in zadc_files:
    image = sitk.ReadImage(file)
    array = sitk.GetArrayFromImage(image)
    zadc_data.append(array)

In [None]:
fig, axes = plt.subplots(5,5, figsize=(14, 14))

for i,ax in enumerate(axes.flat):
    if i < zadc_data[0].shape[0]:
        ax.imshow(zadc_data[0][i], cmap="jet")

Read and display label data

In [None]:
label_data = []
for file in label_files:
    image = sitk.ReadImage(file)
    array = sitk.GetArrayFromImage(image)
    label_data.append(array)

In [None]:
fig, axes = plt.subplots(5,5, figsize=(14, 14))

for i,ax in enumerate(axes.flat):
    if i < label_data[0].shape[0]:
        ax.imshow(label_data[0][i], cmap="copper")

# **Ready data for training**

In [None]:
# Flatten a list of data with n length and (m, 128, 128) shape of each array, so it becomes a list of n*m of arrays of shape (128, 128)

flattened_adc_data = [item for sublist in adc_data for item in sublist]
flattened_zadc_data = [item for sublist in zadc_data for item in sublist]
flattened_label_data = [item for sublist in label_data for item in sublist]

In [None]:
from sklearn.model_selection import train_test_split
flattened_zadc_data_train, flattened_zadc_data_test, flattened_label_data_train, flattened_label_data_test = train_test_split(flattened_zadc_data, flattened_label_data, test_size=0.1, random_state=42)

In [None]:
def slice_image(image):
    """
    Slices the given image into 64x64 patches.
    
    Args:
    image (numpy.ndarray): The image to be sliced, expected shape is (H, W).
    
    Returns:
    list: A list of 64x64 image patches.
    """
    patches = []
    h, w = image.shape
    
    if (h, w) == (256, 256):
        # Slicing 256x256 image into sixteen non-overlapping 64x64 patches
        for i in range(0, h, 64):
            for j in range(0, w, 64):
                patches.append(image[i:i+64, j:j+64])
    elif (h, w) == (160, 160):
        # Slicing 160x160 image with overlap to create 64x64 patches
        for i in range(0, h-64+1, 64):
            for j in range(0, w-64+1, 64):
                patches.append(image[i:i+64, j:j+64])
        # Adding overlapping patches
        patches.append(image[96:160, 96:160])
    elif (h, w) == (128, 128):
        # Slicing 128x128 image with overlap to create 64x64 patches
        for i in range(0, h-64+1, 64):
            for j in range(0, w-64+1, 64):
                patches.append(image[i:i+64, j:j+64])
    elif (h, w) == (64, 64):
        # Returning the image itself if it's already 64x64
        patches.append(image)
    else:
        raise ValueError("Unsupported image size. Expected (64, 64), (128, 128), (160, 160), or (256, 256).")
    
    return patches

def process_images_and_labels(images, labels):
    """
    Processes lists of images and labels, slicing each into 64x64 patches.
    
    Args:
    images (list): List of numpy.ndarray images.
    labels (list): List of numpy.ndarray labels corresponding to the images.
    
    Returns:
    tuple: Two lists containing the sliced images and corresponding labels.
    """
    all_image_patches = []
    all_label_patches = []
    
    for image, label in zip(images, labels):
        image_patches = slice_image(image)
        label_patches = slice_image(label)
        
        all_image_patches.extend(image_patches)
        all_label_patches.extend(label_patches)
    
    return all_image_patches, all_label_patches


In [None]:
flattened_cropped_zadc_data_list, flattened_cropped_label_data_list = process_images_and_labels(flattened_zadc_data_train, flattened_label_data_train)

flattened_test_cropped_zadc_data_list, flattened_test_cropped_label_data_list = process_images_and_labels(flattened_zadc_data_test, flattened_label_data_test)

In [None]:
def display_patches(original_image, patches):
    """
    Display the 64x64 patches in a grid to verify their placement.
    
    Args:
    original_image (numpy.ndarray): The original image before slicing.
    patches (list): A list of 64x64 image patches.
    """
    h, w = original_image.shape
    fig, axs = plt.subplots(h // 64, w // 64, figsize=(10, 10))
    
    patch_index = 0
    for i in range(h // 64):
        for j in range(w // 64):
            axs[i, j].imshow(patches[patch_index])
            axs[i, j].axis('off')
            patch_index += 1

    plt.tight_layout()
    plt.show()

# Display patches for the first image
original_image = flattened_zadc_data_train[875]
patches = slice_image(original_image)
display_patches(original_image, patches)

In [None]:
def stitch_image(patches, original_shape):
    """
    Stitches 64x64 patches back together into the original image.
    
    Args:
    patches (list): A list of 64x64 image patches.
    original_shape (tuple): The shape of the original image (height, width).
    
    Returns:
    numpy.ndarray: The reconstructed image.
    """
    h, w = original_shape
    reconstructed_image = np.zeros(original_shape)
    
    patch_index = 0
    for i in range(0, h, 64):
        for j in range(0, w, 64):
            reconstructed_image[i:i+64, j:j+64] = patches[patch_index]
            patch_index += 1
    
    return reconstructed_image

In [None]:
# Display patches for the first image
original_image = flattened_zadc_data_test[3]
patches = slice_image(original_image)
display_patches(original_image, patches)

In [None]:
# Train data
# Convert the list of arrays to a single NumPy array
flattened_cropped_zadc_data = np.array(flattened_cropped_zadc_data_list)

# Reshape the array to have shape (num_samples, 128, 128, 1)
flattened_cropped_zadc_data = flattened_cropped_zadc_data.reshape((len(flattened_cropped_zadc_data_list), 64, 64, 1))

# Print the shape to verify
print(flattened_cropped_zadc_data.shape)

flattened_cropped_label_data = np.array(flattened_cropped_label_data_list)
flattened_cropped_label_data = flattened_cropped_label_data.reshape((len(flattened_cropped_label_data_list), 64, 64, 1))

# Print the shape to verify
print(flattened_cropped_label_data.shape)

In [None]:
#Test Data
# Convert the list of arrays to a single NumPy array
flattened_test_cropped_zadc_data = np.array(flattened_test_cropped_zadc_data_list)

# Reshape the array to have shape (num_samples, 128, 128, 1)
flattened_test_cropped_zadc_data = flattened_test_cropped_zadc_data.reshape((len(flattened_test_cropped_zadc_data_list), 64, 64, 1))

# Print the shape to verify
print(flattened_test_cropped_zadc_data.shape)

flattened_test_cropped_label_data = np.array(flattened_test_cropped_label_data_list)
flattened_test_cropped_label_data = flattened_test_cropped_label_data.reshape((len(flattened_test_cropped_label_data_list), 64, 64, 1))

# Print the shape to verify
print(flattened_test_cropped_label_data.shape)

# **U Net: ResNet18**

In [None]:
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"

from tensorflow import keras
import segmentation_models as sm
# Segmentation Models: using `keras` framework.

# **Model1: Dice Loss**

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import os

# Define paths
checkpoint_dir = '/Epochs/ResNet18dice'
checkpoint_filepath = os.path.join(checkpoint_dir, 'model_checkpoint_{epoch:02d}.keras')
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
latest_checkpoint = max([os.path.join(checkpoint_dir, f) for f in os.listdir(checkpoint_dir) if f.startswith('model_checkpoint')], key=os.path.getctime, default=None)
log_filepath = 'training_log_ResNet18dice.csv'

# Create the CSVLogger
csv_logger = CSVLogger(log_filepath, append=True)

# ModelCheckpoint callback to save the model
model_checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=False,
    monitor='val_mean_io_u',
    mode='max',
    save_best_only=False,
    save_freq='epoch'
)

In [None]:

import segmentation_models as sm
from sklearn.model_selection import train_test_split

BACKBONE = 'resnet18'

# load your data
x_train1, x_val1, y_train1, y_val1 = train_test_split(flattened_cropped_zadc_data, flattened_cropped_label_data, test_size=0.2, random_state=42)

# define model
model1 = sm.Unet(BACKBONE, encoder_weights=None, input_shape=(64, 64, 1))
model1.compile(
    'Adam',
    loss=sm.losses.DiceLoss,
    metrics=[sm.metrics.iou_score, sm.metrics.FScore],
)

history1 = model1.fit(
   x=x_train1,
   y=y_train1,
   batch_size=32,
   epochs=500,
   validation_data=(x_val1, y_val1),
   callbacks=[model_checkpoint_callback, csv_logger]
)

In [None]:
# Plot training & validation accuracy, loss, and MeanIoU values
import matplotlib.pyplot as plt

plt.figure(figsize=(16, 6))

plt.subplot(1, 3, 1)
plt.plot(history1.history['loss'])
plt.plot(history1.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Dice Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.subplot(1, 3, 2)
plt.plot(history1.history['iou_score'])
plt.plot(history1.history['val_iou_score'])
plt.title('Model MeanIoU')
plt.ylabel('MeanIoU')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.subplot(1, 3, 3)
plt.plot(history1.history['f1-score'])
plt.plot(history1.history['val_f1-score'])
plt.title('Model F1 Score')
plt.ylabel('F1 Score')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

# Add a main title to the entire figure
plt.suptitle('Model 1 : ResNet18 Dice Loss', fontsize=16)

# Adjust spacing between subplots
plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust `rect` to make space for the suptitle

# Save the figure
plt.savefig('resnetDice.png', dpi=300)

plt.show()

In [None]:
#TEST
IMAGE_NUMBER = 230
original_image = flattened_zadc_data_test[IMAGE_NUMBER]
label_image = flattened_label_data_test[IMAGE_NUMBER]
patches = slice_image(original_image)

patches_data = np.array(patches)

# Reshape the array to have shape (num_samples, 128, 128, 1)
patches_data = patches_data.reshape((len(patches), 64, 64, 1))

predictions = model1.predict(patches_data)
predictions = predictions.reshape((len(predictions), 64, 64))
reconstructed_image = stitch_image(predictions, original_image.shape)


plt.figure(figsize=(20, 20))
plt.subplot(1, 3, 1)
plt.imshow(original_image)
plt.title('ZADC image')

plt.subplot(1, 3, 2)
plt.imshow(label_image, cmap="grey")
plt.title('Label image')

plt.subplot(1, 3, 3)
plt.imshow(reconstructed_image, cmap="grey")
plt.title('Predicted image')

# Save the figure
plt.savefig('resNetDicePrediction.png', dpi=300)

plt.show()

In [None]:
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss()

metrics = [sm.metrics.IOUScore(), sm.metrics.FScore()]

In [None]:
model1.compile(
    'Adam',
    loss=dice_loss,
    metrics=metrics,
)

# Evaluate the model
evaluation = model1.evaluate(flattened_test_cropped_zadc_data, flattened_test_cropped_label_data, batch_size=32)

# Print the results
print("Evaluation results on test data (Dice):")
print(f"Dice Loss: {evaluation[0]}")
print(f"IOU Score: {evaluation[1]}")
print(f"F1 Score: {evaluation[2]}")

In [None]:
model1.compile(
    'Adam',
    loss=focal_loss,
    metrics=metrics,
)

# Evaluate the model
evaluation = model1.evaluate(flattened_test_cropped_zadc_data, flattened_test_cropped_label_data, batch_size=32)

# Print the results
print("Evaluation results on test data (Dice):")
print(f"Focal Loss: {evaluation[0]}")
print(f"IOU Score: {evaluation[1]}")
print(f"F1 Score: {evaluation[2]}")

# **Model 2: Dice + Focal Loss**

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import os
import re

# Define paths
checkpoint_dir2 = '/Epochs/ResNet18dicefocal'
checkpoint_filepath2 = os.path.join(checkpoint_dir2, 'model_checkpoint_{epoch:02d}.keras')
if not os.path.exists(checkpoint_dir2):
    os.makedirs(checkpoint_dir2)
latest_checkpoint2 = max([os.path.join(checkpoint_dir2, f) for f in os.listdir(checkpoint_dir2) if f.startswith('model_checkpoint')], key=os.path.getctime, default=None)
log_filepath2 = 'training_log_ResNet18dicefocal.csv'
        
# Create the CSVLogger
csv_logger2 = CSVLogger(log_filepath2, append=True)

# ModelCheckpoint callback to save the model
model_checkpoint_callback2 = ModelCheckpoint(
    filepath=checkpoint_filepath2,
    save_weights_only=False,
    monitor='val_mean_io_u',
    mode='max',
    save_best_only=False,
    save_freq='epoch'
)

In [None]:

import segmentation_models as sm
from sklearn.model_selection import train_test_split

BACKBONE = 'resnet18'

# load your data
x_train2, x_val2, y_train2, y_val2 = train_test_split(flattened_cropped_zadc_data, flattened_cropped_label_data, test_size=0.2, random_state=42)

# define model
model2 = sm.Unet(BACKBONE, encoder_weights=None, input_shape=(64, 64, 1))

dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

metrics = [sm.metrics.IOUScore(), sm.metrics.FScore()]

model2.compile(
    'Adam',
    loss=total_loss,
    metrics=metrics,
)

history2 = model2.fit(
   x=x_train2,
   y=y_train2,
   batch_size=32,
   epochs=500,
   validation_data=(x_val2, y_val2),
   callbacks=[model_checkpoint_callback2, csv_logger2]
)

In [None]:
# Plot training & validation accuracy, loss, and MeanIoU values
import matplotlib.pyplot as plt

plt.figure(figsize=(16, 6))

plt.subplot(1, 3, 1)
plt.plot(history2.history['loss'])
plt.plot(history2.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Dice Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.subplot(1, 3, 2)
plt.plot(history2.history['iou_score'])
plt.plot(history2.history['val_iou_score'])
plt.title('Model MeanIoU')
plt.ylabel('MeanIoU')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.subplot(1, 3, 3)
plt.plot(history2.history['f1-score'])
plt.plot(history2.history['val_f1-score'])
plt.title('Model F1 Score')
plt.ylabel('F1 Score')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

# Add a main title to the entire figure
plt.suptitle('Model 2 : ResNet18 Dice + Focal Loss', fontsize=16)

# Adjust spacing between subplots
plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust `rect` to make space for the suptitle

# Save the figure
plt.savefig('resnetDiceFocal.png', dpi=300)

plt.show()

In [None]:
#TEST
IMAGE_NUMBER = 230
original_image = flattened_zadc_data_test[IMAGE_NUMBER]
label_image = flattened_label_data_test[IMAGE_NUMBER]
patches = slice_image(original_image)

patches_data = np.array(patches)

# Reshape the array to have shape (num_samples, 128, 128, 1)
patches_data = patches_data.reshape((len(patches), 64, 64, 1))

predictions = model2.predict(patches_data)
predictions = predictions.reshape((len(predictions), 64, 64))
reconstructed_image = stitch_image(predictions, original_image.shape)


plt.figure(figsize=(20, 20))
plt.subplot(1, 3, 1)
plt.imshow(original_image)
plt.title('ZADC image')

plt.subplot(1, 3, 2)
plt.imshow(label_image, cmap="grey")
plt.title('Label image')

plt.subplot(1, 3, 3)
plt.imshow(reconstructed_image, cmap="grey")
plt.title('Predicted image')

# Save the figure
plt.savefig('resnetDiceFocalPrediction.png', dpi=300)

plt.show()

In [None]:
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss()

metrics = [sm.metrics.IOUScore(), sm.metrics.FScore()]

In [None]:
model2.compile(
    'Adam',
    loss=dice_loss,
    metrics=metrics,
)

# Evaluate the model
evaluation = model2.evaluate(flattened_test_cropped_zadc_data, flattened_test_cropped_label_data, batch_size=32)

# Print the results
print("Evaluation results on test data (Dice+Focal):")
print(f"Dice Loss: {evaluation[0]}")
print(f"IOU Score: {evaluation[1]}")
print(f"F1 Score: {evaluation[2]}")

In [None]:
model2.compile(
    'Adam',
    loss=focal_loss,
    metrics=metrics,
)

# Evaluate the model
evaluation = model2.evaluate(flattened_test_cropped_zadc_data, flattened_test_cropped_label_data, batch_size=32)

# Print the results
print("Evaluation results on test data (Dice+Focal):")
print(f"Focal Loss: {evaluation[0]}")
print(f"IOU Score: {evaluation[1]}")
print(f"F1 Score: {evaluation[2]}")

# **Model 3: 0.75 Dice + 0.25 Focal Loss**

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import os
import re

# Define paths
checkpoint_dir3 = '/Epochs/ResNet1875dice25focal'
checkpoint_filepath3 = os.path.join(checkpoint_dir3, 'model_checkpoint_{epoch:02d}.keras')
if not os.path.exists(checkpoint_dir3):
    os.makedirs(checkpoint_dir3)
latest_checkpoint3 = max([os.path.join(checkpoint_dir3, f) for f in os.listdir(checkpoint_dir3) if f.startswith('model_checkpoint')], key=os.path.getctime, default=None)
log_filepath3 = 'training_log_ResNet1875dice25focal.csv'

# Create the CSVLogger
csv_logger3 = CSVLogger(log_filepath3, append=True)

# ModelCheckpoint callback to save the model
model_checkpoint_callback3 = ModelCheckpoint(
    filepath=checkpoint_filepath3,
    save_weights_only=False,
    monitor='val_mean_io_u',
    mode='max',
    save_best_only=False,
    save_freq='epoch'
)

In [None]:

import segmentation_models as sm
from sklearn.model_selection import train_test_split

BACKBONE = 'resnet18'

# load your data
x_train3, x_val3, y_train3, y_val3 = train_test_split(flattened_cropped_zadc_data, flattened_cropped_label_data, test_size=0.2, random_state=42)

# define model
model3 = sm.Unet(BACKBONE, encoder_weights=None, input_shape=(64, 64, 1))

dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss()
total_loss = (0.75 * dice_loss) + (0.25 * focal_loss)

metrics = [sm.metrics.IOUScore(), sm.metrics.FScore()]

model3.compile(
    'Adam',
    loss=total_loss,
    metrics=metrics,
)

history3 = model3.fit(
   x=x_train3,
   y=y_train3,
   batch_size=32,
   epochs=1,
   validation_data=(x_val3, y_val3),
#    callbacks=[model_checkpoint_callback3, csv_logger3]
)

In [None]:
# Plot training & validation accuracy, loss, and MeanIoU values
import matplotlib.pyplot as plt

plt.figure(figsize=(16, 6))

plt.subplot(1, 3, 1)
plt.plot(history3.history['loss'])
plt.plot(history3.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Dice Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.subplot(1, 3, 2)
plt.plot(history3.history['iou_score'])
plt.plot(history3.history['val_iou_score'])
plt.title('Model MeanIoU')
plt.ylabel('MeanIoU')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.subplot(1, 3, 3)
plt.plot(history3.history['f1-score'])
plt.plot(history3.history['val_f1-score'])
plt.title('Model F1 Score')
plt.ylabel('F1 Score')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

# Add a main title to the entire figure
plt.suptitle('Model 3 : ResNet18 0.75 Dice + 0.25 Focal Loss', fontsize=16)

# Adjust spacing between subplots
plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust `rect` to make space for the suptitle

# Save the figure
plt.savefig('resnet75Dice25Focal.png', dpi=300)

plt.show()

In [None]:
#TEST
IMAGE_NUMBER = 230
original_image = flattened_zadc_data_test[IMAGE_NUMBER]
label_image = flattened_label_data_test[IMAGE_NUMBER]
patches = slice_image(original_image)

patches_data = np.array(patches)

# Reshape the array to have shape (num_samples, 128, 128, 1)
patches_data = patches_data.reshape((len(patches), 64, 64, 1))

predictions = model3.predict(patches_data)
predictions = predictions.reshape((len(predictions), 64, 64))
reconstructed_image = stitch_image(predictions, original_image.shape)


plt.figure(figsize=(20, 20))
plt.subplot(1, 3, 1)
plt.imshow(original_image)
plt.title('ZADC image')

plt.subplot(1, 3, 2)
plt.imshow(label_image, cmap="grey")
plt.title('Label image')

plt.subplot(1, 3, 3)
plt.imshow(reconstructed_image, cmap="grey")
plt.title('Predicted image')

# Save the figure
plt.savefig('resnet75Dice25FocalPrediction.png', dpi=300)

plt.show()

In [None]:
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss()

metrics = [sm.metrics.IOUScore(), sm.metrics.FScore()]

In [None]:
model3.compile(
    'Adam',
    loss=dice_loss,
    metrics=metrics,
)

# Evaluate the model
evaluation = model3.evaluate(flattened_test_cropped_zadc_data, flattened_test_cropped_label_data, batch_size=32)

# Print the results
print("Evaluation results on test data (0.75 Dice + 0.25 Focal):")
print(f"Dice Loss: {evaluation[0]}")
print(f"IOU Score: {evaluation[1]}")
print(f"F1 Score: {evaluation[2]}")

In [None]:
model3.compile(
    'Adam',
    loss=focal_loss,
    metrics=metrics,
)

# Evaluate the model
evaluation = model3.evaluate(flattened_test_cropped_zadc_data, flattened_test_cropped_label_data, batch_size=32)

# Print the results
print("Evaluation results on test data (0.75 Dice + 0.25 Focal):")
print(f"Focal Loss: {evaluation[0]}")
print(f"IOU Score: {evaluation[1]}")
print(f"F1 Score: {evaluation[2]}")

# **UI**

In [35]:
import gradio as gr

def predict_mri(image):
    """
    Takes an MRI scan image, processes it using the pre-trained model, and returns the segmented output.
    
    Args:
    image (numpy.ndarray): The uploaded MRI image.
    
    Returns:
    numpy.ndarray: The segmented output image.
    """
    # Assuming image is in the format of (128, 128) as expected by your existing code.
    
    # Slice the image into 64x64 patches
    patches = slice_image(image)
    
    # Prepare the patches for model prediction
    patches_data = np.array(patches).reshape((len(patches), 64, 64, 1))
    
    # Use the model to predict the segmentation of each patch
    predictions = model3.predict(patches_data)
    
    # Reshape predictions for stitching
    predictions = predictions.reshape((len(predictions), 64, 64))
    
    # Stitch the patches back together to form the full segmented image
    segmented_image = stitch_image(predictions, original_shape=image.shape)
    
    return segmented_image

# Define the Gradio interface
demo = gr.Interface(
    fn=predict_mri,  # The function to be used for prediction
    inputs=gr.Image(image_mode="L"),  # Expecting grayscale MRI scans
    outputs=gr.Image(image_mode="L"),  # The output is also a grayscale image
    title="MRI Scan Segmentation"
)

# Launch the interface
if __name__ == "__main__":
    demo.launch()


Running on local URL:  http://127.0.0.1:7862

To create a public link, set `share=True` in `launch()`.
