# 3D U-Net Model Evaluation for Lung CT Image Segmentation

## Overview

This notebook demonstrates the evaluation of a 3D U-Net model used for segmenting lung structures from 3D CT scans.

## Key Features

- **Model Architecture**: Utilizes a 3D U-Net network.
- **Evaluation**: Includes setup for evaluating model performance
- **Disease Map Outputs**: Generates disease map outputs for visualisation


In [None]:
# Parameters

spec = 5
# select test(1) or training(0)
test = 0

batch_size = 1
number_of_epochs = 100
filename = '_defEx_InspMask2' #'_DIV_OK' # _DIV_OK

# save csv metrics and disease maps
save = 0

In [None]:
cd /data-synology/anlee/

In [None]:
ls

In [None]:
# includes & imports
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1" #for CPU ""

import shutil
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage as ndi
#import nilearn as nil
import tensorflow as tf

print(tf.config.list_physical_devices('GPU'))

# random seeds
np.random.seed(16)
tf.random.set_seed(16)
tf.keras.utils.set_random_seed(16)

Build an input pipeline with image paths

In [None]:
# get image pathes for input and target images
from tqdm import tqdm

root_directory = '/data-synology/anlee/COPDGene/'  # Replace with the actual path to your root directory
inputImageName = ('insp_ct_ds.nii') #('insp_ct.nii')
outputImageName = ('exp_ct_deform_ds.nii')  # Extensions of the target image files
maskInspImageName = ('insp_mask_cat.nii')  # Mask image
#maskDefExpImageName = ('exp_mask_deform_cat.nii')  # Mask image


def search_images(directory, image_list, name):
  for root, dirs, files in tqdm(os.walk(directory)):
        for file in files:
            if file.endswith(name):
                image_path = os.path.join(root, file)
                image_list.append(image_path.replace('\0', ''))  # Add the image path to the list, replace termination character

# Create an empty list to store image paths
inputImagePath = []
outputImagePath = []
maskInspImagePath = []
#maskDefExpImagePath = []

# Call the search_image function with the root directory
search_images(root_directory, inputImagePath, inputImageName)
search_images(root_directory, outputImagePath, outputImageName)
search_images(root_directory, maskInspImagePath, maskInspImageName)
#search_images(root_directory, maskDefExpImagePath, maskDefExpImageName)

# check if path loaded
if inputImagePath:
    print("Image input paths:")
    for path in inputImagePath[:5]:
        print(path)
else:
    print("No image files found in the directory tree.")

if outputImagePath:
    print("Image ouput paths:")
    for path in outputImagePath[:5]:
        print(path)
else:
    print("No image files found in the directory tree.")
    
if maskInspImagePath:
    print("Image mask insp paths:")
    for path in maskInspImagePath[:5]:
        print(path)
else:
    print("No image files found in the directory tree.")
    
#if maskDefExpImagePath:
#    print("Image mask defExp paths:")
#    for path in maskDefExpImagePath[:5]:
#        print(path)
#else:
#    print("No image files found in the directory tree.")

In [None]:
from sklearn.model_selection import train_test_split

# Convert image paths to lists
inputImagePath = list(inputImagePath)
outputImagePath = list(outputImagePath)
maskInspImagePath = list(maskInspImagePath)
#maskDefExpImagePath = list(maskDefExpImagePath)

# Split the data into training and test sets list
train_input, test_input, train_output, test_output, train_insp_mask, test_insp_mask = train_test_split(
    inputImagePath, outputImagePath, maskInspImagePath, test_size=0.3, random_state=42)

# Split the data into test and validation sets list
val_input, test_input, val_output, test_output, val_insp_mask, test_insp_mask = train_test_split(
    test_input, test_output, test_insp_mask, test_size=0.5, random_state=42)

print('training data: ' + str(len(train_input)))
print('validation data: ' + str(len(val_input)))
print('test data: ' + str(len(test_input)))
print('train mask inspiratory data: ' + str(len(train_insp_mask)))
print('test mask inspiratory data: ' + str(len(test_insp_mask)))
#print('train mask expiratory data: ' + str(len(train_exp_mask)))
#print('test mask expiratory data: ' + str(len(test_exp_mask)))

if train_input:
    print("Image input paths:")
    for path in train_input[:5]:
        print(path)
else:
    print("No image files found in the directory tree.")

if train_output:
    print("Image ouput paths:")
    for path in train_output[:5]:
        print(path)
else:
    print("No image files found in the directory tree.")
    
if train_insp_mask:
    print("Image mask inspiratory paths:")
    for path in train_insp_mask[:5]:
        print(path)
else:
    print("No image files found in the directory tree.")


In [None]:
# load image function to load images scaled by 3000 (HU)
def load_image_scale(file_path):
    
    # load nibable image
    image_data = nib.load(file_path).get_fdata()

    # pick middle slice
    image = image_data

    # convert values
    image = (image / 3000)

    return image

In [None]:
# load image function to load images not scaled
def load_image(file_path):

    # load nibable image
    image_data = nib.load(file_path).get_fdata()
    
    # pick middle slice
    image = image_data
    
    return image

In [None]:
# load image function to load image mask
def load_image_mask(file_path):

    # load nibable image
    image_data = nib.load(file_path).get_fdata()

    # pick middle slice
    mask = image_data

    # create binary image mask
    binary_mask_slice = np.where((mask > 0) & (mask < 6), 1, 0)
    
    return binary_mask_slice

In [None]:
# select test or train
if test == 1:
    test_input_smaller = test_input[0::128]
    test_output_smaller = test_output[0::128] 
    test_mask_insp_smaller = test_insp_mask[0::128]


    test_dataset_paths = (test_input_smaller, test_output_smaller)
    test_mask_insp_paths = test_mask_insp_smaller

    
else:
    train_dataset_paths = (train_input, train_output)

    test_dataset_paths = (test_input, test_output)
    test_mask_insp_paths = test_insp_mask

# calucalte steps and so on
length = len(test_dataset_paths[0])
print('length test:')
print(length)

# calculate steps per epoch and validation steps
steps_per_epoch = len(test_dataset_paths[0]) // batch_size
print('steps_per_epoch: ' + str(steps_per_epoch))

test_steps = len(test_dataset_paths[0]) // batch_size
print('test_steps: ' + str(test_steps))

# calculate the number of training iterations
number_of_steps_total = int(steps_per_epoch * number_of_epochs)
print('number_of_steps_total: ' + str(number_of_steps_total))

In [None]:
from random import shuffle

# data generator v1
# Output: yield np.array(test_input_images), [np.array(test_output_images),np.array(test_output_mask)]


def data_generator(paths, mask_insp_paths, batch_size):
    test_input_paths, test_output_paths = paths
    
    num_samples = len(test_input_paths)
    #print(num_samples)
    indices = list(range(num_samples))
        
    for i in range(0, num_samples, batch_size):
        batch_indices = indices[i:i+batch_size]

        #print(batch_indices)
        
        test_input_images = []
        test_output_images = []
        test_output_mask = []
            
        for idx in batch_indices:
            
            # load mask
            mask = load_image_mask(mask_insp_paths[idx])
            # load input image
            insp_image = load_image_scale(test_input_paths[idx])
            # mask input image
            input_image = insp_image * mask
            # load deformed exp image
            defexp_image = load_image(test_output_paths[idx])
            # calculate subtraction and mask it
            output_image = defexp_image * mask
            # subtraction image
            diseas_image = (output_image - insp_image)
            diseas_mask = np.where(diseas_image < 1/30 , 1, 0) * mask # check
            
            # append images to input and output train
            test_input_images.append(np.expand_dims(input_image, -1))
            test_output_images.append(np.expand_dims(output_image, -1))
            test_output_mask.append(np.expand_dims(diseas_mask, -1))
            #test_output_mask.append(np.expand_dims(mask, -1))

        
        yield np.array(test_input_images), [np.array(test_output_images),np.array(test_output_mask)]


Image generator


In [None]:
# create data generators for test and validation
test_generator = data_generator(test_dataset_paths, test_mask_insp_paths, batch_size)

In [None]:
gen = data_generator(test_dataset_paths, test_mask_insp_paths, batch_size)
inputs, (outputs, outputs2) = next(gen)

print(inputs.shape, outputs.shape, outputs2.shape)

In [None]:
plt.imshow(np.rot90(outputs2[0,96, :, :,0], 1), cmap='gray')
plt.axis('off')
plt.colorbar(orientation='vertical')

In [None]:
#inp, out = next(test_generator)
#print(inp.shape, out.shape)
print(number_of_steps_total)
print(batch_size)

In [None]:
number_of_steps_total = 638000

In [None]:
def dice_coefficient(y_true, y_pred, threshold=100, scale=3000):
    
    epsilon = 1e-6
    y_true = tf.cast(y_true, tf.float32)
    
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + epsilon
    
    dice = (2. * intersection + epsilon) / union
    return (1 - dice)  # For loss, return 1 - dice to minimize

In [None]:
from tensorflow.keras.models import load_model

model_path = f'/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss_{spec}/'

# Load the model with custom_objects
model = load_model(model_path, custom_objects={
    'dice_coefficient': dice_coefficient})

print('Model loaded successfully.')

In [None]:
# adjust calc functions and why do we predict the mask/??

In [None]:
# Calcualte MSE
def compute_mse(ground_truth, prediction):
    mse = np.mean((ground_truth - prediction) ** 2)
    return mse

In [None]:
# calcualte airtrapping
def compute_airtrapping_per_image(inspiratory, defExp_prediction_image, mask):

    subtraction_image = (defExp_prediction_image - inspiratory) * mask
    
    sub_prediction_image = subtraction_image * 3000

    threshold = 100

    # count values below threshold
    airtrapping_voxels = np.sum((sub_prediction_image < threshold) & (sub_prediction_image != 0))
    
    # count total total values
    total_voxels = np.count_nonzero(sub_prediction_image)

    if total_voxels > 0:
        percentage_airtrapping = (airtrapping_voxels / total_voxels) * 100
    else:
        percentage_airtrapping = 0
    
    return percentage_airtrapping

In [None]:
# Dice score
# input: deformed expiratory, inpsriatory, predicted
def calculate_dice_score(inspiratory, output, masked_predictions, mask):

    threshold = 100
    
    # defExp - inspiratory
    sub_real = (output - inspiratory)
    mt = np.where(sub_real * 3000 < threshold, 1, 0) * mask

    # defExp predict - inspiratory
    sub_fake = (masked_predictions - inspiratory)
    mp = np.where(sub_fake * 3000 < threshold, 1, 0) * mask
    
    dice_value = (2 * np.sum(mt * mp)) / (np.sum(mt + mp))
    
    return dice_value

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

# plot Input image, Predicted image, Predicted mask, and save
def plot_predictions_with_masks(batch_input, image_predictions, mask_predictions, counter, index=0):

    save_dir = f'/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss_{spec}/outputsEx'
    
    input_image = batch_input[index]
    predicted_image = image_predictions[index]
    predicted_mask = mask_predictions[index]

    middle_slice = input_image.shape[0] // 2  

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.title("Input Image")
    plt.imshow(np.rot90(input_image[middle_slice, :, :], 1), cmap='gray')
    plt.axis('off')
    plt.colorbar(orientation='vertical')

    plt.subplot(1, 3, 2)
    plt.title("Predicted Image")
    plt.imshow(np.rot90(predicted_image[middle_slice, :, :], 1), cmap='gray')
    plt.axis('off')
    plt.colorbar(orientation='vertical')

    plt.subplot(1, 3, 3)
    plt.title("Predicted Mask")
    plt.imshow(np.rot90(predicted_mask[middle_slice, :, :], 1), cmap='gray')
    plt.axis('off')
    plt.colorbar(orientation='vertical')

    # Check if save directory exists, if not, create it
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Save the figure
    plt.savefig(os.path.join(save_dir, f'prediction_{counter}.png'), dpi=300)

    # Show the plot
    plt.show()


In [None]:
# variable for counting plots
counter = 0

# plot plot predictions with masks
for batch_input, batch_output in test_generator:
    image_predictions, mask_predictions = model.predict(batch_input)
    plot_predictions_with_masks(batch_input, image_predictions, mask_predictions, counter)
    counter += 1
    if counter == 5:
        break

In [None]:
from random import shuffle

# data generator v2
# Output: yield np.array(test_input_images), np.array(test_output_images),np.array(test_output_mask)

def data_generator2(paths, mask_insp_paths, batch_size):
    test_input_paths, test_output_paths = paths
    
    num_samples = len(test_input_paths)
    #print(num_samples)
    indices = list(range(num_samples))
        
    for i in range(0, num_samples, batch_size):
        batch_indices = indices[i:i+batch_size]

        #print(batch_indices)
        
        test_input_images = []
        test_output_images = []
        test_output_mask = []
            
        for idx in batch_indices:
            
            # load mask
            mask = load_image_mask(mask_insp_paths[idx])
            # load input image
            insp_image = load_image_scale(test_input_paths[idx])
            # mask input image
            input_image = insp_image * mask
            # load deformed exp image
            defexp_image = load_image(test_output_paths[idx])
            # calculate subtraction and mask it
            output_image = defexp_image * mask

            #diseas_image = (output_image - insp_image)
            #diseas_mask = np.where(diseas_image < 1/30 , 1, 0) * mask # check
            
            # append images to input and output train
            test_input_images.append(np.expand_dims(input_image, -1))
            test_output_images.append(np.expand_dims(output_image, -1))
            #test_output_mask.append(np.expand_dims(diseas_mask, -1))
            test_output_mask.append(np.expand_dims(mask, -1))

                yield np.array(test_input_images), [np.array(test_output_images),np.array(test_output_mask)]

        yield np.array(test_input_images), np.array(test_output_images),np.array(test_output_mask)



In [None]:
# create data generator 2
test_generator2 = data_generator2(test_dataset_paths, test_mask_insp_paths, batch_size)

In [None]:
# Initialize placeholders for statistics

mse_values = []
airtrapping_percentages_fake = []
airtrapping_percentages_real = []
dice_scores = []


for batch_input, batch_output, batch_masks in tqdm(test_generator2, total=test_steps, desc="Processing batches"):

    # predict
    #predictions = model(batch_input, training=False)

    predictions, mask_predictionsX = model(batch_input, training=False)

    masked_predictions = []
    
    # maske predictions
    for prediction, mask in zip(predictions,batch_masks):
        mask_pred = prediction * mask
        masked_predictions.append(mask_pred)

    # calculate MSE
    for defExp, prediction in zip(batch_output, masked_predictions):
        mse_values.append(compute_mse(defExp, prediction))

    # Iterate through the lists and calculate airtrapping percentages
    for inspiratory, prediction, mask in zip(batch_input, masked_predictions, batch_masks):
        airtrapping_percentages_fake.append(compute_airtrapping_per_image(inspiratory, prediction, mask))

    # Iterate through the lists and calculate airtrapping percentages
    for inspiratory, subtraction, mask in zip(batch_input, batch_output, batch_masks):
        airtrapping_percentages_real.append(compute_airtrapping_per_image(inspiratory, subtraction, mask))

    # Iterate through the lists and calculate dice scores
    for inspiratory, output, predictions, mask in zip(batch_input, batch_output, masked_predictions, batch_masks):
        dice_scores.append(calculate_dice_score(inspiratory, output, predictions, mask))

In [None]:
# create data generator 3
test_generator3 = data_generator2(test_dataset_paths, test_mask_insp_paths, batch_size)

In [None]:
from tqdm import tqdm

# Initialize a list to store overall air trapping information for 3D images
overall_info_3d_images = []

for batch_idx, (batch_input, batch_output, batch_masks) in enumerate(tqdm(test_generator3, total=test_steps, desc="Processing batches")):
    predictions, mask_predictions = model(batch_input, training=False)
    
    for item_idx in range(batch_input.shape[0]):  # Iterate through each 3D image in the batch
        
        # Compute air trapping for the entire 3D image (all slices)
        overall_true_airtrapping = compute_airtrapping_per_image(batch_input[item_idx], batch_output[item_idx], batch_masks[item_idx])
        overall_predicted_airtrapping = compute_airtrapping_per_image(batch_input[item_idx], predictions[item_idx], batch_masks[item_idx])

        # Initialize variables to find the worst slice
        worst_difference = 0
        worst_slice_idx = None
        worst_true_airtrapping_slice = 0
        worst_predicted_airtrapping_slice = 0

        # Iterate over each slice to find the one with the worst air trapping difference
        for slice_idx in range(batch_input.shape[1]):  
            inspiratory_slice = batch_input[item_idx, slice_idx, :, :]
            prediction_slice = predictions[item_idx, slice_idx, :, :]
            mask_slice = batch_masks[item_idx, slice_idx, :, :]

            # Calculate air trapping for each slice
            true_airtrapping_slice = compute_airtrapping_per_image(inspiratory_slice, batch_output[item_idx, slice_idx, :, :], mask_slice)
            predicted_airtrapping_slice = compute_airtrapping_per_image(inspiratory_slice, prediction_slice, mask_slice)

            # Calculate the difference in air trapping between true and prediction for the slice
            difference = abs(true_airtrapping_slice - predicted_airtrapping_slice)

            # Update worst slice if current slice difference is greater
            if difference > worst_difference:
                worst_difference = difference
                worst_slice_idx = slice_idx
                worst_true_airtrapping_slice = true_airtrapping_slice
                worst_predicted_airtrapping_slice = predicted_airtrapping_slice

        # Store information for the 3D image including overall air trapping and the worst slice
        overall_info_3d_images.append({
            'batch_idx': batch_idx,
            'item_idx': item_idx,
            'overall_true_airtrapping': overall_true_airtrapping,
            'overall_predicted_airtrapping': overall_predicted_airtrapping,
            'worst_slice_idx': worst_slice_idx,
            'worst_difference': worst_difference,
            'worst_true_airtrapping_slice': worst_true_airtrapping_slice,
            'worst_predicted_airtrapping_slice': worst_predicted_airtrapping_slice,
        })


In [None]:
# Calculate the overall difference between true and predicted air trapping for each 3D image
for image_detail in overall_info_3d_images:
    overall_true = image_detail.get('overall_true_airtrapping', 0)
    overall_predicted = image_detail.get('overall_predicted_airtrapping', 0)
    overall_difference = abs(overall_true - overall_predicted)
    image_detail['overall_difference'] = overall_difference

# Sort images by their overall difference in descending order
sorted_images_by_overall_difference = sorted(overall_info_3d_images, key=lambda x: x['overall_difference'], reverse=True)

# Extract the top images with the largest overall difference
images_with_largest_overall_difference = sorted_images_by_overall_difference[:20]  # Adjust the number as needed

# Report the images with the largest overall difference
for image in images_with_largest_overall_difference:
    print(f"Batch {image['batch_idx']}, Image {image['item_idx']}:")
    print(f"  Overall True Air Trapping: {image.get('overall_true_airtrapping', 'N/A')}%")
    print(f"  Overall Predicted Air Trapping: {image.get('overall_predicted_airtrapping', 'N/A')}%")
    print(f"  Worst Slice Index: {image.get('worst_slice_idx', 'N/A')}")
    print(f"  Worst Slice True Air Trapping: {image.get('worst_true_airtrapping_slice', 'N/A')}%")
    print(f"  Worst Slice Predicted Air Trapping: {image.get('worst_predicted_airtrapping_slice', 'N/A')}%")
    print(f"  Overall Difference: {image.get('overall_difference', 'N/A')}%\n")


In [None]:
desired_batches = set(img['batch_idx'] for img in images_with_largest_overall_difference)
print(desired_batches)

In [None]:
test_generator4 = data_generator2(test_dataset_paths, test_mask_insp_paths, batch_size)

In [None]:
#for image in images_with_largest_overall_difference:
#    print(f"Batch {image['batch_idx']}, Image {image['item_idx']}:")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import gaussian_filter
from matplotlib import cm
from matplotlib.colors import ListedColormap

# Initialize the color maps
insp_map = ListedColormap(cm.get_cmap('Reds', 256)(np.linspace(1, 0, 256)))
diff_map = ListedColormap(cm.get_cmap('Blues', 256)(np.linspace(1, 0, 256)))

save_dir = (f"/home/tkeller/google-drive/LossPlots/Worst_Predcitions_DiceLoss_{spec}/")
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Reset the generator and set the current batch index
current_batch_idx = -1

for batch_input, batch_output, batch_masks in test_generator4:
    current_batch_idx += 1

    if current_batch_idx not in desired_batches:
        continue

    filtered_images = [img for img in images_with_largest_overall_difference if img['batch_idx'] == current_batch_idx]

    for image_details in filtered_images:
        batch_idx = image_details['batch_idx']  # Use 'image_details' instead of 'image_detail'
        item_idx = image_details['item_idx']
        worst_slice_idx = image_details['worst_slice_idx']
        worst_true_airtrapping = image_details['worst_true_airtrapping_slice']
        true_airtrapping_overall = image_details['overall_true_airtrapping']
        worst_predicted_airtrapping = image_details['worst_predicted_airtrapping_slice']   
        predicted_airtrapping_overall = image_details['overall_predicted_airtrapping']
        predictions, mask_pred = model.predict(batch_input)

        # Processing images and masks for visualization
        insp_ct_ds = np.rot90(np.squeeze(batch_input[item_idx, worst_slice_idx, :, :] * 3000))
        prediction_sub = np.rot90(np.squeeze(predictions[item_idx, worst_slice_idx, :, :] - batch_input[item_idx, worst_slice_idx, :, :]) * 3000)
        insp_whole = np.rot90(np.squeeze(batch_masks[item_idx, worst_slice_idx, :, :]))
        mask_pred = np.rot90(np.squeeze(mask_pred[item_idx, worst_slice_idx, :, :]))

        
        # Apply Gaussian filter to smooth the images for better visualization
        insp_slc_mask_smooth = gaussian_filter(insp_ct_ds, sigma=1) * insp_whole
        diff_slc_pred_mask_smooth = np.abs(gaussian_filter(prediction_sub, sigma=1)) * insp_whole

        plt.figure(figsize=(24, 10))

        # Display the ground truth image with airtrapping percentage
        plt.subplot(2, 3, 1)
        plt.imshow(insp_ct_ds, cmap='gray', vmin=-400-1500, vmax=-400+1500)
        plt.imshow(np.ma.masked_where(insp_whole == 0, insp_slc_mask_smooth), cmap=insp_map, alpha=0.5)
        plt.title(f"Ground Truth - Batch: {current_batch_idx}, Slice: {worst_slice_idx}\nTrue Air Trapping: {worst_true_airtrapping:.2f}%", fontsize=12)
        plt.axis('off')
        plt.colorbar(orientation='vertical')


        # Display the prediction image with airtrapping percentage
        plt.subplot(2, 3, 2)
        plt.imshow(prediction_sub, cmap='gray', vmin=-400-1500, vmax=-400+1500)
        plt.imshow(np.ma.masked_where(insp_whole == 0, diff_slc_pred_mask_smooth), cmap=diff_map, alpha=0.5)
        plt.title(f"Prediction - Batch: {current_batch_idx}, Slice: {worst_slice_idx}\nPredicted Air Trapping: {worst_predicted_airtrapping:.2f}%", fontsize=12)
        plt.axis('off')
        plt.colorbar(orientation='vertical')


        # Mask only
        plt.subplot(2, 3, 3)
        plt.imshow(mask_pred, cmap='jet')
        plt.title(f"Mask Only - Batch: {current_batch_idx}, Slice: {worst_slice_idx}", fontsize=12)
        plt.axis('off')
        plt.colorbar(orientation='vertical')


        # Histogram for the ground truth pixel values
        plt.subplot(2, 3, 4)
        gt_values = insp_ct_ds[insp_whole != 0].flatten()
        plt.hist(gt_values, bins=50, color='purple', alpha=0.7)
        plt.axvline(x=100, color='k', linestyle='--', label='100 HU')  
        plt.axvline(x=-860, color='k', linestyle='--', label='-860 HU')  
    
        plt.title("Ground Truth Histogram", fontsize=12)

        # Histogram for the prediction pixel values
        plt.subplot(2, 3, 5)
        pred_values = prediction_sub[insp_whole != 0].flatten()
        plt.hist(pred_values, bins=50, color='purple', alpha=0.7)
        plt.axvline(x=100, color='k', linestyle='--', label='100 HU') 
        plt.axvline(x=-860, color='k', linestyle='--', label='-860 HU') 
        plt.title("Prediction Histogram", fontsize=12)

        plt.subplot(2, 3, 6)
        plt.text(0.5, 0.5, f"Overall True Air Trapping: {true_airtrapping_overall:.2f}%\n"
                           f"Overall Predicted Air Trapping: {predicted_airtrapping_overall:.2f}%",
                        ha='center', va='center', fontsize=20)
        plt.axis('off')

        plt.tight_layout()

        # Construct the filename
        filename = f"Batch_{current_batch_idx}_Slice_{worst_slice_idx}.png"
        save_path = os.path.join(save_dir, filename)
    
        # Save the figure
        plt.savefig(save_path)
    
        plt.show()
        plt.close()

In [None]:
for idx, mse in enumerate(mse_values[:5]):
    print(f"MSE at position {idx}:", mse)

In [None]:
for idx, airtrapping in enumerate(airtrapping_percentages_fake[:5]):
    print(f"Airtrapping percentage for Prediciton {idx}: {airtrapping}%")

In [None]:
for idx, airtrapping in enumerate(airtrapping_percentages_real[:5]):
    print(f"Airtrapping percentage for Subtraction {idx}: {airtrapping}%")

In [None]:
for idx, dicescore in enumerate(dice_scores[:5]):
    print(f"Airtrapping percentage for Subtraction {idx}: {dicescore}%")

In [None]:
import pandas as pd
df_metrics = pd.DataFrame(columns=["Predicted_Image", "MSE_Dice", "Air-trp 1 [%]", "Air-trp 2 [%]", "Dice Score"])

In [None]:
# add values to metric dataframe
for i, (mse, airtrapping, airtrapping_check, dice) in enumerate(zip(mse_values, airtrapping_percentages_fake, airtrapping_percentages_real, dice_scores)):
    df_metrics.loc[i] = [i, mse, airtrapping_check, airtrapping, dice]

In [None]:
print(df_metrics[:50])

In [None]:
if save == 1:
    try:
        os.mkdir(f"/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss_{spec}/test/")
        os.mkdir(f"/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss_{spec}/test/metrics/")
        os.mkdir(f"/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss_{spec}/test/plots/")
    
    except Exception as e:
        print("An error occurred:", str(e))

In [None]:
# save to csv
if save == 1:
    df_metrics.to_csv(f'/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss_{spec}/test/metrics/metrics.csv', index=False)

In [None]:
# plot metrics of df
mean_mse = df_metrics['MSE_Dice'].mean()
std_mse = df_metrics['MSE_Dice'].std()

mean_dice = df_metrics['Dice Score'].mean()
std_dice = df_metrics['Dice Score'].std()

# Represent in the format "mean ± SD"
mse_str = f"{mean_mse:.2e} ± {std_mse:.2e}"
dice_str = f"{mean_dice:.2f} ± {std_dice:.2f}"

print(f'MSE: {mse_str}')
print(f'Dice Score: {dice_str}')

In [None]:
import seaborn as sns

# Set the style and size
sns.set_style("whitegrid")
plt.figure(figsize=(10, 8))

# Scatter plot
sns.scatterplot(x=df_metrics["Air-trp 1 [%]"], y=df_metrics["Air-trp 2 [%]"], alpha=0.6)

# Title and labels
plt.title("Correlation between Air-trapping 1 (True) and Air-trapping 2 (Prediction)", weight='bold')
plt.xlabel("Air-trapping 1 [%] (True)")
plt.ylabel("Air-trapping 2 [%] (Prediction)")

# Adding the identity line
limits = [min(plt.xlim()[0], plt.ylim()[0]), max(plt.xlim()[1], plt.ylim()[1])]
plt.plot(limits, limits, 'r-', label='Identity Line')
plt.legend()
if save == 1:
    plt.savefig(f"/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss_{spec}/test/plots/Correlation_nsteps{number_of_steps_total}_batch{batch_size}{filename}.png")
plt.show()

In [None]:
# The Bland-Altman plot, also known as a difference plot, is used to visualize the agreement
# between two methods or two measurements. It plots the difference between the two measures
# against their average.

# Bland-Altman data
average = (df_metrics["Air-trp 1 [%]"] + df_metrics["Air-trp 2 [%]"]) / 2
difference = df_metrics["Air-trp 1 [%]"] - df_metrics["Air-trp 2 [%]"]

mean_diff = difference.mean()
std_diff = difference.std()

# Set the style and size
sns.set_style("whitegrid")
plt.figure(figsize=(10, 8))

# Scatter plot
sns.scatterplot(x=average, y=difference, alpha=0.6)

# Add mean and limits of agreement lines
plt.axhline(mean_diff, color='red', linestyle='--', label=f'Mean diff: {mean_diff: .2f}')
plt.axhline(mean_diff + 1.96*std_diff, color='blue', linestyle='--', label='Mean diff + 1.96*SD')  #95%
plt.axhline(mean_diff - 1.96*std_diff, color='blue', linestyle='--', label='Mean diff - 1.96*SD')

# Adding mean values next to the lines using plt.text
x_position = max(average)
plt.text(x_position, mean_diff + 1.96*std_diff, f'+1.96 SD: {mean_diff + 1.96*std_diff:.2f}', verticalalignment='bottom', horizontalalignment='right', color='blue')
plt.text(x_position, mean_diff - 1.96*std_diff, f'-1.96 SD: {mean_diff - 1.96*std_diff:.2f}', verticalalignment='bottom', horizontalalignment='right', color='blue')

# Title and labels
plt.title("Bland-Altman Plot between Air-trapping 1 (True) and Air-trapping 2 (Prediction)", weight='bold')
plt.xlabel("Average of Air-trapping 1 [%] (True) and Air-trapping 2 [%] (Prediction)")
plt.ylabel("Difference between Air-trapping 1 [%] (True) and Air-trapping 2 [%] (Prediction)")
plt.legend()

if save == 1:
    plt.savefig(f"/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss/test/plots/Bland_Altman_nsteps{number_of_steps_total}_batch{batch_size}{filename}.png")
    
plt.show()


In [None]:
# ICC
import pingouin as pg

# Reshape dataframe for ICC calculation
df_melted = df_metrics.melt(id_vars=['Predicted_Image'], value_vars=['Air-trp 1 [%]', 'Air-trp 2 [%]'], 
                    var_name='Method', value_name='Measurement')

# Calculate ICC
icc = pg.intraclass_corr(data=df_melted, targets='Predicted_Image', raters='Method', ratings='Measurement').round(3)

print(icc)

# Print the ICC value for ICC1
print(icc[icc['Type'] == 'ICC1']['ICC'].values[0])

icc_val = float(f"{icc[icc['Type'] == 'ICC1']['ICC'].values[0]:.2f}")


In [None]:
# table of final statistics

In [None]:
print(f'MSE: {mse_str}')
print(f'ICC: {icc_val}')
print(f'Dice Score: {dice_str}')

In [None]:
results_df = pd.DataFrame({
    'Metric': ['MSE_Dice', 'ICC', 'Dice Score'],
    'Value': [mse_str, icc_val, dice_str]
})

print(results_df)

In [None]:
# filter worst cases
df_sorted_asce = df_metrics.sort_values(by='Dice Score', ascending=True)
print(df_sorted_asce[:20])

In [None]:
# create array with worst 20 cases wioth id
predicted_image_ids = np.array(df_sorted_asce['Predicted_Image'][:50])

# Now predicted_image_ids contains the IDs of interest
print(predicted_image_ids)

In [None]:
if save == 1:
    results_df.to_csv(f'/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss_{spec}/test/model_results.csv', index=False)

In [None]:
# save only disease map
save2 = 1

In [None]:
if save2 == 1:
    try:
        os.mkdir(f'/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss_{spec}/DiseaseMap/')
    except:
        print('Couldnt create directory')

In [None]:
test_generator3 = data_generator2(test_dataset_paths, test_mask_insp_paths, batch_size)

In [None]:
from scipy.ndimage import gaussian_filter
from matplotlib import cm
from matplotlib.colors import ListedColormap

desired_indices = [1,2,3,4,5]  # Specify the indices of the images you want
count = 0  # Initialize the counter

# Iterate through the generator
for batch_input, batch_output, batch_masks in test_generator3:
    #print(f'Batch {count} processed')
    for element in desired_indices:
        int_element = np.floor(element).astype(int)
        print(f'Processing index {int_element}')
        
        predictions, mask2 = model(batch_input, training=False)
        masked_predictions = []

        for prediction, mask in zip(predictions, batch_masks):
            mask_pred = prediction * mask
            masked_predictions.append(mask_pred)

        # Process and plot the first image in each batch (or adjust as needed)
        j = 0  # first image of each batch
    
        # Select images for creating disease map
        insp_ct_ds = (batch_input[j] * 3000)[96] # masked_inspiratory[j] * 3000
        exp_ct_deform_ds = (batch_output[j] - batch_input[j])[96] # masked_subtraction[j]
        prediction_sub = (masked_predictions[j] - batch_input[j])[96]# masked_predictions[j]
        insp_whole = (batch_masks[j])[96] # np.where((test_mask_images[j] > 0) & (test_mask_images[j] < 6), 1, 0)
        
        # Remove the singular third dimension
        insp_ct_ds_2d = np.rot90(np.squeeze(insp_ct_ds), k=-1)
        exp_ct_deform_ds_2d = np.rot90(np.squeeze(exp_ct_deform_ds), k=-1)
        sub_ct_deform_pred_2d = np.rot90(np.squeeze(prediction_sub), k=-1)
        insp_whole_2d = np.rot90(np.squeeze(insp_whole), k=-1)
        
        # Calculate difference
        insp_slc_mask_smooth = gaussian_filter(insp_ct_ds_2d, sigma=1) * insp_whole_2d
        diff_slc_mask_smooth = np.abs(gaussian_filter(exp_ct_deform_ds_2d * 3000, sigma=1)) * insp_whole_2d
        diff_slc_pred_mask_smooth = np.abs(gaussian_filter(sub_ct_deform_pred_2d * 3000, sigma=1)) * insp_whole_2d
        
        # Flip images
        insp_slc = np.flipud(np.fliplr(insp_ct_ds_2d))
        insp_mask = np.flipud(np.fliplr(insp_whole_2d))
        insp_slc_mask = np.flipud(np.fliplr(insp_slc_mask_smooth))
        diff_slc_pred_mask = np.flipud(np.fliplr(diff_slc_pred_mask_smooth))  # prediction
        diff_slc_mask = np.flipud(np.fliplr(diff_slc_mask_smooth))  # ground truth

        # Red parts with emphysema
        insp_map = cm.get_cmap('Reds', 512)(np.linspace(1,0,512))
        insp_map[:,3] = np.linspace(1,0,512)
        insp_map[:,0:3] = insp_map[256,0:3]
        insp_map=ListedColormap(insp_map)
        
        # Blue parts with air-trapping
        diff_map = cm.get_cmap('Blues', 512)(np.linspace(1,0,512))
        diff_map[:,3] = np.linspace(1,0,512)
        diff_map[:,0:3] = diff_map[128,0:3]
        diff_map=ListedColormap(diff_map)
        
        # Create the figure
        plt.figure(figsize=(20, 10))
        
        # Display the first image on the left side
        plt.subplot(1, 2, 1)
        plt.imshow(insp_slc, cmap='gray', vmin=-400-1500, vmax=-400+1500, origin='lower')
        plt.imshow(np.ma.masked_where(insp_mask == 0, diff_slc_mask), cmap=diff_map, vmin=0, vmax=100, alpha=0.5)
        plt.imshow(np.ma.masked_where(insp_mask == 0, insp_slc_mask), cmap=insp_map, vmin=-1000, vmax=-925, alpha=0.5)
        plt.title(f"Ground-Truth Index {int_element}", fontweight='bold', fontsize=20)
        plt.axis('off')
        
        # Display the second image on the right side
        plt.subplot(1, 2, 2)
        plt.imshow(insp_slc, cmap='gray', vmin=-400-1500, vmax=-400+1500, origin='lower')
        plt.imshow(np.ma.masked_where(insp_mask == 0, diff_slc_pred_mask), cmap=diff_map, vmin=0, vmax=100, alpha=0.5)
        plt.imshow(np.ma.masked_where(insp_mask == 0, insp_slc_mask), cmap=insp_map, vmin=-1000, vmax=-925, alpha=0.5)
        plt.title(f"Prediction Index {int_element}", fontweight='bold', fontsize=20)
        plt.axis('off')
        
        plt.tight_layout()
        if save2 == 1:
            plt.savefig(f'/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss_{spec}/DiseaseMap/DiseaseMap_defEx_{count}.png')
        plt.show()
        plt.close()
        #print(int_element

    # Increment the counter after each batch
    count += 1 

    if count > len(desired_indices):
        break  # Exit loop after the highest index is surpassed

In [None]:
# import nibabel as nib
# aff = np.array([[0,-1,0,0],[-1,0,0,0],[0,0,1,0],[0,0,0,1]])
# img = nib.Nifti1Image(np.float32(img), aff)
# nib.save(img, 'filename.nii')

In [None]:
test_generator3 = data_generator(test_dataset_paths, test_mask_insp_paths, batch_size)

In [None]:
# extract a single batch from the generator
try:
    batch_input, batch_output, batch_masks = next(test_generator3)
    # Generate predictions for this batch
    print(f"Batch shapes -- Input: {batch_input.shape}, Output: {batch_output.shape}, Masks: {batch_masks.shape}, Predictions: {predictions.shape}")
except Exception as e:
    print(f"Error extracting batch from generator or generating predictions: {e}")


In [None]:
# save a single 3D image as a NIfTI file
def save_nifti_with_affine(batch_input, batch_output, predictions, mask, index, output_dir):
    """
    Saves 3D images with a specific affine transformation as NIfTI files.

    Parameters:
    - batch_input: 4D numpy array of input images [batch_size, depth, height, width].
    - batch_output: 4D numpy array of output images [batch_size, depth, height, width].
    - predictions: 4D numpy array of model predictions [batch_size, depth, height, width].
    - mask: 4D numpy array of masks [batch_size, depth, height, width].
    - index: int, index of the image in the batch to save.
    - output_dir: str, directory to save the NIfTI files.
    """
    # Affine matrix for flipping x and y axes
    affine = np.array([[0, -1, 0, 0],
                       [-1, 0, 0, 0],
                       [0, 0, 1, 0],
                       [0, 0, 0, 1]])

    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Function to save a single 3D image as a NIfTI file
    def save_image(data, filename):
        # Convert data to float32 as recommended
        data_float32 = np.float32(data)
        # Create and save the NIfTI image
        img = nib.Nifti1Image(data_float32, affine)
        nib.save(img, filename)

    # Save each image
    save_image(batch_input[index], os.path.join(output_dir, f'input_{index}.nii'))
    save_image(batch_output[index] - batch_input[index], os.path.join(output_dir, f'output_deformation_{index}.nii'))
    save_image(predictions[index] - batch_input[index], os.path.join(output_dir, f'prediction_deformation_{index}.nii'))
    save_image(mask[index], os.path.join(output_dir, f'mask_{index}.nii'))


In [None]:
import nibabel as nib

# Specify the index of the image you want to save, 0 for the first image
batch_index = 0

# Specify your output directory here, ensure it exists or you have permissions to write to it
output_dir = f'/data-synology/tkeller/Outputs/3D_nsteps{number_of_steps_total}_batch{batch_size}_DiceLoss/DiseaseMap/'

# Now call the function with the first image data
save_nifti_with_affine(batch_input, batch_output, predictions, batch_masks, batch_index, output_dir)