In [None]:
import os
import re
import pickle
import random
import pandas as pd
import gzip
import shutil
import numpy as np
import imageio
import imgaug as ia
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from sklearn.utils import shuffle

from collections import defaultdict
from sortedcontainers import SortedList
from PIL import Image, ImageColor

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input,
    Conv2D,
    MaxPooling2D,
    UpSampling2D,
    Flatten,
    Dense,
    Dropout,
    BatchNormalization,
    Concatenate,
    Activation,
    GlobalAveragePooling2D,
    Reshape,
    Lambda,
    RandomRotation
)
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import load_model

# File Loading

In [None]:
import os
import pickle
import requests

# Define file paths
BASE_PATH = "data"  # Store in 'data' folder
FILENAME = "image_dicts_256_wgrayscale_andcutoffs.pkl"
FILE_PATH = os.path.join(BASE_PATH, FILENAME)
EXCEL_FILE_PATH = os.path.join(BASE_PATH, "sample_groups.xlsx")

# GitHub release URL
URL = "https://github.com/tylervasse/DOCI-Prediction/releases/download/v1.0/image_dicts_256_wgrayscale_andcutoffs.pkl"

def download_file(url, output_path):
    """
    Downloads a file from a URL if it doesn't already exist.
    """
    if not os.path.exists(output_path):
        print(f"Downloading {output_path}...")
        response = requests.get(url, stream=True)
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, "wb") as file:
            for chunk in response.iter_content(chunk_size=8192):
                file.write(chunk)
        print("Download complete.")
    else:
        print(f"File already exists at {output_path}")

def load_image_dicts(file_path):
    """
    Loads image dictionary data from a given file.
    Parameters:
        file_path (str): Path to the pickle file containing image dictionaries.
    Returns:
        list: A list of dictionaries containing image metadata.
    """
    try:
        with open(file_path, "rb") as file:
            return pickle.load(file)
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        return []
    except Exception as e:
        print(f"Error loading file: {e}")
        return []

# Ensure the file is downloaded before loading
download_file(URL, FILE_PATH)

# Load image dictionaries
image_dicts = load_image_dicts(FILE_PATH)

In [None]:
exclude_list = ["SSW-23-14395_C2", "SSW-23-05363_A7"]

for i in image_dicts:
    image_dicts2 = [i for i in image_dicts if not any(exclude_word in i["name"] for exclude_word in exclude_list)]

image_dicts = image_dicts2

EXCEL_FILE_PATH = 'data/sample_groups.xlsx'  # Ensure this file is in the correct directory

def split_data(sample_names, train_ratio=0.6, val_ratio=0.4):
    random.shuffle(sample_names)

    # Calculate split indices
    total_samples = len(sample_names)
    train_end = int(total_samples * train_ratio)
    val_end = train_end + int(total_samples * val_ratio)

    # Create sets
    train_samples = sample_names[:train_end]
    val_samples = sample_names[train_end:val_end]
    test_samples = sample_names[val_end:]

    return train_samples, val_samples, test_samples

def load_sample_groups(excel_file_path):
    """
    Loads training, validation, and test sample groups from an Excel file.
    Parameters:
        excel_file_path (str): Path to the Excel file containing sample groups.
    Returns:
        tuple: Lists of training, validation, and test sample names.
    """
    
    try:
        groups_df = pd.read_excel(excel_file_path)
        train_samples = groups_df['Train Samples'].dropna().tolist()
        train_samples = [str(sample).strip().replace("'", "").replace(" ", "_") for sample in train_samples]
        val_samples = groups_df['Validation Samples'].dropna().tolist()
        val_samples = [str(sample).strip().replace("'", "").replace(" ", "_") for sample in val_samples]
        test_samples = groups_df['Test Samples'].dropna().tolist()
        test_samples = [str(sample).strip().replace("'", "").replace(" ", "_") for sample in test_samples]
        
        return train_samples, val_samples, test_samples
    except FileNotFoundError:
        print(f"Error: Sample groups file not found at {excel_file_path}")
        return [], [], []
    except Exception as e:
        print(f"Error reading Excel file: {e}")
        return [], [], []

# Load sample groups
train_samples, val_samples, test_samples = load_sample_groups(EXCEL_FILE_PATH)

In [None]:
# Load PCA categories from Excel
def load_pca_categories(file_path):
    """
    Load PCA categories from an Excel file and return as lists.
    Parameters:
        file_path (str): Path to the Excel file containing PCA categories.
    Returns:
        tuple: Lists of PCA categories for follicular, papillary, and normal samples.
    """
    
    try:
        pca_df = pd.read_excel(file_path)
        pca_follicular = pca_df['PCA_follicular'].dropna().tolist()
        pca_follicular = [str(sample).strip().replace("'", "") for sample in pca_follicular]
        pca_papillary = pca_df['PCA_papillary'].dropna().tolist()
        pca_papillary = [str(sample).strip().replace("'", "") for sample in pca_papillary]
        pca_normal = pca_df['PCA_normal'].dropna().tolist()
        pca_normal = [str(sample).strip().replace("'", "") for sample in pca_normal]
        return pca_follicular, pca_papillary, pca_normal
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        return [], [], []
    except Exception as e:
        print(f"Error loading PCA categories: {e}")
        return [], [], []

# Replace the path with the actual file path to your Excel file
pca_file_path = 'data/pca_results.xlsx'
pca_follicular, pca_papillary, pca_normal = load_pca_categories(pca_file_path)

def categorize_images(image_data, train_samples, val_samples, test_samples):
    """
    Categorizes images into training, validation, and test sets based on sample names.
    Parameters:
        image_data (list): List of dictionaries containing image metadata.
        train_samples (list): List of sample names designated for training.
        val_samples (list): List of sample names designated for validation.
        test_samples (list): List of sample names designated for testing.
    Returns:
        tuple: Lists of categorized image data for training, validation, and testing.
    """
    
    train_set, val_set, test_set = [], [], []
    for data in image_data:
        sample_name = data['name'].split('_')[0] + "_" + data['name'].split('_')[1]
        if sample_name in train_samples:
            train_set.append(data)
        elif sample_name in val_samples:
            val_set.append(data)
        elif sample_name in test_samples:
            test_set.append(data)
    return train_set, val_set, test_set

# Function to add PCA predictions
def add_pca_predictions(rel_set, pca_follicular, pca_papillary, pca_normal):
    """
    Adds PCA predictions as a 'tissue_type' key to the image metadata.
    Parameters:
        rel_set (list): List of dictionaries containing image metadata.
        pca_follicular (list): List of follicular sample names.
        pca_papillary (list): List of papillary sample names.
        pca_normal (list): List of normal sample names.
    Returns:
        list: Updated list with added 'tissue_type' key.
    """
    
    new_set = []
    for i in rel_set:
        name = i["name"].split("_")[0] + "_" + i["name"].split("_")[1]
        if name in pca_follicular:
            i["tissue_type"] = "Follicular"
        elif name in pca_papillary:
            i["tissue_type"] = "Papillary"
        elif name in pca_normal:
            i["tissue_type"] = "Normal"
        else:
            i["tissue_type"] = "Unknown"
        new_set.append(i)
    return new_set

# Function to filter out specific tissue types
def filter_out_relevant_predictions(rel_set, rel_cat):
    """
    Filters out samples that match a specific tissue type.
    Parameters:
        rel_set (list): List of dictionaries containing image metadata with tissue types.
        rel_cat (str): The tissue type to filter out.
    Returns:
        list: Filtered list excluding the specified tissue type.
    """
    
    return [i for i in rel_set if i["tissue_type"] != rel_cat]

def include_filters(set1, include_list):
    include_list = [f"DOCI_{str(i)}.tif" for i in include_list]
    return [i for i in set1 if any(include_word in i["name"] for include_word in include_list)]

# Categorize images into training, validation, and test sets
train_set, val_set, test_set = categorize_images(image_dicts, train_samples, val_samples, test_samples)

# Add PCA predictions
train_combined_with_predictions2 = add_pca_predictions(train_set, pca_follicular, pca_papillary, pca_normal)
val_combined_with_predictions2 = add_pca_predictions(val_set, pca_follicular, pca_papillary, pca_normal)
test_combined_with_predictions2 = add_pca_predictions(test_set, pca_follicular, pca_papillary, pca_normal)

# Filter out 'Follicular' samples
train_set = filter_out_relevant_predictions(train_combined_with_predictions2, "Papillary")
val_set = filter_out_relevant_predictions(val_combined_with_predictions2, "Papillary")
test_set = filter_out_relevant_predictions(test_combined_with_predictions2, "Papillary")

#include_list = [2, 3, 4, 5, 6, 7, 8, 12, 16, 18, 21]
#train_set = include_filters(train_set, include_list)
#val_set = include_filters(val_set, include_list)
#test_set = include_filters(test_set, include_list)

# Shuffle the datasets
train_set = shuffle(train_set, random_state=42)
val_set = shuffle(val_set, random_state=42)
test_set = shuffle(test_set, random_state=42)

In [None]:
# Function to extract the base name before "DOCI_n"
def get_base_name(name):
    """
    Extracts the base name from a given file name before the "DOCI_n" part.
    Parameters:
        name (str): The full file name.
    Returns:
        str: The base name extracted from the file name.
    """
    
    return name.split('_DOCI')[0]

# Function to extract the DOCI number (n) from the name
def get_doci_number(name):
    """
    Extracts the DOCI number (n) from a given file name.
    Parameters:
        name (str): The full file name.
    Returns:
        int: The extracted DOCI number, or -1 if no number is found.
    """
    
    match = re.search(r'_DOCI_(\d+)', name)
    return int(match.group(1)) if match else -1

# Function to create a mask with 4 channels for each tissue type
def create_mask_voxel(mask, tissue_type_index, num_classes=3):
    """
    Creates a mask voxel with specified channels for each tissue type.
    Parameters:
        mask (numpy.ndarray): The binary mask for a single slice.
        tissue_type_index (int): The index of the tissue type to activate in the voxel.
        num_classes (int, optional): The total number of classes. Default is 3.
    Returns:
        numpy.ndarray: A voxel mask of shape (height, width, num_classes) with the tissue type channel activated.
    """
    
    mask_voxel = np.zeros((mask.shape[0], mask.shape[1], num_classes))
    mask_voxel[:, :, tissue_type_index] = mask
    return mask_voxel


# Function to process a dataset (train, val, test)
def process_dataset(dataset, tissue_types):
    """
    Processes a dataset by grouping images based on the base name, sorting by DOCI number,
    and creating voxelized grayscale and mask representations.
    Parameters:
        dataset (list): A list of dictionaries, where each dictionary contains metadata for a single image.
        tissue_types (list): A list of tissue type labels.
    Returns:
        list: A list of voxelized samples, each containing grayscale voxels, tissue type, and a mask.
    """
    
    grouped_samples = defaultdict(lambda: {'images': [], 'image_cutoffs': [], 'tissue_type': None, 'masks': [], 'names': [],
                                          'grayscale': [], 'image_grayscale_cutoff': []})
    
    # Group samples based on the base name
    for sample in dataset:
        base_name = get_base_name(sample['name'])
        grouped_samples[base_name]['images'].append(sample['image'])
        grouped_samples[base_name]['image_cutoffs'].append(sample['image_cutoff'])
        grouped_samples[base_name]['tissue_type'] = sample['tissue_type']
        grouped_samples[base_name]['masks'].append(sample['mask'])
        grouped_samples[base_name]['names'].append(sample['name'])  # Store the name to extract DOCI number later
        grouped_samples[base_name]['grayscale'].append(sample['grayscale'])
        grouped_samples[base_name]['image_grayscale_cutoff'].append(sample['image_grayscale_cutoff'])
    
    # Convert images and image_cutoffs to grayscale and stack them to create voxels in the correct DOCI_n order
    voxelized_samples = []
    for base_name, group in grouped_samples.items():
        # Sort by DOCI number
        sorted_indices = sorted(range(len(group['names'])), key=lambda i: get_doci_number(group['names'][i]))
        
        # Sort images, image_cutoffs, and masks according to the sorted indices
        sorted_images = [group['images'][i] for i in sorted_indices]
        sorted_image_cutoffs = [group['image_cutoffs'][i] for i in sorted_indices]
        sorted_masks = [group['masks'][i] for i in sorted_indices]
        
        sorted_grayscale = [group['grayscale'][i] for i in sorted_indices]
        sorted_image_grayscale_cutoff = [group['image_grayscale_cutoff'][i] for i in sorted_indices]
        
        # Stack the grayscale images and image_cutoffs along the third dimension (axis=-1) to create voxels
        grayscale_voxel = np.stack(sorted_grayscale, axis=-1)
        grayscale_image_cutoff_voxel = np.stack(sorted_image_grayscale_cutoff, axis=-1)

        # Create a mask voxel (256, 256, 4) based on the tissue type index
        tissue_type_index = tissue_types.index(group['tissue_type'])
        mask = group['masks'][0]
        
        # Create the voxelized sample
        voxelized_sample = {
            'name': base_name,
            'grayscale_voxel': grayscale_voxel,
            'grayscale_image_cutoff_voxel': grayscale_image_cutoff_voxel,
            'tissue_type': group['tissue_type'],
            'mask': mask
        }
        voxelized_samples.append(voxelized_sample)
    
    return voxelized_samples

# Define the tissue types
tissue_types = ['Normal', 'Follicular', 'Papillary', 'Anaplastic']

# Process the train, val, and test sets to create voxels
train_combined = process_dataset(train_set, tissue_types)
val_combined = process_dataset(val_set, tissue_types)
test_combined = process_dataset(test_set, tissue_types)

# Image Processing

In [None]:
# Define the augmentation function
def augment(image, image_cutoff, mask):
    """
    Applies random augmentations (flipping, rotation, zoom, and noise) to an image, 
    its cutoff version, and the associated mask.
    Parameters:
        image (tf.Tensor): The input image tensor.
        image_cutoff (tf.Tensor): The cutoff version of the input image tensor.
        mask (tf.Tensor): The mask tensor.

    Returns:
        tuple: Augmented image, image_cutoff, and mask tensors.
    """
    
    image = tf.cast(image, tf.float32)
    image_cutoff = tf.cast(image_cutoff, tf.float32)
    mask = tf.cast(mask, tf.float32)

    # Expand dimensions of the mask to make it 3D (256, 256, 1)
    mask = tf.expand_dims(mask, axis=-1)
    
    # Random augmentation parameters (same for image, image_cutoff, and mask)
    flip_lr = tf.random.uniform(shape=[], minval=0, maxval=1, dtype=tf.float32)
    flip_ud = tf.random.uniform(shape=[], minval=0, maxval=1, dtype=tf.float32)
    angles = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
    zoom_factor = tf.random.uniform([], 0.9, 1.1, dtype=tf.float32)

    # Apply the same horizontal flip
    image = tf.cond(flip_lr > 0.5, lambda: tf.image.flip_left_right(image), lambda: image)
    image_cutoff = tf.cond(flip_lr > 0.5, lambda: tf.image.flip_left_right(image_cutoff), lambda: image_cutoff)
    mask = tf.cond(flip_lr > 0.5, lambda: tf.image.flip_left_right(mask), lambda: mask)

    # Apply the same vertical flip
    image = tf.cond(flip_ud > 0.5, lambda: tf.image.flip_up_down(image), lambda: image)
    image_cutoff = tf.cond(flip_ud > 0.5, lambda: tf.image.flip_up_down(image_cutoff), lambda: image_cutoff)
    mask = tf.cond(flip_ud > 0.5, lambda: tf.image.flip_up_down(mask), lambda: mask)

    # Apply the same rotation
    image = tf.image.rot90(image, k=angles)
    image_cutoff = tf.image.rot90(image_cutoff, k=angles)
    mask = tf.image.rot90(mask, k=angles)

    # Apply the same zoom
    new_height = tf.cast(zoom_factor * tf.cast(tf.shape(image)[0], tf.float32), tf.int32)
    new_width = tf.cast(zoom_factor * tf.cast(tf.shape(image)[1], tf.float32), tf.int32)
    
    image = tf.image.resize(image, [new_height, new_width])
    image_cutoff = tf.image.resize(image_cutoff, [new_height, new_width])
    mask = tf.image.resize(mask, [new_height, new_width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # Resize back to (256, 256)
    image = tf.image.resize_with_crop_or_pad(image, 256, 256)
    image_cutoff = tf.image.resize_with_crop_or_pad(image_cutoff, 256, 256)
    mask = tf.image.resize_with_crop_or_pad(mask, 256, 256)

    # Add random noise conditionally to image and image_cutoff
    add_noise = tf.random.uniform(shape=[], minval=0, maxval=1, dtype=tf.float32)

    def apply_noise(img):
        noise = tf.random.normal(shape=tf.shape(img), mean=0.0, stddev=0.001, dtype=tf.float32)
        return img + noise
    
    image = tf.cond(add_noise > 0.5, lambda: apply_noise(image), lambda: image)
    image_cutoff = tf.cond(add_noise > 0.5, lambda: apply_noise(image_cutoff), lambda: image_cutoff)

    ## Squeeze the mask back to 2D (256, 256) if needed
    mask = tf.squeeze(mask, axis=-1)

    return image, image_cutoff, mask

def load_image_and_masks(image, image_cutoff, mask, tissue_type_label, augment_data=False, num_augmentations=1):
    """
    Normalizes and augments image, cutoff, and mask data for training, with reduced augmentations for 'Normal' tissue.

    Parameters:
        image (numpy.ndarray): Input grayscale image.
        image_cutoff (numpy.ndarray): Input image cutoff.
        mask (numpy.ndarray): Input mask.
        tissue_type_label (str): Tissue type label.
        augment_data (bool): Whether to augment the data.
        num_augmentations (int): Base number of augmentations per sample if augment_data is True.

    Returns:
        tuple: Lists of augmented images, image cutoffs, masks, and tissue type indices.
    """

    # Normalize image and image_cutoff
    image = np.array(image) / 255.0
    image_cutoff = np.array(image_cutoff) / 255.0

    # Find the index of the tissue type
    tissue_types = ['Normal', 'Follicular', 'Papillary', 'Anaplastic']
    tissue_type_index = tissue_types.index(tissue_type_label)

    images = [image]
    image_cutoffs = [image_cutoff]
    masks = [mask]
    labels = [tissue_type_index]

    # Adjust augmentations based on tissue type
    if augment_data:
        # Reduce augmentation factor for "Normal" tissue
        #if tissue_type_label == "Normal":
            #adjusted_augmentations = max(1, num_augmentations // 3)  # e.g., if 6 -> 2
        #else:
            #adjusted_augmentations = num_augmentations
        adjusted_augmentations = num_augmentations ####
        
        for _ in range(adjusted_augmentations):
            augmented_image, augmented_image_cutoff, augmented_mask = augment(image, image_cutoff, mask)
            images.append(augmented_image)
            image_cutoffs.append(augmented_image_cutoff)
            masks.append(augmented_mask)
            labels.append(tissue_type_index)

    return images, image_cutoffs, masks, labels


tissue_types = ['Normal', 'Follicular', 'Papillary']
num_classes = len(tissue_types)


# Function to resize voxel data (for both images and cutoffs)
def resize_voxel(voxel, target_size=(256, 256), is_mask=False):
    """
    Resizes voxel data to the specified target size.

    Parameters:
        voxel (numpy.ndarray): Input voxel data.
        target_size (tuple): Target size for resizing (height, width).
        is_mask (bool): Whether the voxel is a mask (for nearest-neighbor resizing).

    Returns:
        numpy.ndarray: Resized voxel data.
    """
    
    if isinstance(voxel, tf.Tensor):
        voxel = voxel.numpy()

    if isinstance(voxel, np.ndarray):
        if is_mask:
            # Expand dimensions if the mask is 2D
            if len(voxel.shape) == 2:
                voxel = np.expand_dims(voxel, axis=-1)  # Add channel dimension
            
            # Resize the mask using nearest-neighbor to avoid introducing new values
            voxel = tf.image.resize(voxel, target_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR).numpy()
            
            # Squeeze back to 2D if it was originally 2D
            voxel = np.squeeze(voxel, axis=-1)
        else:
            # Normal resizing for image and cutoff voxels
            voxel = tf.image.resize(voxel, target_size).numpy()
    
    return voxel


# Function to process a dataset (train, val, test)
def process_dataset(dataset, augment_data=False, num_augmentations=1, target_size=(256, 256)):
    """
    Processes a dataset to generate images, cutoffs, masks, and labels, with optional augmentation.
    Parameters:
        dataset (list): Input dataset containing image data and metadata.
        augment_data (bool): Whether to apply augmentation to the dataset.
        num_augmentations (int): Number of augmentations to apply per sample if augment_data is True.
        target_size (tuple): Target size for resizing images and masks.
    Returns:
        tuple: Arrays of images, image cutoffs, masks, and labels.
    """
    
    images, images_cutoffs, masks, labels = [], [], [], []

    for img_dict in dataset:
        # Load and augment the images, cutoffs, and masks
        imgs, cutoffs, msks, lbls = load_image_and_masks(
            img_dict['grayscale_voxel'], img_dict['grayscale_image_cutoff_voxel'], img_dict['mask'],
            img_dict['tissue_type'],
            augment_data=augment_data,
            num_augmentations=num_augmentations
        )

        # Resize images, cutoffs, and masks if necessary
        imgs = [resize_voxel(img, target_size) for img in imgs]
        cutoffs = [resize_voxel(cutoff, target_size) for cutoff in cutoffs]
        msks = [resize_voxel(mask, target_size, is_mask=True) for mask in msks]

        images.extend(imgs)
        images_cutoffs.extend(cutoffs)
        masks.extend(msks)
        labels.extend(lbls)

    return np.array(images), np.array(images_cutoffs), np.array(masks), np.array(labels)

# Process training data with augmentation
train_images, train_images_cutoffs, train_masks, train_labels = process_dataset(
    train_combined, augment_data=True, num_augmentations=6)

# Process validation data without augmentation
val_images, val_images_cutoffs, val_masks, val_labels = process_dataset(
    val_combined, augment_data=False, num_augmentations=0)

# Process test data without augmentation
test_images, test_images_cutoffs, test_masks, test_labels = process_dataset(
    test_combined, augment_data=False, num_augmentations=0)


for i in range(len(train_combined)):
    try:
        train_combined[i]["mask_voxel"]
    except:
        train_combined[i]["name"]
        
train_masks = np.expand_dims(train_masks, axis=-1)
val_masks = np.expand_dims(val_masks, axis=-1)
test_masks = np.expand_dims(test_masks, axis=-1) 

def normalize_images(images):
    """
    Normalizes images to the range [0, 1].
    Parameters:
        images (numpy.ndarray): Array of images.
    Returns:
        numpy.ndarray: Normalized images.
    """
    
    if np.max(images) == 255:
        images = images.astype('float32') / 255.0
    return images

train_images = np.array(train_images)
train_masks = np.array(train_masks)
val_images = np.array(val_images)
val_masks = np.array(val_masks)
test_images = np.array(test_images)
test_masks = np.array(test_masks)

# Normalize to [0, 1] range
train_images = normalize_images(train_images)
val_images = normalize_images(val_images)
test_images = normalize_images(test_images)

train_masks = normalize_images(train_masks)
val_masks= normalize_images(val_masks)
test_masks = normalize_images(test_masks)

# Verify the normalization
print(f"Train Images Min: {train_images.min()}, Max: {train_images.max()}")
print(f"Val Images Min: {val_images.min()}, Max: {val_images.max()}")
print(f"Test Images Min: {test_images.min()}, Max: {test_images.max()}")

print(f"Train Masks Min: {train_masks.min()}, Max: {train_masks.max()}")
print(f"Val Masks Min: {val_masks.min()}, Max: {val_masks.max()}")
print(f"Test Masks Min: {test_masks.min()}, Max: {test_masks.max()}")

# U-net model

In [None]:
# Define the squeeze-and-excite block
def squeeze_excite_block(input_tensor, ratio=16):
    """
    Creates a squeeze-and-excite block that recalibrates the input tensor's channel-wise feature maps
    by adaptively reweighting them.
    Parameters:
        input_tensor (tf.Tensor): The input tensor to the squeeze-and-excite block.
        ratio (int, optional): The reduction ratio for channel-wise dimensionality reduction. Default is 16.
    Returns:
        tf.Tensor: The recalibrated tensor after applying squeeze-and-excite.
    """
    
    filters = input_tensor.shape[-1]
    se = tf.keras.layers.GlobalAveragePooling2D()(input_tensor)
    se = tf.keras.layers.Dense(filters // ratio, activation='relu')(se)
    se = tf.keras.layers.Dense(filters, activation='sigmoid')(se)
    se = tf.keras.layers.Reshape([1, 1, filters])(se)
    return tf.keras.layers.multiply([input_tensor, se])

# Define the encoder block
def encoder_block(x, filters):
    """
    Creates an encoder block consisting of convolutional layers, batch normalization, and down-sampling.
    Parameters:
        x (tf.Tensor): The input tensor to the encoder block.
        filters (int): The number of filters for the convolutional layers.
    Returns:
        tuple: A tuple containing:
        - tf.Tensor: The output tensor after the first convolution and batch normalization.
        - tf.Tensor: The down-sampled tensor after the second convolution.
    """
    
    x = Conv2D(filters, (3, 3), padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #x = squeeze_excite_block(x)
    x_pool = Conv2D(filters, (3, 3), strides=(2, 2), padding='same', activation='relu')(x)
    return x, x_pool

# Define the decoder block
def decoder_block(x, skip, filters):
    """
    Creates a decoder block that upsamples the input tensor and merges it with the corresponding skip connection.
    Parameters:
        x (tf.Tensor): The input tensor to the decoder block.
        skip (tf.Tensor): The skip connection tensor from the encoder block.
        filters (int): The number of filters for the convolutional layers.
    Returns:
        tf.Tensor: The output tensor after upsampling, concatenation, and convolutional layers.
    """
    
    x = UpSampling2D((2, 2))(x)
    x = Concatenate()([x, skip])
    x = Conv2D(filters, (3, 3), padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = Conv2D(filters, (3, 3), padding='same', activation='relu')(x)
    return x

# Combined Dice + BCE Loss
def combined_loss(y_true, y_pred):
    """
    Computes the combined Dice loss and Binary Cross-Entropy (BCE) loss.
    Parameters:
        y_true (tf.Tensor): Ground truth tensor.
        y_pred (tf.Tensor): Predicted tensor.
    Returns:
        tf.Tensor: Combined loss value.
    """
    
    dice = dice_loss(y_true, y_pred)
    bce = binary_crossentropy(y_true, tf.clip_by_value(y_pred, 0.0, 1.0))
    return 0.5 * bce + 0.5 * dice

# Dice Loss
def dice_loss(y_true, y_pred):
    """
    Computes the Dice loss for segmentation tasks.
    Parameters:
        y_true (tf.Tensor): Ground truth tensor.
        y_pred (tf.Tensor): Predicted tensor.
    Returns:
        tf.Tensor: Dice loss value.
    """
    
    y_pred = tf.clip_by_value(y_pred, 0.0, 1.0)
    numerator = 2 * tf.reduce_sum(y_true * y_pred)
    denominator = tf.reduce_sum(y_true + y_pred)
    dice = 1 - (numerator + tf.keras.backend.epsilon()) / (denominator + tf.keras.backend.epsilon())
    return tf.maximum(dice, 0)


# Define the segmentation model architecture
def build_u2net_with_rule(input_shape):
    """
    Builds a U-Net-like segmentation model with additional dropout in the bottleneck and a custom loss function.
    Parameters:
        input_shape (tuple): The shape of the input tensor, including the number of channels.
    Returns:
        tf.keras.Model: A compiled U-Net-like segmentation model with a combined Dice and Binary Cross-Entropy (BCE) loss.
    """
    
    image_input = Input(input_shape, name='image_input')

    # Encoder
    enc1, enc1_pool = encoder_block(image_input, 32)
    enc2, enc2_pool = encoder_block(enc1_pool, 64)
    #enc3, enc3_pool = encoder_block(enc2_pool, 128)

    # Bottleneck
    bottleneck = Conv2D(128, (3, 3), padding='same', activation='relu')(enc2_pool)
    bottleneck = BatchNormalization()(bottleneck)
    bottleneck = Dropout(0.7)(bottleneck)

    # Decoder
    #dec3 = decoder_block(bottleneck, enc3, 128)
    dec2 = decoder_block(bottleneck, enc2, 64)
    dec1 = decoder_block(dec2, enc1, 32)

    # Segmentation mask output
    seg_output = Conv2D(1, (1, 1), activation='sigmoid', name='seg_output')(dec1)

    # Define the model
    model = Model(inputs=image_input, outputs=seg_output)

    # Compile the model
    model.compile(optimizer=Adam(learning_rate=1e-4),
                  loss=combined_loss,
                  metrics=['accuracy'])

    return model

# Define input shape with appended one-hot column
input_shape = (256, 256, 23)

# Build the model
unet_segmentation_model = build_u2net_with_rule(input_shape)

In [None]:
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1)
checkpoint = ModelCheckpoint('best_models.weights.h5', monitor='val_loss', save_best_only=True, save_weights_only=True, verbose=1)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6, verbose=1)

# Combine callbacks
callbacks = [early_stopping, checkpoint, reduce_lr]

unet_segmentation_model.fit(
    x=train_images,
    y=train_masks,
    validation_data=(val_images, val_masks),
    batch_size=6,
    epochs=100,
    callbacks=callbacks
)

In [None]:
unet_segmentation_model.save('best_model.h5')
test_loss, test_accuracy = unet_segmentation_model.evaluate(test_images, test_masks)
print(f"Test accuracy: {test_accuracy * 100:.2f}%")
unet_segmentation_model.load_weights('best_model.weights.h5')

# Results Visualization

In [None]:
def visualize_comparison(test_image, predicted_mask, actual_mask):
    """
    Visualizes a side-by-side comparison of a test image, its predicted mask, and the ground truth mask.
    Parameters:
        test_image (numpy.ndarray): The input test image, expected to have at least 3 channels (height, width, channels).
        predicted_mask (numpy.ndarray): The predicted mask, expected to have 3 dimensions (height, width, channels) with a single channel.
        actual_mask (numpy.ndarray): The ground truth mask, expected to be a single-channel (2D) array.
    Raises:
        ValueError: If the actual_mask is not a single-channel (2D) array.
    Returns:
        None: Displays the comparison plots for visual inspection.
    """
    
    if actual_mask.ndim != 3:
        raise ValueError("Actual mask is not a single-channel (2D) array.")

    # Visualize the comparison
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    # Plot the third channel in the image (index 2 for the third channel)
    axs[0].imshow(test_image[:, :, 2], cmap='gray')
    axs[0].set_title("DOCI Image")
    axs[0].axis('off')

    # Plot the predicted mask
    axs[1].imshow(predicted_mask[:, :, 0], cmap='gray')
    axs[1].set_title("Predicted Mask")
    axs[1].axis('off')

    # Plot the actual mask
    axs[2].imshow(actual_mask, cmap='gray')
    axs[2].set_title("Ground Truth Mask")
    axs[2].axis('off')

    plt.savefig("1papillary.png", dpi=400, bbox_inches='tight')
    
    plt.show()

## Test Images

In [None]:
for i in range(len(test_images)):
    test_image = test_images[i]
    actual_mask = test_masks[i]

    # Predict the mask for the validation image
    predicted_mask = unet_segmentation_model.predict(np.expand_dims(test_image, axis=0))[0]

    # Visualize the third channel of the image, the predicted mask, and the actual mask
    visualize_comparison(test_image, predicted_mask, actual_mask)

## Train Images

In [None]:
for i in range(len(train_images)):
    train_image = train_images[i]
    actual_mask = train_masks[i]

    # Predict the mask for the validation image
    predicted_mask = unet_segmentation_model.predict(np.expand_dims(train_image, axis=0))[0]

    # Visualize the third channel of the image, the predicted mask, and the actual mask
    visualize_comparison(train_image, predicted_mask, actual_mask)

## Validation Images

In [None]:
# Function to find the correct mask layer (only one has any true values)
def find_true_mask_layer(mask):
    """
    Identifies the index of the mask layer (channel) that contains non-zero (true) values.
    This function iterates through the channels in the given mask and returns the index of the first 
    channel with any non-zero values. If no channel contains non-zero values, it defaults to returning 0.
    Parameters:
        mask (numpy.ndarray): A 3D array representing the mask (height, width, channels).
    Returns:
        int: The index of the first channel with non-zero values, or 0 if no non-zero values are found.
    """
    
    # Iterate through each channel and find the first one with non-zero values
    for i in range(mask.shape[-1]):
        if np.any(mask[:, :, i]):
            return i
    return 0

# Assuming the test_images, test_masks, and model are already loaded
for i in range(len(val_images)):
    val_image = val_images[i]
    actual_mask = val_masks[i]

    # Predict the mask for the test image voxel
    predicted_mask = unet_segmentation_model.predict(np.expand_dims(val_image, axis=0))[0]

    # Visualize the third image in the voxel, the predicted mask, and the actual mask for the true mask layer
    visualize_comparison(val_image, predicted_mask, actual_mask)