## Benchmarking Contrastive Learning for Multimodal Medical Imaging

## 1. Environment setup

### 1.1. Loading libraries

In [None]:
# Standard
import os
import random
import pickle
import glob
import copy
import re
import plistlib
import xml.etree.ElementTree as ET
from pathlib import Path
from collections import Counter, defaultdict

import numpy as np
import pandas as pd
import cv2
import pydicom
import requests
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image, ImageDraw
from bs4 import BeautifulSoup
from skimage.draw import polygon
from sklearn.metrics import (
    f1_score, matthews_corrcoef, accuracy_score, balanced_accuracy_score,
    jaccard_score, precision_recall_curve, average_precision_score,
    roc_auc_score, roc_curve, precision_score, recall_score, confusion_matrix,
    ConfusionMatrixDisplay
)
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.manifold import TSNE
from scipy.spatial.distance import directed_hausdorff

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
import lightly.data as data

plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf')

In [None]:
# For reproducible results
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Use GPU
NUM_WORKERS = os.cpu_count()
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
print("Number of workers:", NUM_WORKERS)

###    1.2. Directory setting for data and results

In [None]:
# Base directories setup
current_dir = os.getcwd()
root_dir = os.path.abspath(os.path.join(current_dir, ".."))
notebooks_dir = os.path.join(root_dir, "notebooks")
data_dir = os.path.join(root_dir, "data")
data_used_dir = os.path.join(data_dir, "dataUsed")
results_dir = os.path.join(root_dir, "results")

# Creating directories
def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

# Subfolder structures
contrastive_methods = ["simclr", "moco", "byol", "supervised"]
modalities = ["ultrasound", "mammography", "multimodal"]
task_types = ["classification", "segmentation"]
finetune_splits = ["train", "validation", "test"]
classes = ["benign", "malignant", "normal"]
subfolders = ["images", "masks"]  # For segmentation task
dataset_folders = ["UltrasoundDataset", "MammographyDataset", "MultimodalDataset"]

# Creating directories for cleaned data
def create_data_folders(base_data_dir):
    
    for dataset in dataset_folders:
        dataset_dir = os.path.join(base_data_dir, dataset)
        
        # Pretrain and Finetune directories inside each dataset
        pretrain_dir = os.path.join(dataset_dir, "pretrainData")
        finetune_dir = os.path.join(dataset_dir, "finetuneData")
        
        # Create the pretrain directory with train/validation splits
        if dataset == "MultimodalDataset":
            for split in ["train", "validation"]:
                split_dir = os.path.join(pretrain_dir, split)
                create_directory(split_dir)
                for modality in ["ultrasoundImages", "mammographyImages"]:
                    create_directory(os.path.join(split_dir, modality))

        for split in ["train", "validation"]:
            split_dir = os.path.join(pretrain_dir, split)
            create_directory(split_dir)

        # Create the finetune directory with train/validation/test splits and class structure
        for split in finetune_splits:
            split_dir = os.path.join(finetune_dir, split)
            for cls in classes:
                class_dir = os.path.join(split_dir, cls)
                create_directory(class_dir)
                for subfolder in subfolders:
                    subfolder_path = os.path.join(class_dir, subfolder)
                    create_directory(subfolder_path)
                    if dataset == "MultimodalDataset" and subfolder == "images":
                        for modality in ["ultrasoundImages", "mammographyImages"]:
                            create_directory(os.path.join(subfolder_path, modality))
                    elif dataset == "MultimodalDataset" and subfolder == "masks":
                        for mask_type in ["UltrasoundMasks", "MammographyMasks"]:
                            create_directory(os.path.join(subfolder_path, mask_type))

# Results directory
def create_results_structure(base_results_dir):
    # Create the main pretraining directory
    pretrain_dir = os.path.join(base_results_dir, "pretrainPhase")
    create_directory(pretrain_dir)
    
    for method in contrastive_methods:
        method_dir = os.path.join(pretrain_dir, method)
        create_directory(method_dir)
        
        for modality in modalities:
            modality_dir = os.path.join(method_dir, modality)
            create_directory(modality_dir)
            
    for task_type in task_types:
        task_dir = os.path.join(base_results_dir, task_type)
        create_directory(task_dir)

        for method in contrastive_methods:
            method_dir = os.path.join(task_dir, method)
            create_directory(method_dir)

            for modality in modalities:
                modality_dir = os.path.join(method_dir, modality)
                create_directory(modality_dir)

create_data_folders(data_used_dir) 
create_results_structure(results_dir)

# Verification
print(f"Folder structure created under {data_used_dir} for data and {results_dir} for results.")

## 2. Creating datasets (run once)

Ultrasound

- Pretrain: QAMEBI’s BUSI dataset + Breast Ultrasound Image (Mendeley) dataset + BrEaST dataset + Thammasat dataset

- Finetuning: BUSI dataset

Mammography
  
- Pretrain: CBIS-DDSM dataset

- Finetuning: INbreast dataset

Both types
- Pretrain: 50% of data points considered in the ultrasound and mammography datasets, taken randomly per class and per training phase (pretrain or finetune)

- Finetuning: portion of BUSI and CBIS-DDSM (50%) + portion of INbreast (50%)

### 2.1. Handling used datasets

#### 2.1.1. INBreast 

Mammography image dataset

https://doi.org/10.1016/j.acra.2011.09.014

-----

Label adaptations based on:

https://doi.org/10.48550/arXiv.1705.08550


In [None]:
# Originally follows BI-RADS system, now following:
# 1          ->  normal
# 2/3        ->  benign
# 4a,b,c/5/6 ->  malignant

#INbreast—has a total of 115 cases from which 90 cases are from women with both breasts affected (four images per case) and 25 cases are from mastectomy patients (two images per case).

In [None]:
# Define directories
data_dir = '../data/INbreast Release 1.0'
dicom_dir = os.path.join(data_dir, 'AllDICOMs')
xml_dir = os.path.join(data_dir, 'AllXML')
roi_dir = os.path.join(data_dir, 'AllROI')
reports_dir = os.path.join(data_dir, 'MedicalReports')
clean_inbreast_dir = os.path.join(data_dir, 'cleanmergedInbreast')
clean_inbreast_masks_dir = os.path.join(clean_inbreast_dir, 'masks')
clean_inbreast_images_dir = os.path.join(clean_inbreast_dir, 'images')
csv_file_path = '../data/INbreast Release 1.0/INBreast.csv'

def create_directory(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f'Created folder: {directory}')
    else:
        print(f'Folder already exists: {directory}')

create_directory(clean_inbreast_dir)
create_directory(clean_inbreast_masks_dir)
create_directory(clean_inbreast_images_dir)

Folder already exists: ../data/INbreast Release 1.0\cleanmergedInbreast
Folder already exists: ../data/INbreast Release 1.0\cleanmergedInbreast\masks
Folder already exists: ../data/INbreast Release 1.0\cleanmergedInbreast\images


In [None]:
# Extract BIRADS number from medical report
def extract_birads_from_report(report_path):
    try:
        with open(report_path, 'r') as file:
            content = file.read()
            match = re.search(r'bi-?rads\s*-\s*(\d+)', content, re.IGNORECASE)
            if match:
                birads_number = int(match.group(1))
                return birads_number
    except Exception as e:
        print(f"Error reading report {report_path}: {e}")
    return None

# Re-classify based on BIRADS number
def classify_birads(birads_number):
    # Handle special cases like "4a", "4b", "4c" by converting them to "4"
    if isinstance(birads_number, str) and '4' in birads_number:
        birads_number = 4
    else:
        try:
            birads_number = int(birads_number)  # Ensure BIRADS as int
        except ValueError:
            return None
    if birads_number == 1:
        return "normal"
    elif birads_number in [2, 3]:
        return "benign"
    elif birads_number in [4, 5, 6]:
        return "malignant"
    return None

def dicom_to_png(dicom_path, output_image_path):
    try:
        dicom_data = pydicom.dcmread(dicom_path)
        img_array = dicom_data.pixel_array
        
        # Normalize
        img_array = (img_array - np.min(img_array)) / (np.max(img_array) - np.min(img_array)) * 255
        img_array = img_array.astype(np.uint8)

        # Save image
        img = Image.fromarray(img_array)
        img.save(output_image_path)
        return img_array.shape  # Return image size
    except Exception as e:
        print(f"Error converting DICOM {dicom_path} to PNG: {e}")
        return None
    
# Convert DICOM to PNG
dicom_sizes = {}
for dicom_file in os.listdir(dicom_dir):
    if dicom_file.endswith('.dcm'):
        dicom_path = os.path.join(dicom_dir, dicom_file)
        png_filename = f"{os.path.splitext(dicom_file)[0]}.png"
        output_image_path = os.path.join(clean_inbreast_images_dir, png_filename)
        
        img_size = dicom_to_png(dicom_path, output_image_path)
        if img_size is not None:
            dicom_sizes[png_filename] = img_size

print("Conversion of DICOMs completed.")

In [None]:
# Load INbreast CSV data
inbreast_df = pd.read_csv(csv_file_path, delimiter=';')

# List image files in the images folder
image_files = [f for f in os.listdir(clean_inbreast_images_dir) if f.endswith('.png')]

# Counters for different matching scenarios
csv_match_count = 0
unmatched_images = []

# Store BIRADS classification for each image
birads_class_dict = {}

# Step 1: Check Images Against CSV File and Assign Class Based on BIRADS
for image_file in image_files:
    # Extract the image ID for CSV matching (assumed as the first part of the filename)
    image_id = image_file.split('_')[0]

    # Check for a match in the CSV file using "File Name" column
    csv_match = inbreast_df[inbreast_df['File Name'] == int(image_id)]

    if not csv_match.empty:
        csv_birads = csv_match.iloc[0]['Bi-Rads']
        print(f"Image {image_file} matched in CSV with BIRADS: {csv_birads}")

        # Classify the BIRADS into a class (normal, benign, malignant)
        image_class = classify_birads(csv_birads)
        if image_class:
            # Store the classification
            birads_class_dict[image_file] = image_class
            csv_match_count += 1
        else:
            print(f"Invalid BIRADS value for {image_file}: {csv_birads}")
    else:
        # If no match, add to unmatched images list
        unmatched_images.append(image_file)

# Step 2: Final Count and Unmatched Images
total_images = len(image_files)
successful_matches = csv_match_count
print(f"\nFinal count out of {total_images} images:")
print(f" - Matched with CSV: {csv_match_count}")
print(f" - Unmatched images: {len(unmatched_images)}")

if unmatched_images:
    print("\nUnmatched image files:")
    for unmatched_file in unmatched_images:
        print(unmatched_file)

# Step 3: Rename Files Based on Class
for image_file, image_class in birads_class_dict.items():
    # Full path of the original image
    old_image_path = os.path.join(clean_inbreast_images_dir, image_file)

    if os.path.exists(old_image_path):
        # Construct the new filename by appending the class before the extension
        name_part, ext = os.path.splitext(image_file)  # Split filename and extension
        new_image_name = f"{name_part}_{image_class}.png"
        new_image_path = os.path.join(clean_inbreast_images_dir, new_image_name)

        # Rename the file
        os.rename(old_image_path, new_image_path)
        print(f"Renamed {image_file} to {new_image_name}")
    else:
        print(f"File not found: {old_image_path}")

print("\nFile renaming process completed.")

In [None]:
# Function adapted from: https://github.com/pablogiaccaglia/Breast-Cancer-Segmentation-Datasets/blob/master/INbreast/refactorINbreast.py
# Original Author: Pablo Giaccaglia
def loadInbreastMask(mask_path, imshape=(4084, 3328), filter=False):
    
    def load_point(point_string):
        # Converts point string "(x, y)" into a tuple (y, x)
        x, y = tuple([float(num) for num in point_string.strip('()').split(',')])
        return y, x  # Remember (row, col) in numpy

    mask = np.zeros(imshape)
    
    # Open the plist XML file
    with open(mask_path, 'rb') as mask_file:
        plist_dict = plistlib.load(mask_file, fmt=plistlib.FMT_XML)['Images'][0]
        
        num_rois = plist_dict['NumberOfROIs']
        rois = plist_dict['ROIs']
        
        assert len(rois) == num_rois
        
        for roi in rois:
            num_points = roi['NumberOfPoints']
            points = roi['Point_px']
            
            assert num_points == len(points)
            
            points = [load_point(point) for point in points]

            # Apply filter to skip ROIs with fewer points
            if filter and len(points) < 18:
                continue
            
            if len(points) <= 2:  # Handling single points or lines
                for point in points:
                    # Check bounds before assigning
                    if 0 <= int(point[0]) < imshape[0] and 0 <= int(point[1]) < imshape[1]:
                        mask[int(point[0]), int(point[1])] = 1
            else:
                # Draw polygons for regions with more points
                y, x = zip(*points)  # Points are in (row, col) order
                y, x = np.array(y), np.array(x)
                
                # Ensure points are valid within the image bounds
                poly_y, poly_x = polygon(y, x, shape=imshape)
                
                # Assign mask regions
                mask[poly_y, poly_x] = 1
    
    return mask

# Function adapted from: https://github.com/pablogiaccaglia/Breast-Cancer-Segmentation-Datasets/blob/master/INbreast/refactorINbreast.py
# Original Author: Pablo Giaccaglia
def generate_masks(xml_dir, images_dir, output_mask_dir):
    
    os.makedirs(output_mask_dir, exist_ok=True)

    for xml_filename in os.listdir(xml_dir):
        if xml_filename.endswith('.xml'):
            
            xml_file_path = os.path.join(xml_dir, xml_filename)
            
            # Remove file type extension .xml
            base_id = xml_filename[:-4]  # Base ID from XML filename
            
            matched_image_found = False
            
            for image_filename in os.listdir(images_dir):
                if image_filename.startswith(base_id) and image_filename.endswith('.png'):
                    
                    image_path = os.path.join(images_dir, image_filename)
                    matched_image_found = True
                    
                    image = cv2.imread(image_path)
                    if image is not None:
                        image_shape = image.shape[:2]  # Get height and width
                        
                        # Generate the mask
                        mask = loadInbreastMask(xml_file_path, imshape=image_shape)
                        
                        # Save the mask as a binary image
                        mask_output_path = os.path.join(output_mask_dir, f"{base_id}_mask.png")
                        plt.imsave(mask_output_path, mask, cmap='gray', format='png')
                    else:
                        print(f"Could not load image: {image_path}")
                    break
            
            if not matched_image_found:
                print(f"No matching image found for XML file: {xml_filename}")

generate_masks(xml_dir, clean_inbreast_images_dir, clean_inbreast_masks_dir)

In [None]:
def group_masks_by_class(images_dir, masks_dir):
    grouped_masks = defaultdict(list)

    for image_filename in os.listdir(images_dir):
        if image_filename.endswith('.png'):
            print(f"Processing image file: {image_filename}")
            
            # Determine class based on suffix
            if '_benign.png' in image_filename:
                image_class = 'benign'
            elif '_malignant.png' in image_filename:
                image_class = 'malignant'
            elif '_normal.png' in image_filename:
                image_class = 'normal'
            else:
                continue

            # Extract base ID by splitting at the first underscore
            base_id = image_filename.split('_')[0]

            # Look for corresponding mask
            mask_filename = f"{base_id}_mask.png"
            mask_path = os.path.join(masks_dir, mask_filename)

            if os.path.exists(mask_path):
                grouped_masks[image_class].append(mask_filename)
            else:
                print(f"Mask file not found for {mask_filename}")

    return grouped_masks

grouped_masks = group_masks_by_class(clean_inbreast_images_dir, clean_inbreast_masks_dir)
benign_count = len(grouped_masks.get('benign', []))
malignant_count = len(grouped_masks.get('malignant', []))

print(f"Number of benign masks: {benign_count}")
print(f"Number of malignant masks: {malignant_count}")

In [None]:
benign_images_count = 0
malignant_images_count = 0
normal_images_count = 0
unmatched_images_count = 0

for image_filename in os.listdir(clean_inbreast_images_dir):
    if image_filename.endswith('.png'):
        if '_benign.png' in image_filename:
            benign_images_count += 1
        elif '_malignant.png' in image_filename:
            malignant_images_count += 1
        elif '_normal.png' in image_filename:
            normal_images_count += 1
        else:
            unmatched_images_count += 1

print(f"Number of benign images: {benign_images_count}")
print(f"Number of malignant images: {malignant_images_count}")
print(f"Number of normal images: {normal_images_count}")
print(f"Number of unmatched images: {unmatched_images_count}")

In [None]:
def rename_masks_with_class(image_folder, mask_folder):
    
    image_files = os.listdir(image_folder)
    
    # Map ID to class name
    id_to_class = {}
    
    for filename in image_files:
        
        match = re.match(r"(\d+)_.*_(benign|malignant)", filename)
        if match:
            id_number = match.group(1)
            class_name = match.group(2)
            id_to_class[id_number] = class_name

    for mask_filename in os.listdir(mask_folder):
        
        mask_match = re.match(r"(\d+)_mask\.png", mask_filename)
        if mask_match:
            mask_id = mask_match.group(1)
            # Check corresponding class for ID
            if mask_id in id_to_class:
                new_mask_filename = f"{mask_id}_{id_to_class[mask_id]}_mask.png"
                old_mask_path = os.path.join(mask_folder, mask_filename)
                new_mask_path = os.path.join(mask_folder, new_mask_filename)
                os.rename(old_mask_path, new_mask_path)

rename_masks_with_class(clean_inbreast_images_dir, clean_inbreast_masks_dir)

In [None]:
# Generate synthetic masks (empty arrays) serving as "normal" case masks in segmentation

# Mask dimensions
height, width = 3328, 2560

# Create the all-black mask (zeroes array)
black_mask = np.zeros((height, width), dtype=np.uint8)
c = 1

for filename in os.listdir(clean_inbreast_images_dir):
    if filename.endswith("_normal.png"):
        # Unique ID (part before '_normal')
        base_id = filename.split('_')[0]
        
        mask_filename = f"{base_id}_normal_syntheticMask.png"
        
        # Save new mask
        mask_path = os.path.join(clean_inbreast_masks_dir, mask_filename)
        Image.fromarray(black_mask).save(mask_path)

        print(f"{c} - Created mask: {mask_filename}")
        c+=1

In [None]:
def extract_id(filename):
    return filename.split('_')[0]

image_files = [f for f in os.listdir(clean_inbreast_images_dir) if os.path.isfile(os.path.join(clean_inbreast_images_dir, f))]
mask_files = [f for f in os.listdir(clean_inbreast_masks_dir) if os.path.isfile(os.path.join(clean_inbreast_masks_dir, f))]

image_dict = {extract_id(f): f for f in image_files}
mask_dict = {extract_id(f): f for f in mask_files}

missing_masks = []
duplicate_masks = []

# Check correspondence
for image_id, image_file in image_dict.items():
    if image_id not in mask_dict:
        missing_masks.append(image_file)  # No corresponding mask
    elif list(mask_dict.values()).count(mask_dict[image_id]) > 1:
        duplicate_masks.append(mask_dict[image_id])  # Mask used for more than one image

print(f"Total images: {len(image_files)}")
print(f"Total masks: {len(mask_files)}")
print(f"Images without corresponding masks: {len(missing_masks)}")
print(f"Masks associated with more than one image: {len(duplicate_masks)}")

# Missing or duplicate details
if missing_masks:
    print("Images missing masks:")
    for img in missing_masks:
        print(img)

if duplicate_masks:
    print("Duplicate masks:")
    for mask in duplicate_masks:
        print(mask)

#### 2.1.2. CBIS-DDSM

Mammography image dataset

https://doi.org/10.1038/sdata.2017.177

In [None]:
# Directories
jpeg_dir = '../data/CBISDDSMdataset/jpeg'
clean_cbisddsm_dir = '../data/CBISDDSMdataset/cleanmergedCbisddsm'
clean_cbisddsm_images_dir = os.path.join(clean_cbisddsm_dir, 'images')
clean_cbisddsm_masks_dir = os.path.join(clean_cbisddsm_dir, 'masks')

os.makedirs(clean_cbisddsm_images_dir, exist_ok=True)
os.makedirs(clean_cbisddsm_masks_dir, exist_ok=True)

# Load DICOM and filter for "full mammogram images" and "cropped images"
dicom_data = pd.read_csv('../data/CBISDDSMdataset/csv/dicom_info.csv')
dicom_data = dicom_data[["image_path", "PatientID", "SeriesDescription"]]

dicom_data['clean_image_path'] = dicom_data['image_path'].str.replace("CBIS-DDSM/jpeg/", "")

# Load pathology data and clean paths
columns_to_keep = ["pathology", "image file path", "cropped image file path", "ROI mask file path"]
calc_test_data = pd.read_csv('../data/CBISDDSMdataset/csv/calc_case_description_test_set.csv')[columns_to_keep]
calc_train_data = pd.read_csv('../data/CBISDDSMdataset/csv/calc_case_description_train_set.csv')[columns_to_keep]
mass_test_data = pd.read_csv('../data/CBISDDSMdataset/csv/mass_case_description_test_set.csv')[columns_to_keep]
mass_train_data = pd.read_csv('../data/CBISDDSMdataset/csv/mass_case_description_train_set.csv')[columns_to_keep]
pathology_data = pd.concat([calc_train_data, calc_test_data, mass_train_data, mass_test_data])

In [None]:
# Exclude cropped images immediately
dicom_filtered = dicom_data[dicom_data['SeriesDescription'] != 'cropped images']

# Separate full mammogram and ROI images
full_mammo_data = dicom_filtered[dicom_filtered['SeriesDescription'] == 'full mammogram images']
roi_data = dicom_filtered[dicom_filtered['SeriesDescription'] == 'ROI mask images']

counter = 1

for _, mammo_row in full_mammo_data.iterrows():
    
    mammo_patient_id = mammo_row['PatientID']
    full_mammo_image_path = mammo_row['clean_image_path']

    # Check for corresponding pathology data
    matching_pathology = pathology_data[pathology_data['image file path'].str.startswith(mammo_patient_id)]
    if matching_pathology.empty:
        continue  # Skip if no pathology data is found

    # Determine the pathology class
    pathology_class = matching_pathology['pathology'].values[0].lower()
    if pathology_class == 'benign_without_callback':
        pathology_class = 'normal'

    # Malignant cases: Process both full mammogram and ROI as in previous logic
    if pathology_class == 'malignant':
        # Check for matching ROI in roi_data
        matching_roi = roi_data[roi_data['PatientID'].str.startswith(mammo_patient_id)]
        
        # Skip if more or less than one matching ROI
        if len(matching_roi) != 1:
            continue

        # ROI row
        roi_row = matching_roi.iloc[0]
        roi_image_path = roi_row['clean_image_path']
        
        # Destination filenames
        mammo_dest_filename = f"{os.path.splitext(os.path.basename(full_mammo_image_path))[0]}_id{counter}_malignant.png"
        roi_dest_filename = f"{os.path.splitext(os.path.basename(roi_image_path))[0]}_id{counter}_malignant.png"
        mammo_dest_path = os.path.join(clean_cbisddsm_images_dir, mammo_dest_filename)
        roi_dest_path = os.path.join(clean_cbisddsm_masks_dir, roi_dest_filename)

        # Copy full mammogram and ROI mask images
        shutil.copy(os.path.join(jpeg_dir, full_mammo_image_path), mammo_dest_path)
        shutil.copy(os.path.join(jpeg_dir, roi_image_path), roi_dest_path)
        print(f"Copied malignant full mammogram: {mammo_dest_path}")
        print(f"Copied malignant ROI mask: {roi_dest_path}")

    # Benign and Normal cases: Only copy full mammogram images to "images" subfolder
    else:
        mammo_dest_filename = f"{os.path.splitext(os.path.basename(full_mammo_image_path))[0]}_id{counter}_{pathology_class}.png"
        mammo_dest_path = os.path.join(clean_cbisddsm_images_dir, mammo_dest_filename)

        # Copy full mammogram image
        shutil.copy(os.path.join(jpeg_dir, full_mammo_image_path), mammo_dest_path)
        print(f"Copied {pathology_class} full mammogram: {mammo_dest_path}")

    counter += 1

In [None]:
def rename_masks(mask_folder):
    for filename in os.listdir(mask_folder):
        if not filename.endswith('_mask.png'):  # Avoid re-renaming
            old_path = os.path.join(mask_folder, filename)
            new_path = os.path.join(mask_folder, filename.split('.')[0] + '_mask.png')
            os.rename(old_path, new_path)

rename_masks(clean_cbisddsm_masks_dir)

#### 2.1.3. BUSI

Ultrasound image dataset

https://doi.org/10.3390/healthcare10040729

In [None]:
busi_dir = "../data/Dataset_BUSI_with_GTdataset"
busi_benign_dir = os.path.join(busi_dir, "benign")
busi_malignant_dir = os.path.join(busi_dir, "malignant")
busi_normal_dir = os.path.join(busi_dir, "normal")
clean_busi_dir = os.path.join(busi_dir, "cleanmergedBusi")
clean_busi_masks_dir = os.path.join(clean_busi_dir, 'masks')
clean_busi_images_dir = os.path.join(clean_busi_dir, 'images')

os.makedirs(clean_busi_masks_dir, exist_ok=True)
os.makedirs(clean_busi_images_dir, exist_ok=True)

In [None]:
def extract_id(filename):
    return filename.replace('_mask', '').replace('.png', '')

# Get all image and mask files from multiple directories
def get_all_files(*directories):
    image_files = []
    mask_files = []
    
    for directory in directories:
        image_files.extend([f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and '_mask' not in f])
        mask_files.extend([f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and '_mask' in f])
    
    return image_files, mask_files

# List of all images and masks from the three directories
image_files, mask_files = get_all_files(busi_benign_dir, busi_malignant_dir, busi_normal_dir)

# Map IDs to filenames
image_dict = {extract_id(f): f for f in image_files}
mask_dict = {extract_id(f): f for f in mask_files}

# Initialize lists to track issues
missing_masks = []
duplicate_masks = []
extra_masks = []

# Check correspondence between images and masks
for image_id, image_file in image_dict.items():
    if image_id not in mask_dict:
        missing_masks.append(image_file)  # No corresponding mask
    elif list(mask_dict.values()).count(mask_dict[image_id]) > 1:
        duplicate_masks.append(mask_dict[image_id])  # Mask used for more than one image

print(f"Total images: {len(image_files)}")
print(f"Total masks: {len(mask_files)}")

# Masks that don't have a corresponding image check
for mask_id, mask_file in mask_dict.items():
    if mask_id not in image_dict:
        extra_masks.append(mask_id)  # Store mask ID (without "_mask") to skip both mask and associated image

print(f"Extra masks (without corresponding images): {len(extra_masks)}")

# Filenames of the extra masks
if extra_masks:
    print("\nFilenames of extra masks:")
    for mask in extra_masks:
        print(f"{mask}_mask.png")

# Skip both extra masks and their corresponding images
def copy_files(class_dir, extra_masks):
    files = os.listdir(class_dir)
    
    for file in files:
        file_id = extract_id(file)  # Extract ID to check if it's in the extra masks list
        
        if file_id in extra_masks:
            print(f"Skipping {file} as it is associated with an extra mask.")
            continue
        
        file_path = os.path.join(class_dir, file)
        
        # Check if it's a mask or an image based on the file name
        if '_mask' in file:
            # Copy mask files to the masks directory
            destination_path = os.path.join(clean_busi_masks_dir, file)
            shutil.copy(file_path, destination_path)
            print(f"Copied mask: {file} -> {destination_path}")
        else:
            # Copy mammogram images to the images directory
            destination_path = os.path.join(clean_busi_images_dir, file)
            shutil.copy(file_path, destination_path)
            print(f"Copied image: {file} -> {destination_path}")

copy_files(busi_benign_dir, extra_masks)
copy_files(busi_malignant_dir, extra_masks)
copy_files(busi_normal_dir, extra_masks)

In [None]:
lst = os.listdir(clean_busi_images_dir) 
number_files = len(lst)
print(number_files)

lst2 = os.listdir(clean_busi_masks_dir) 
number_files2 = len(lst2)
print(number_files2)

#### 2.1.4. Mendeley

Ultrasound image dataset

https://doi.org/10.17632/wmy84gzngw.1

In [None]:
mendeley_dir = "../data/BreastUltrasoundImageMendeleydataset/originals"
mendeley_benign_dir = os.path.join(mendeley_dir, "benign")
mendeley_malignant_dir = os.path.join(mendeley_dir, "malignant")

# No normal class here

clean_mendeley_dir = os.path.join(mendeley_dir, "cleanmergedMendeley")

os.makedirs(clean_mendeley_dir, exist_ok=True)

In [None]:
def copy_and_rename_images(src_dir, label):
    files = os.listdir(src_dir)
    
    for file in files:
        if file.lower().endswith('.bmp'):
            file_path = os.path.join(src_dir, file)
            
            # Load and convert to PNG format
            with Image.open(file_path) as img:
                # New filename with _benign or _malignant appended
                new_filename = file.replace('.bmp', f'_{label}.png')
                new_file_path = os.path.join(clean_mendeley_dir, new_filename)
                
                img.save(new_file_path, 'PNG')
                
                print(f"Copied and renamed: {file} -> {new_filename}")

copy_and_rename_images(mendeley_benign_dir, 'benign')
copy_and_rename_images(mendeley_malignant_dir, 'malignant')

print("Copying and renaming complete.")

In [None]:
clean_merged_images = [f for f in os.listdir(clean_mendeley_dir) if f.lower().endswith('.png')]
print(f"Number of images in cleanmergedMendeley folder: {len(clean_merged_images)}")

benign_images = [f for f in os.listdir(mendeley_benign_dir) if f.lower().endswith('.bmp')]
malignant_images = [f for f in os.listdir(mendeley_malignant_dir) if f.lower().endswith('.bmp')]
total_images = len(benign_images) + len(malignant_images)
print(f"Total number of images in benign and malignant folders: {total_images}")

#### 2.1.5. QAMEBI

Ultrasound image dataset

https://doi.org/10.1016/j.compbiomed.2022.106438

https://doi.org/10.1016/j.ejrad.2022.110591

https://doi.org/10.1016/j.bbe.2022.07.004

In [None]:
qamebi_dir = "../data/QAMEBIdataset"
qamebi_benign_dir = os.path.join(qamebi_dir, "benign")
qamebi_malignant_dir = os.path.join(qamebi_dir, "malignant")

# No normal class here

clean_qamebi_dir = os.path.join(qamebi_dir, "cleanmergedQamebi")

os.makedirs(clean_qamebi_dir, exist_ok=True)

In [None]:
def copy_and_convert_images(src_dir, dest_dir):
    files = os.listdir(src_dir)
    for file in files:
        file_path = os.path.join(src_dir, file)
        
        if "Image" in file and file.endswith(".bmp"):  # Only process files with "Image"
            # Load and convert BMP to PNG
            img = Image.open(file_path)
            new_filename = file.replace(".bmp", ".png")
            destination_path = os.path.join(dest_dir, new_filename)
            img.save(destination_path)
            print(f"Copied and converted image: {file} -> {destination_path}")

copy_and_convert_images(qamebi_benign_dir, clean_qamebi_dir)
copy_and_convert_images(qamebi_malignant_dir, clean_qamebi_dir)

clean_images_count = len([f for f in os.listdir(clean_qamebi_dir) if "Image" in f])
print(f"Number of images in the clean merged folder: {clean_images_count}")

# Sum of images in the benign and malignant folders (filtering only those with "Image" in the name)
benign_images_count = len([f for f in os.listdir(qamebi_benign_dir) if "Image" in f and f.endswith(".bmp")])
malignant_images_count = len([f for f in os.listdir(qamebi_malignant_dir) if "Image" in f and f.endswith(".bmp")])
total_images_count = benign_images_count + malignant_images_count
print(f"Total number of 'Image' files in benign and malignant folders: {total_images_count}")

if clean_images_count == total_images_count:
    print("Success: The number of images in the clean merged folder matches the sum of 'Image' files from benign and malignant folders.")
else:
    print("Warning: There is a mismatch between the number of images in the clean merged folder and the total 'Image' files from benign and malignant folders.")

#### 2.1.6. Thammaset

Ultrasound image dataset

https://doi.org/10.1016/j.patcog.2018.01.032

In [None]:
# Load HTML
with open('thammasatHospitalSourcePageCodeSnippet.html', 'r', encoding='utf-8') as file:
    html_content = file.read()

# Parse the HTML content
soup = BeautifulSoup(html_content, 'html.parser')

# Initialize an empty dictionary to store image data
image_dict = {}

table = soup.find('table', {'id': 'imageList'})
rows = table.find_all('tr')[1:]  # Skip the header row

for row in rows:
    try:
        image_tag = row.find('a', {'class': 'fancybox'})
        image_url = image_tag['href']
        image_id = image_url.split('/')[-1].split('.')[0]
    except (AttributeError, IndexError):
        continue  # Skip row if the image or URL is missing

    try:
        class_tag = row.find_all('td')[3]  # The class is in the 4th column (index 3)
        class_label = class_tag.text.strip()

        # empty class field is set to 'missing'
        if not class_label:
            class_label = 'missing'
    except (AttributeError, IndexError):
        class_label = 'missing'

    image_dict[image_id] = class_label

for img_id, cls in image_dict.items():
    print(f"Image ID: {img_id}, Class: {cls}")

In [None]:
thammasat_dir = '../data/Thammasatdataset'
thammasat_images_dir = os.path.join(thammasat_dir, "images")
clean_thammasat_dir = os.path.join(thammasat_dir, "cleanmergedThammasat")

os.makedirs(thammasat_images_dir, exist_ok=True)
os.makedirs(clean_thammasat_dir, exist_ok=True)

# Base URL
base_url = 'http://www.onlinemedicalimages.com/media/com_record/'

# Download each image
for image_id, class_label in image_dict.items():

    image_url = base_url + image_id + '.jpg'
    
    # Local filename with appended class label to the image ID)
    filename = f"{image_id}_{class_label.replace(' ', '_')}.jpg"
    save_path = os.path.join(thammasat_images_dir, filename)
    
    try:
        print(f"Downloading {image_url} ...")
        response = requests.get(image_url)
        response.raise_for_status()

        with open(save_path, 'wb') as file:
            file.write(response.content)
        print(f"Saved {filename}")

    except requests.exceptions.RequestException as e:
        print(f"Failed to download {image_url}: {e}")

print("Image download process completed.")

In [None]:
class_distribution = Counter(image_dict.values())
for cls, count in class_distribution.items():
    print(f"Class: {cls}, Count: {count}")

In [None]:
#Out of the 165 downloaded only the classes below were considered with "Fibroadenoma" and "Cyst" classifications being merged with the overall class of benign images -> check https://geekymedics.com/benign-breast-disease/

malignant_classes = ['malignant', 'Malignant Solid Mass']
benign_classes = ['Benign Solid Mass', 'Fibroadenoma', 'Cyst']

for filename in os.listdir(thammasat_images_dir):
    if filename.endswith('.png') or filename.endswith('.jpg'):
        # Image ID from the filename is before the first underscore
        image_id = filename.split('_')[0]
        print(f"Processing file: {filename}, extracted ID: {image_id}")

        if image_id in image_dict:
            class_label = image_dict[image_id]
            print(f"Found class label for {image_id}: {class_label}")
            
            if class_label in malignant_classes:
                new_filename = f"{image_id}_malignant.png"
                shutil.copy2(os.path.join(thammasat_images_dir, filename), 
                             os.path.join(clean_thammasat_dir, new_filename))
                print(f"Copied {filename} to {new_filename}")
            elif class_label in benign_classes:
                new_filename = f"{image_id}_benign.png"
                shutil.copy2(os.path.join(thammasat_images_dir, filename), 
                             os.path.join(clean_thammasat_dir, new_filename))
                print(f"Copied {filename} to {new_filename}")
        else:
            print(f"No classification found for {image_id}")

print("Images have been copied and renamed successfully.")

#### 2.1.7. BrEaST or Breast-Lesions-USG (ultra)

Ultrasound image dataset

https://doi.org/10.7937/9WKK-Q141

In [None]:
BrEaST_dir = "../data/BrEaSTdataset/BrEaST-Lesions_USG-images_and_masks-Dec-15-2023"
BrEaST_csv_dir = "../data/BrEaSTdataset/BrEaST-Lesions-USG-clinical-data-Dec-15-2023.xlsx"
BrEaST_imagesAndMasks_dir = os.path.join(BrEaST_dir, "BrEaST-Lesions_USG-images_and_masks")
clean_BrEaST_dir = os.path.join(BrEaST_dir, "cleanmergedBrEaST")

os.makedirs(clean_BrEaST_dir, exist_ok=True)

In [None]:
# Load the Excel file
BrEaST_csv_dir = "../data/BrEaSTdataset/BrEaST-Lesions-USG-clinical-data-Dec-15-2023.xlsx"
df = pd.read_excel(BrEaST_csv_dir)

In [None]:
for index, row in df.iterrows():
    original_filename = row['Image_filename']  # Image filenames
    class_name = row['Classification']  # Class label

    # Construct the full path to the original image
    original_image_path = os.path.join(BrEaST_imagesAndMasks_dir, original_filename)
    
    # Check if the original image exists
    if os.path.exists(original_image_path):
        new_filename = f"{os.path.splitext(original_filename)[0]}_{class_name}.png"
        new_image_path = os.path.join(clean_BrEaST_dir, new_filename)

        # Copy the image to the new directory with the new filename
        shutil.copy2(original_image_path, new_image_path)
        print(f"Copied: {original_filename} to {new_filename}")
    else:
        print(f"File not found: {original_filename}")

### 2.2. Creating the 3 used custom datasets from gathered data

#### 2.2.1. Setting paths

In [None]:
# Ultrasound Dataset
ultrasound_dir = os.path.join(data_used_dir, 'UltrasoundDataset')
ultrasound_pretrain_dir = os.path.join(ultrasound_dir, 'pretrainData')
ultrasound_finetune_dir = os.path.join(ultrasound_dir, 'finetuneData')

# Mammography Dataset
mammography_dir = os.path.join(data_used_dir, 'MammographyDataset')
mammography_pretrain_dir = os.path.join(mammography_dir, 'pretrainData')
mammography_finetune_dir = os.path.join(mammography_dir, 'finetuneData')

# Multimodal Dataset
multimodal_dir = os.path.join(data_used_dir, 'MultimodalDataset')
multimodal_pretrain_dir = os.path.join(multimodal_dir, 'pretrainData')
multimodal_finetune_dir = os.path.join(multimodal_dir, 'finetuneData')

# Actual paths to the cleanmerged folders (avoids re-running code above)
clean_qamebi_dir = "../data/QAMEBIdataset/cleanmergedQamebi"
clean_thammasat_dir = '../data/Thammasatdataset/cleanmergedThammasat'
clean_mendeley_dir = "../data/BreastUltrasoundImageMendeleydataset/originals/cleanmergedMendeley"
clean_BrEaST_dir = "../data/BrEaSTdataset/BrEaST-Lesions_USG-images_and_masks-Dec-15-2023/cleanmergedBrEaST"
clean_busi_masks_dir = "../data/Dataset_BUSI_with_GTdataset/cleanmergedBusi/masks"
clean_busi_images_dir = "../data/Dataset_BUSI_with_GTdataset/cleanmergedBusi/images"
clean_cbisddsm_masks_dir = '../data/CBISDDSMdataset/cleanmergedCbisddsm/masks'
clean_cbisddsm_images_dir = '../data/CBISDDSMdataset/cleanmergedCbisddsm/images'
clean_inbreast_masks_dir = '../data/INbreast Release 1.0/cleanmergedInbreast/masks'
clean_inbreast_images_dir = '../data/INbreast Release 1.0/cleanmergedInbreast/images'

dataset_dirs = {
    "BUSI": clean_busi_images_dir,
    "BUSI_masks": clean_busi_masks_dir,
    "QAMEBI": clean_qamebi_dir,
    "Mendeley": clean_mendeley_dir,
    "INBreast": clean_inbreast_images_dir,
    "INBreast_masks": clean_inbreast_masks_dir,
    "CBIS-DDSM": clean_cbisddsm_images_dir,
    "CBIS-DDSM_masks": clean_cbisddsm_masks_dir,
    "Thammasat": clean_thammasat_dir,
    "BrEaST": clean_BrEaST_dir
}

#Hold images for each class and their masks
image_variables = {
    'BUSI': {
        'benign': [],
        'malignant': [],
        'normal': []
    },
    'BUSI_masks': {
        'benign_mask': [],
        'malignant_mask': [],
        'normal_mask': []
    },
    'QAMEBI': {
        'benign': [],
        'malignant': [],
        'normal': []
    },
    'Mendeley': {
        'benign': [],
        'malignant': [],
        'normal': []
    },
    'Thammasat': {
        'benign': [],
        'malignant': [],
        'normal': []
    },
    'BrEaST': {
        'benign': [],
        'malignant': [],
        'normal': []
    },
    'INBreast': {
        'benign': [],
        'malignant': [],
        'normal': []
    },
    'INBreast_masks': {
        'benign_mask': [],
        'malignant_mask': [],
        'normal_mask': []
    },
    'CBIS-DDSM': {
        'benign': [],
        'malignant': [],
        'normal': []
    },
    'CBIS-DDSM_masks': {
        'benign_mask': [],
        'malignant_mask': [],
        'normal_mask': []
    }
}

In [None]:
# Populate images and masks by class based on keywords
def populate_image_variables(dataset_dir, variable, mask=False):
    # Determine classification names depending on whether it is mask data
    class_names = ['benign_mask', 'malignant_mask', 'normal_mask'] if mask else ['benign', 'malignant', 'normal']
    
    for img_name in os.listdir(dataset_dir):
        img_path = os.path.join(dataset_dir, img_name)
        if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            found_class = None
            for class_name in class_names:
                if re.search(class_name.replace('_mask', ''), img_name, re.IGNORECASE):
                    variable[class_name].append(img_path)
                    found_class = class_name
                    break
            if not found_class:
                print(f"Unmatched file in {dataset_dir}: {img_name}")

# Match INBreast masks with images based on ID and classify by type
def populate_inbreast_masks(image_dir, mask_dir, variable):
    
    image_name_map = {img.split("_")[0]: img for img in os.listdir(image_dir) if img.lower().endswith(('.png', '.jpg', '.jpeg'))}
    
    for mask_name in os.listdir(mask_dir):
        mask_path = os.path.join(mask_dir, mask_name)
        if mask_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            mask_prefix = mask_name.split("_")[0]
            if mask_prefix in image_name_map:
                image_name = image_name_map[mask_prefix]
                
                if "benign" in image_name.lower():
                    variable['benign_mask'].append(mask_path)
                elif "malignant" in image_name.lower():
                    variable['malignant_mask'].append(mask_path)
                elif "normal" in image_name.lower():
                    variable['normal_mask'].append(mask_path)
            else:
                print(f"No match for mask file: {mask_name}")


populate_image_variables(clean_busi_images_dir, image_variables['BUSI'])
populate_image_variables(clean_busi_masks_dir, image_variables['BUSI_masks'], mask=True)
populate_image_variables(clean_qamebi_dir, image_variables['QAMEBI'])
populate_image_variables(clean_thammasat_dir, image_variables['Thammasat'])
populate_image_variables(clean_mendeley_dir, image_variables['Mendeley'])
populate_image_variables(clean_BrEaST_dir, image_variables['BrEaST'])
populate_image_variables(clean_inbreast_images_dir, image_variables['INBreast'])
populate_inbreast_masks(clean_inbreast_images_dir, clean_inbreast_masks_dir, image_variables['INBreast_masks'])
populate_image_variables(clean_cbisddsm_images_dir, image_variables['CBIS-DDSM'])
populate_image_variables(clean_cbisddsm_masks_dir, image_variables['CBIS-DDSM_masks'], mask=True)

In [None]:
# Count entries in each key/subkey of image_variables
def count_entries_in_image_variables(variable):
    counts = {}
    for main_key, sub_dict in variable.items():
        counts[main_key] = {}
        for sub_key, images_list in sub_dict.items():
            counts[main_key][sub_key] = len(images_list)
    return counts

image_variable_counts = count_entries_in_image_variables(image_variables)

for main_key, sub_counts in image_variable_counts.items():
    print(f"\n{main_key}:")
    for sub_key, count in sub_counts.items():
    
        print(f"  {sub_key}: {count}")

#### 2.2.2. Ultrasound dataset

In [None]:
# Count images by class across datasets
def count_images_by_class(images):
    counts = Counter()
    for img in images:
        if "benign" in img:
            counts['benign'] += 1
        elif "malignant" in img:
            counts['malignant'] += 1
        elif "normal" in img:
            counts['normal'] += 1
    return counts

# Count masks by class
def count_masks_by_class(masks):
    counts = Counter()
    for mask in masks:
        if "benign" in mask:
            counts['benign'] += 1
        elif "malignant" in mask:
            counts['malignant'] += 1
        elif "normal" in mask:
            counts['normal'] += 1
    return counts

# Random selection of specified amount for pretrain donation
pretrain_busi_images = {
    'benign': random.sample(image_variables['BUSI']['benign'], 234),
    'malignant': random.sample(image_variables['BUSI']['malignant'], 47),
    'normal': random.sample(image_variables['BUSI']['normal'], 100)
}

# Collect remaining images for finetune
finetune_busi_images = {
    'benign': list(set(image_variables['BUSI']['benign']) - set(pretrain_busi_images['benign'])),
    'malignant': list(set(image_variables['BUSI']['malignant']) - set(pretrain_busi_images['malignant'])),
    'normal': list(set(image_variables['BUSI']['normal']) - set(pretrain_busi_images['normal']))
}

# Get corresponding masks for finetune images
finetune_busi_masks = {
    'benign': [
        mask for img in finetune_busi_images['benign']
        for mask in image_variables['BUSI_masks']['benign_mask']
        if os.path.basename(img).replace('.png', '') in os.path.basename(mask)
    ],
    'malignant': [
        mask for img in finetune_busi_images['malignant']
        for mask in image_variables['BUSI_masks']['malignant_mask']
        if os.path.basename(img).replace('.png', '') in os.path.basename(mask)
    ],
    'normal': [
        mask for img in finetune_busi_images['normal']
        for mask in image_variables['BUSI_masks']['normal_mask']
        if os.path.basename(img).replace('.png', '') in os.path.basename(mask)
    ]
}

In [None]:
all_pretrain_images = []

# Populate images for each class
for dataset in ['QAMEBI', 'Mendeley', 'Thammasat', 'BrEaST']:
    all_pretrain_images.extend(set(image_variables[dataset]['benign']))
    all_pretrain_images.extend(set(image_variables[dataset]['malignant']))
    all_pretrain_images.extend(set(image_variables[dataset]['normal']))

# Confirm count after filtering duplicates
benign_count = len([img for img in all_pretrain_images if 'benign' in img.lower()])
malignant_count = len([img for img in all_pretrain_images if 'malignant' in img.lower()])
normal_count = len([img for img in all_pretrain_images if 'normal' in img.lower()])

print(f"before donation: Benign count: {benign_count}, Malignant count: {malignant_count}, Normal count: {normal_count}")

all_pretrain_images.extend(pretrain_busi_images['benign'])
all_pretrain_images.extend(pretrain_busi_images['malignant'])
all_pretrain_images.extend(pretrain_busi_images['normal'])

# Confirm benign count after filtering duplicates
benign_count = len([img for img in all_pretrain_images if 'benign' in img.lower()])
malignant_count = len([img for img in all_pretrain_images if 'malignant' in img.lower()])
normal_count = len([img for img in all_pretrain_images if 'normal' in img.lower()])

print(f"after donation: Benign count: {benign_count}, Malignant count: {malignant_count}, Normal count: {normal_count}")

In [None]:
def split_images(images, train_ratio, val_ratio, test_ratio=0):
    if len(images) < 2:
        return images, [], [] if test_ratio > 0 else images, []
    
    train_images, temp_images = train_test_split(images, train_size=train_ratio, random_state=42)
    
    if test_ratio > 0:
        val_images, test_images = train_test_split(temp_images, test_size=test_ratio / (val_ratio + test_ratio), random_state=42)
        return train_images, val_images, test_images
    else:
        val_images = temp_images
        return train_images, val_images

def match_masks(image_list, masks, masks_folder):
    matched_masks = []
    
    for img in image_list:
        base_name = img.split('\\')[-1].split('.')[0] 

        expected_mask_name = f"{base_name}_mask.png"
        expected_mask_path = f"{masks_folder}\\{expected_mask_name}"  

        # Check expected mask path exists in the masks list
        if expected_mask_path in masks:
            matched_masks.append(expected_mask_path)
        else:
            matched_masks.append(None)
    
    return matched_masks

# Separate each class in all_pretrain_images and finetune_busi_images
pretrain_benign = [img for img in all_pretrain_images if 'benign' in img.lower()]
pretrain_malignant = [img for img in all_pretrain_images if 'malignant' in img.lower()]
pretrain_normal = [img for img in all_pretrain_images if 'normal' in img.lower()]

# Repeat for finetune data, but with test set included
finetune_benign = [img for img in finetune_busi_images['benign'] if 'benign' in img.lower()]
finetune_malignant = [img for img in finetune_busi_images["malignant"] if 'malignant' in img.lower()]
finetune_normal = [img for img in finetune_busi_images['normal'] if 'normal' in img.lower()]

# Split each class
train_benign, val_benign = split_images(pretrain_benign, 0.8, 0.2)
train_malignant, val_malignant = split_images(pretrain_malignant, 0.8, 0.2)
train_normal, val_normal = split_images(pretrain_normal, 0.8, 0.2)

# Repeat for finetune data, but with test set included
train_benign_ft, val_benign_ft, test_benign_ft = split_images(finetune_benign, 0.6, 0.2, 0.2)
train_malignant_ft, val_malignant_ft, test_malignant_ft = split_images(finetune_malignant, 0.6, 0.2, 0.2)
train_normal_ft, val_normal_ft, test_normal_ft = split_images(finetune_normal, 0.6, 0.2, 0.2)

train_benign_masks_ft = match_masks(train_benign_ft, finetune_busi_masks["benign"], clean_busi_masks_dir)
val_benign_masks_ft = match_masks(val_benign_ft, finetune_busi_masks["benign"], clean_busi_masks_dir)
test_benign_masks_ft = match_masks(test_benign_ft, finetune_busi_masks["benign"], clean_busi_masks_dir)

train_malignant_masks_ft = match_masks(train_malignant_ft, finetune_busi_masks["malignant"], clean_busi_masks_dir)
val_malignant_masks_ft = match_masks(val_malignant_ft, finetune_busi_masks["malignant"], clean_busi_masks_dir)
test_malignant_masks_ft = match_masks(test_malignant_ft, finetune_busi_masks["malignant"], clean_busi_masks_dir)

train_normal_masks_ft = match_masks(train_normal_ft, finetune_busi_masks["normal"], clean_busi_masks_dir)
val_normal_masks_ft = match_masks(val_normal_ft, finetune_busi_masks["normal"], clean_busi_masks_dir)
test_normal_masks_ft = match_masks(test_normal_ft, finetune_busi_masks["normal"], clean_busi_masks_dir)

In [None]:
def plot_pie_chart(class_name, train_count, val_count, test_count=None):
    sizes = [train_count, val_count] if test_count is None else [train_count, val_count, test_count]
    labels = ['Train', 'Validation'] if test_count is None else ['Train', 'Validation', 'Test']
    colors = ['gold', 'lightcoral', 'lightskyblue']

    plt.figure(figsize=(5, 5))
    plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=140)
    plt.title(f'{class_name.capitalize()} Class Distribution')
    plt.axis('equal')
    plt.show()

# Plot for each class in pretrain
plot_pie_chart('Benign', len(train_benign), len(val_benign))
plot_pie_chart('Malignant', len(train_malignant), len(val_malignant))
plot_pie_chart('Normal', len(train_normal), len(val_normal))

# Plot for each class in finetune
plot_pie_chart('Benign', len(train_benign_ft), len(val_benign_ft), len(test_benign_ft))
plot_pie_chart('Malignant', len(train_malignant_ft), len(val_malignant_ft), len(test_malignant_ft))
plot_pie_chart('Normal', len(train_normal_ft), len(val_normal_ft), len(test_normal_ft))


In [None]:
# Assuming pretrain phase has been split into these variables
total_train_pretrain = len(train_benign) + len(train_malignant) + len(train_normal)
total_val_pretrain = len(val_benign) + len(val_malignant) + len(val_normal)

# Plotting the overall distribution
def plot_overall_pie_chart(train_count, val_count):
    sizes = [train_count, val_count]
    labels = ['Train', 'Validation']
    colors = ['lightcoral', 'lightskyblue']
    
    plt.figure(figsize=(5, 5))
    plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=140)
    plt.title('Overall Distribution in Pretraining Phase')
    plt.axis('equal')
    plt.show()

plot_overall_pie_chart(total_train_pretrain, total_val_pretrain)


In [None]:
# Assuming finetune phase has been split into these variables
total_train_finetune = len(train_benign_ft) + len(train_malignant_ft) + len(train_normal_ft)
total_val_finetune = len(val_benign_ft) + len(val_malignant_ft) + len(val_normal_ft)
total_test_finetune = len(test_benign_ft) + len(test_malignant_ft) + len(test_normal_ft)

# Plotting the overall distribution
def plot_overall_pie_chart_finetune(train_count, val_count, test_count):
    sizes = [train_count, val_count, test_count]
    labels = ['Train', 'Validation', 'Test']
    colors = ['lightcoral', 'lightskyblue', 'gold']
    
    plt.figure(figsize=(5, 5))
    plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=140)
    plt.title('Overall Distribution in Finetuning Phase')
    plt.axis('equal')
    plt.show()

plot_overall_pie_chart_finetune(total_train_finetune, total_val_finetune, total_test_finetune)

In [None]:
# Ultrasound Dataset

# Pretrain data: train and validation images only (no subfolders by class)
for split, images in [("train", train_benign + train_malignant + train_normal),
                      ("validation", val_benign + val_malignant + val_normal)]:
    split_dir = os.path.join(ultrasound_pretrain_dir, split)
    os.makedirs(split_dir, exist_ok=True)
    for img_path in images:
        shutil.copy(img_path, split_dir)

# Finetune data: Organized into train, validation, test with class-specific subfolders for images and masks
for split, class_data in [("train", [(train_benign_ft, train_benign_masks_ft, "benign"),
                                     (train_malignant_ft, train_malignant_masks_ft, "malignant"),
                                     (train_normal_ft, train_normal_masks_ft, "normal")]),
                          ("validation", [(val_benign_ft, val_benign_masks_ft, "benign"),
                                          (val_malignant_ft, val_malignant_masks_ft, "malignant"),
                                          (val_normal_ft, val_normal_masks_ft, "normal")]),
                          ("test", [(test_benign_ft, test_benign_masks_ft, "benign"),
                                    (test_malignant_ft, test_malignant_masks_ft, "malignant"),
                                    (test_normal_ft, test_normal_masks_ft, "normal")])]:
    
    split_dir = os.path.join(ultrasound_finetune_dir, split)
    for images, masks, class_name in class_data:
        class_img_dir = os.path.join(split_dir, class_name, "images")
        class_mask_dir = os.path.join(split_dir, class_name, "masks")
        os.makedirs(class_img_dir, exist_ok=True)
        os.makedirs(class_mask_dir, exist_ok=True)
        
        # Copy images
        for img_path in images:
            shutil.copy(img_path, class_img_dir)
        
        # Copy corresponding masks
        for mask_path in masks:
            shutil.copy(mask_path, class_mask_dir)

#### 2.2.3. Mammography dataset

In [None]:
malignant_images = image_variables['CBIS-DDSM']['malignant']

# Select specified amount random malignant images to set aside
donated_malignant_images = random.sample(malignant_images, 63)

# Create new variables for malignant images
pretrain_cbis_malignant_images = [img for img in malignant_images if img not in donated_malignant_images]

In [None]:
# Store the paths of the corresponding masks for the donated malignant images
finetune_cbis_malignant_masks = []

def extract_cbis_id(filename):
    # Match sequence of digits at the start of the filename
    match = re.search(r'_id(\d+)_', filename)
    return match.group(1) if match else None

for image_path in donated_malignant_images:
    
    image_filename = os.path.basename(image_path)
    image_id = extract_cbis_id(image_filename)
    if image_id:
        for mask_name in os.listdir(clean_cbisddsm_masks_dir):
            if re.search(f"id{image_id}_malignant", mask_name):
                mask_path = os.path.join(clean_cbisddsm_masks_dir, mask_name)
                finetune_cbis_malignant_masks.append(mask_path)
                break
        else:
            print(f"No mask found for image with ID {image_id}")

In [None]:
# Creating individual variables for INBreast dataset images and masks
finetune_inbreast_benign_images = image_variables['INBreast']['benign']
finetune_inbreast_malignant_images = image_variables['INBreast']['malignant']
finetune_inbreast_normal_images = image_variables['INBreast']['normal']
finetune_inbreast_benign_masks = image_variables['INBreast_masks']['benign_mask']
finetune_inbreast_malignant_masks = image_variables['INBreast_masks']['malignant_mask']
finetune_inbreast_normal_masks = image_variables['INBreast_masks']['normal_mask']

finetune_cbis_malignant_images = donated_malignant_images

# Creating individual variables for CBIS-DDSM dataset images and masks
pretrain_cbis_benign_images = image_variables['CBIS-DDSM']['benign']
pretrain_cbis_normal_images = image_variables['CBIS-DDSM']['normal']

print(f"INBreast Benign Images: {len(finetune_inbreast_benign_images)}")
print(f"INBreast Malignant Images: {len(finetune_inbreast_malignant_images)}")
print(f"INBreast Normal Images: {len(finetune_inbreast_normal_images)}")
print(f"INBreast Benign Masks: {len(finetune_inbreast_benign_masks)}")
print(f"INBreast Malignant Masks: {len(finetune_inbreast_malignant_masks)}")
print(f"INBreast Normal Masks: {len(finetune_inbreast_normal_masks)}\n")

print(f"CBIS Normal Images: {len(pretrain_cbis_normal_images)}")
print(f"CBIS Benign Images: {len(pretrain_cbis_benign_images)}")
print(f"CBIS Malignant Images (Full Set Minus Donated): {len(pretrain_cbis_malignant_images)}")
print(f"CBIS Donated Malignant Images: {len(finetune_cbis_malignant_images)}")
print(f"CBIS Donated Malignant Masks: {len(finetune_cbis_malignant_masks)}")

In [None]:
# 1. Discard specified amount of normal CBIS-DDSM images
discarded_normals_254 = random.sample(pretrain_cbis_normal_images, 254)
pretrain_cbis_normal_images = [img for img in pretrain_cbis_normal_images if img not in discarded_normals_254]

# 2. Discard specified amount of benign CBIS-DDSM images
discarded_benign_589 = random.sample(pretrain_cbis_benign_images, 589)
pretrain_cbis_benign_images = [img for img in pretrain_cbis_benign_images if img not in discarded_benign_589]

# 3. Discard additional specified amount of malignant CBIS-DDSM images
discarded_malignant_629 = random.sample(pretrain_cbis_malignant_images, 629)
pretrain_cbis_malignant_images = [img for img in pretrain_cbis_malignant_images if img not in discarded_malignant_629]

In [None]:
def discard_images_and_masks(images_list, masks_list, num_discard):

    discarded_images = random.sample(images_list, num_discard)
    discarded_ids = {os.path.basename(path).split('_')[0] for path in discarded_images}

    # Filter out discarded images
    updated_images = [img for img in images_list if os.path.basename(img).split('_')[0] not in discarded_ids]

    # Filter out corresponding masks by matching IDs
    updated_masks = [mask for mask in masks_list if os.path.basename(mask).split('_')[0] not in discarded_ids]

    return updated_images, updated_masks

# 4. Discard 34 normal INBreast images and their corresponding masks
finetune_inbreast_normal_images, finetune_inbreast_normal_masks = discard_images_and_masks(
    finetune_inbreast_normal_images, finetune_inbreast_normal_masks, 34
)

# 5. Discard 40 benign images and their corresponding masks
finetune_inbreast_benign_images, finetune_inbreast_benign_masks = discard_images_and_masks(
    finetune_inbreast_benign_images, finetune_inbreast_benign_masks, 40
)

combined_finetune_malignant_masks = finetune_inbreast_malignant_masks + finetune_cbis_malignant_masks

combined_finetune_malignant_images = finetune_inbreast_malignant_images + finetune_cbis_malignant_images

In [None]:
### Verification

print(f"Final Pretrain CBIS Normal Images: {len(pretrain_cbis_normal_images)}")
print(f"Final Pretrain CBIS Benign Images: {len(pretrain_cbis_benign_images)}")
print(f"Final Pretrain CBIS Malignant Images: {len(pretrain_cbis_malignant_images)}")
print(f"Final Finetune CBIS Malignant Images: {len(finetune_cbis_malignant_images)}")
print(f"Final Finetune CBIS Malignant Masks: {len(finetune_cbis_malignant_masks)}")
print(f"Final Finetune INBreast Normal Images: {len(finetune_inbreast_normal_images)}")
print(f"Final Finetune INBreast Benign Images: {len(finetune_inbreast_benign_images)}")
print(f"Final Finetune INBreast Malignant Images: {len(finetune_inbreast_malignant_images)}")
print(f"Final Finetune INBreast Benign Masks: {len(finetune_inbreast_benign_masks)}")
print(f"Final Finetune INBreast Normal Masks: {len(finetune_inbreast_normal_masks)}")
print(f"Final Finetune INBreast Malignant Masks: {len(finetune_inbreast_malignant_masks)}")
print(f"Final Finetune Combined Malignant Images: {len(combined_finetune_malignant_images)}")
print(f"Final Finetune Combined Malignant Masks: {len(combined_finetune_malignant_masks)}")

In [None]:
def split_images(images, train_ratio, val_ratio, test_ratio=0):
    if len(images) < 2:
        return images, [], [] if test_ratio > 0 else images, []
    
    train_images, temp_images = train_test_split(images, train_size=train_ratio, random_state=42)
    
    if test_ratio > 0:
        val_images, test_images = train_test_split(temp_images, test_size=test_ratio / (val_ratio + test_ratio), random_state=42)
        return train_images, val_images, test_images
    else:
        val_images = temp_images
        return train_images, val_images

def match_masks_mammography(image_list, masks, masks_folder, dataset_type):
    matched_masks = []
    
    for img in image_list:
        if dataset_type == "inbreast":
            parts = img.split('_')
            img_id = parts[0]  # The ID part
            if 'normal' in img:  # Check if it's a normal case
                expected_mask_name = f"{img_id}_normal_syntheticMask.png"
            else:
                img_class = parts[-1]  # The class part
                expected_mask_name = f"{img_id}_{img_class}_mask.png"

        elif dataset_type == "cbis":
            img_id = [part for part in img.split('_') if 'id' in part][0].replace('id', '')
            img_class = img.split('_')[-1]  # The class part
            expected_mask_name = f"*_{img_id}_{img_class}_mask.png"  # Use wildcard to ignore variable prefix

        # Construct the full mask path
        expected_mask_path = f"{masks_folder}/{expected_mask_name}"

        # Check expected mask path exists in the masks list
        if expected_mask_path in masks:
            matched_masks.append(expected_mask_path)
        else:
            matched_masks.append(None)

    return matched_masks

In [None]:
def match_masks_mammography_inbreast(image_list, masks):
    matched_masks = []

    for img in image_list:
        img_filename = os.path.basename(img)  # Get only the filename
        img_id = img_filename.split('_')[0]  # ID is the part before the first underscore
        img_class = img_filename.split('_')[-1].replace(".png", "")  # Class is the last part of the filename

        search_string = f"{img_id}_{img_class}"

        mask_found = False

        for mask in masks:
            mask_filename = os.path.basename(mask)  # Get only the filename

            # Check search string is present in the mask filename
            if search_string in mask_filename:
                matched_masks.append(mask)
                mask_found = True
                break
        
        if not mask_found:
            matched_masks.append(None)
            print(f"No match found for {img} -> expected mask with ID {img_id} and class {img_class}")

    return matched_masks

def match_masks_mammography_cbis(image_list, masks):
    matched_masks = []

    for img in image_list:
        parts = img.split('_')
        img_id = parts[1]  # ID is right after the first underscore
        img_class = parts[-1].replace(".png", "")  # Class is the last part of the filename

        mask_found = False
        
        for mask in masks:
            mask_parts = mask.split('_')
            mask_id = mask_parts[1]  # ID follows the first underscore in mask

            # Check IDs match and if the mask class matches
            if mask_id == img_id and img_class in mask:
                matched_masks.append(mask)
                mask_found = True
                break
        
        if not mask_found:
            matched_masks.append(None)
            print(f"No match found for {img} -> expected mask pattern with ID {img_id} and class {img_class}")

    return matched_masks

In [None]:
# 1. Pretrain stuff
mm_train_benign, mm_val_benign = split_images(pretrain_cbis_benign_images, 0.8, 0.2)
mm_train_malignant, mm_val_malignant = split_images(pretrain_cbis_malignant_images, 0.8, 0.2)
mm_train_normal, mm_val_normal = split_images(pretrain_cbis_normal_images, 0.8, 0.2)

# 2 Finetune stuff
# 2.1. Benign
mm_train_benign_ft, mm_val_benign_ft, mm_test_benign_ft = split_images(finetune_inbreast_benign_images, 0.6, 0.2, 0.2)
mm_train_benign_masks_ft = match_masks_mammography_inbreast(mm_train_benign_ft, finetune_inbreast_benign_masks)
mm_val_benign_masks_ft = match_masks_mammography_inbreast(mm_val_benign_ft, finetune_inbreast_benign_masks)
mm_test_benign_masks_ft = match_masks_mammography_inbreast(mm_test_benign_ft, finetune_inbreast_benign_masks)

# 2.2. Normal
mm_train_normal_ft, mm_val_normal_ft, mm_test_normal_ft = split_images(finetune_inbreast_normal_images, 0.6, 0.2, 0.2)
mm_train_normal_masks_ft = match_masks_mammography_inbreast(mm_train_normal_ft, finetune_inbreast_normal_masks)
mm_val_normal_masks_ft = match_masks_mammography_inbreast(mm_val_normal_ft, finetune_inbreast_normal_masks)
mm_test_normal_masks_ft = match_masks_mammography_inbreast(mm_test_normal_ft, finetune_inbreast_normal_masks)

# 2.3. Malignant (some cbis images/masks and some inbreast images/masks)
mm_train_cbis_malignant_aux, mm_val_cbis_malignant_aux, mm_test_cbis_malignant_aux = split_images(finetune_cbis_malignant_images, 0.6, 0.2, 0.2)
mm_train_cbis_malignant_masks_aux = match_masks_mammography_cbis(mm_train_cbis_malignant_aux, finetune_cbis_malignant_masks)
mm_val_cbis_malignant_masks_aux = match_masks_mammography_cbis(mm_val_cbis_malignant_aux, finetune_cbis_malignant_masks)
mm_test_cbis_malignant_masks_aux = match_masks_mammography_cbis(mm_test_cbis_malignant_aux, finetune_cbis_malignant_masks)

mm_train_inbreast_malignant_aux, mm_val_inbreast_malignant_aux, mm_test_inbreast_malignant_aux = split_images(finetune_inbreast_malignant_images, 0.6, 0.2, 0.2)
mm_train_inbreast_malignant_masks_aux = match_masks_mammography_inbreast(mm_train_inbreast_malignant_aux, finetune_inbreast_malignant_masks)
mm_val_inbreast_malignant_masks_aux = match_masks_mammography_inbreast(mm_val_inbreast_malignant_aux, finetune_inbreast_malignant_masks)
mm_test_inbreast_malignant_masks_aux = match_masks_mammography_inbreast(mm_test_inbreast_malignant_aux, finetune_inbreast_malignant_masks)

# Merge sets of images
mm_train_malignant_ft = mm_train_cbis_malignant_aux + mm_train_inbreast_malignant_aux
mm_val_malignant_ft = mm_val_cbis_malignant_aux + mm_val_inbreast_malignant_aux
mm_test_malignant_ft = mm_test_cbis_malignant_aux + mm_test_inbreast_malignant_aux

# Merge sets of masks
mm_train_malignant_masks_ft = mm_train_cbis_malignant_masks_aux + mm_train_inbreast_malignant_masks_aux
mm_val_malignant_masks_ft = mm_val_cbis_malignant_masks_aux + mm_val_inbreast_malignant_masks_aux
mm_test_malignant_masks_ft = mm_test_cbis_malignant_masks_aux + mm_test_inbreast_malignant_masks_aux

In [None]:
# Expected counts
expected_pretrain_counts = {
    'benign': {'train': 513, 'val': 129},
    'normal': {'train': 83, 'val': 21},
    'malignant': {'train': 396, 'val': 99}
}

expected_finetune_counts = {
    'benign': {'train': 121, 'val': 41, 'test': 41},
    'normal': {'train': 19, 'val': 7, 'test': 7},
    'malignant': {
        'cbis': {'train': 37, 'val': 13, 'test': 13},
        'inbreast': {'train': 60, 'val': 20, 'test': 20}
    }
}

expected_train_benign, expected_val_benign = len(mm_train_benign), len(mm_val_benign)
expected_train_malignant, expected_val_malignant = len(mm_train_malignant), len(mm_val_malignant)
expected_train_normal, expected_val_normal = len(mm_train_normal), len(mm_val_normal)

# Finetune splits for benign
expected_train_benign_ft, expected_val_benign_ft, expected_test_benign_ft = (
    len(mm_train_benign_ft), len(mm_val_benign_ft), len(mm_test_benign_ft)
)

# Finetune splits for normal
expected_train_normal_ft, expected_val_normal_ft, expected_test_normal_ft = (
    len(mm_train_normal_ft), len(mm_val_normal_ft), len(mm_test_normal_ft)
)

# Finetune splits for malignant (both CBIS and INBreast)
expected_train_malignant_ft, expected_val_malignant_ft, expected_test_malignant_ft = (
    len(train_malignant_ft), len(val_malignant_ft), len(test_malignant_ft)
)

# Checks
# Pretrain splits
assert len(mm_train_benign) == expected_pretrain_counts['benign']['train'], "Mismatch in train_benign count for pretrain"
assert len(mm_val_benign) == expected_pretrain_counts['benign']['val'], "Mismatch in val_benign count for pretrain"
assert len(mm_train_normal) == expected_pretrain_counts['normal']['train'], "Mismatch in train_normal count for pretrain"
assert len(mm_val_normal) == expected_pretrain_counts['normal']['val'], "Mismatch in val_normal count for pretrain"
assert len(mm_train_malignant) == expected_pretrain_counts['malignant']['train'], "Mismatch in train_malignant count for pretrain"
assert len(mm_val_malignant) == expected_pretrain_counts['malignant']['val'], "Mismatch in val_malignant count for pretrain"

# Finetune splits for benign and normal
assert len(mm_train_benign_ft) == expected_finetune_counts['benign']['train'], "Mismatch in train_benign_ft count for finetune"
assert len(mm_val_benign_ft) == expected_finetune_counts['benign']['val'], "Mismatch in val_benign_ft count for finetune"
assert len(mm_test_benign_ft) == expected_finetune_counts['benign']['test'], "Mismatch in test_benign_ft count for finetune"

assert len(mm_train_normal_ft) == expected_finetune_counts['normal']['train'], "Mismatch in train_normal_ft count for finetune"
assert len(mm_val_normal_ft) == expected_finetune_counts['normal']['val'], "Mismatch in val_normal_ft count for finetune"
assert len(mm_test_normal_ft) == expected_finetune_counts['normal']['test'], "Mismatch in test_normal_ft count for finetune"

# Finetune splits for malignant from CBIS and INBreast
assert len(mm_train_cbis_malignant_aux) == expected_finetune_counts['malignant']['cbis']['train'], "Mismatch in train_cbis_malignant_aux count for finetune"
assert len(mm_val_cbis_malignant_aux) == expected_finetune_counts['malignant']['cbis']['val'], "Mismatch in val_cbis_malignant_aux count for finetune"
assert len(mm_test_cbis_malignant_aux) == expected_finetune_counts['malignant']['cbis']['test'], "Mismatch in test_cbis_malignant_aux count for finetune"

assert len(mm_train_inbreast_malignant_aux) == expected_finetune_counts['malignant']['inbreast']['train'], "Mismatch in train_inbreast_malignant_aux count for finetune"
assert len(mm_val_inbreast_malignant_aux) == expected_finetune_counts['malignant']['inbreast']['val'], "Mismatch in val_inbreast_malignant_aux count for finetune"
assert len(mm_test_inbreast_malignant_aux) == expected_finetune_counts['malignant']['inbreast']['test'], "Mismatch in test_inbreast_malignant_aux count for finetune"

# Final finetune malignant combined checks
assert len(mm_train_malignant_ft) == expected_finetune_counts['malignant']['cbis']['train'] + expected_finetune_counts['malignant']['inbreast']['train'], "Mismatch in combined train_malignant_ft count"
assert len(mm_val_malignant_ft) == expected_finetune_counts['malignant']['cbis']['val'] + expected_finetune_counts['malignant']['inbreast']['val'], "Mismatch in combined val_malignant_ft count"
assert len(mm_test_malignant_ft) == expected_finetune_counts['malignant']['cbis']['test'] + expected_finetune_counts['malignant']['inbreast']['test'], "Mismatch in combined test_malignant_ft count"

# Mask alignment checks
assert len(mm_train_benign_ft) == len(mm_train_benign_masks_ft), "Mismatch in benign train image-mask count"
assert len(mm_val_benign_ft) == len(mm_val_benign_masks_ft), "Mismatch in benign validation image-mask count"
assert len(mm_test_benign_ft) == len(mm_test_benign_masks_ft), "Mismatch in benign test image-mask count"

assert len(mm_train_normal_ft) == len(mm_train_normal_masks_ft), "Mismatch in normal train image-mask count"
assert len(mm_val_normal_ft) == len(mm_val_normal_masks_ft), "Mismatch in normal validation image-mask count"
assert len(mm_test_normal_ft) == len(mm_test_normal_masks_ft), "Mismatch in normal test image-mask count"

# Malignant masks for finetune
assert len(mm_train_malignant_ft) == len(mm_train_malignant_masks_ft), "Mismatch in malignant train image-mask count"
assert len(mm_val_malignant_ft) == len(mm_val_malignant_masks_ft), "Mismatch in malignant validation image-mask count"
assert len(mm_test_malignant_ft) == len(mm_test_malignant_masks_ft), "Mismatch in malignant test image-mask count"

print("All clear")

# 2. Assertions for Image-Mask Pair Matching
# Finetune benign
assert len(mm_train_benign_ft) == len(mm_train_benign_masks_ft), \
    f"Expected {len(mm_train_benign_ft)} masks, found {len(mm_train_benign_masks_ft)} for train_benign_ft"
assert len(mm_val_benign_ft) == len(mm_val_benign_masks_ft), \
    f"Expected {len(mm_val_benign_ft)} masks, found {len(mm_val_benign_masks_ft)} for val_benign_ft"
assert len(mm_test_benign_ft) == len(mm_test_benign_masks_ft), \
    f"Expected {len(mm_test_benign_ft)} masks, found {len(mm_test_benign_masks_ft)} for test_benign_ft"

# Finetune normal
assert len(mm_train_normal_ft) == len(mm_train_normal_masks_ft), \
    f"Expected {len(mm_train_normal_ft)} masks, found {len(mm_train_normal_masks_ft)} for train_normal_ft"
assert len(mm_val_normal_ft) == len(val_normal_masks_ft), \
    f"Expected {len(mm_val_normal_ft)} masks, found {len(mm_val_normal_masks_ft)} for val_normal_ft"
assert len(mm_test_normal_ft) == len(mm_test_normal_masks_ft), \
    f"Expected {len(mm_test_normal_ft)} masks, found {len(mm_test_normal_masks_ft)} for test_normal_ft"

# Finetune malignant
assert len(mm_train_malignant_ft) == len(mm_train_malignant_masks_ft), \
    f"Expected {len(mm_train_malignant_ft)} masks, found {len(mm_train_malignant_masks_ft)} for train_malignant_ft"
assert len(mm_val_malignant_ft) == len(mm_val_malignant_masks_ft), \
    f"Expected {len(mm_val_malignant_ft)} masks, found {len(mm_val_malignant_masks_ft)} for val_malignant_ft"
assert len(mm_test_malignant_ft) == len(mm_test_malignant_masks_ft), \
    f"Expected {len(mm_test_malignant_ft)} masks, found {len(mm_test_malignant_masks_ft)} for test_malignant_ft"

print("All clear: counts of images and masks match across variables.")

In [None]:
plot_pie_chart('Benign', len(mm_train_benign), len(mm_val_benign))
plot_pie_chart('Malignant', len(mm_train_malignant), len(mm_val_malignant))
plot_pie_chart('Normal', len(mm_train_normal), len(mm_val_normal))

plot_pie_chart('Benign', len(mm_train_benign_ft), len(mm_val_benign_ft), len(mm_test_benign_ft))
plot_pie_chart('Normal', len(mm_train_normal_ft), len(mm_val_normal_ft), len(mm_test_normal_ft))
plot_pie_chart('Malignant', len(mm_train_malignant_ft), len(mm_val_malignant_ft), len(mm_test_malignant_ft))

In [None]:
# Assuming pretrain phase has been split into these variables
mm_total_train_pretrain = len(mm_train_benign) + len(mm_train_malignant) + len(mm_train_normal)
mm_total_val_pretrain = len(mm_val_benign) + len(mm_val_malignant) + len(mm_val_normal)

def plot_overall_pie_chart(train_count, val_count):
    sizes = [train_count, val_count]
    labels = ['Train', 'Validation']
    colors = ['lightcoral', 'lightskyblue']
    
    plt.figure(figsize=(5, 5))
    plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=140)
    plt.title('Overall Distribution in Pretraining Phase')
    plt.axis('equal')
    plt.show()

plot_overall_pie_chart(mm_total_train_pretrain, mm_total_val_pretrain)

In [None]:
# Assuming finetune phase has been split into these variables
mm_total_train_finetune = len(mm_train_benign_ft) + len(mm_train_malignant_ft) + len(mm_train_normal_ft)
mm_total_val_finetune = len(mm_val_benign_ft) + len(mm_val_malignant_ft) + len(mm_val_normal_ft)
mm_total_test_finetune = len(mm_test_benign_ft) + len(mm_test_malignant_ft) + len(mm_test_normal_ft)

def plot_overall_pie_chart_finetune(train_count, val_count, test_count):
    sizes = [train_count, val_count, test_count]
    labels = ['Train', 'Validation', 'Test']
    colors = ['lightcoral', 'lightskyblue', 'gold']
    
    plt.figure(figsize=(5, 5))
    plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=140)
    plt.title('Overall Distribution in Finetuning Phase')
    plt.axis('equal')
    plt.show()

plot_overall_pie_chart_finetune(mm_total_train_finetune, mm_total_val_finetune, mm_total_test_finetune)

In [None]:
# Pretrain data: train and validation images only (no subfolders by class)
for split, images in [("train", mm_train_benign + mm_train_malignant + mm_train_normal),
                      ("validation", mm_val_benign + mm_val_malignant + mm_val_normal)]:
    split_dir = os.path.join(mammography_pretrain_dir, split)
    os.makedirs(split_dir, exist_ok=True)
    for img_path in images:
        shutil.copy(img_path, split_dir)

# Finetune data: Organized into train, validation, test with class-specific subfolders for images and masks
for split, class_data in [("train", [(mm_train_benign_ft, mm_train_benign_masks_ft, "benign"),
                                     (mm_train_malignant_ft, mm_train_malignant_masks_ft, "malignant"),
                                     (mm_train_normal_ft, mm_train_normal_masks_ft, "normal")]),
                          ("validation", [(mm_val_benign_ft, mm_val_benign_masks_ft, "benign"),
                                          (mm_val_malignant_ft, mm_val_malignant_masks_ft, "malignant"),
                                          (mm_val_normal_ft, mm_val_normal_masks_ft, "normal")]),
                          ("test", [(mm_test_benign_ft, mm_test_benign_masks_ft, "benign"),
                                    (mm_test_malignant_ft, mm_test_malignant_masks_ft, "malignant"),
                                    (mm_test_normal_ft, mm_test_normal_masks_ft, "normal")])]:
    
    split_dir = os.path.join(mammography_finetune_dir, split)
    for images, masks, class_name in class_data:
        class_img_dir = os.path.join(split_dir, class_name, "images")
        class_mask_dir = os.path.join(split_dir, class_name, "masks")
        os.makedirs(class_img_dir, exist_ok=True)
        os.makedirs(class_mask_dir, exist_ok=True)
        
        # Copy images
        for img_path in images:
            shutil.copy(img_path, class_img_dir)
        
        # Copy corresponding masks
        for mask_path in masks:
            shutil.copy(mask_path, class_mask_dir)

In [41]:
# List of variables to check for None values
variables = {
    "train_benign": mm_train_benign, "val_benign": mm_val_benign,
    "train_malignant": mm_train_malignant, "val_malignant": mm_val_malignant,
    "train_normal": mm_train_normal, "val_normal": mm_val_normal,
    
    # Finetune variables
    "train_benign_ft": mm_train_benign_ft, "val_benign_ft": mm_val_benign_ft, "test_benign_ft": mm_test_benign_ft,
    "train_benign_masks_ft": mm_train_benign_masks_ft, "val_benign_masks_ft": mm_val_benign_masks_ft, "test_benign_masks_ft": mm_test_benign_masks_ft,
    
    "train_normal_ft": mm_train_normal_ft, "val_normal_ft": mm_val_normal_ft, "test_normal_ft": mm_test_normal_ft,
    "train_normal_masks_ft": mm_train_normal_masks_ft, "val_normal_masks_ft": mm_val_normal_masks_ft, "test_normal_masks_ft": mm_test_normal_masks_ft,
    
    # Malignant variables with auxiliary splits
    "train_cbis_malignant_aux": mm_train_cbis_malignant_aux, "val_cbis_malignant_aux": mm_val_cbis_malignant_aux, "test_cbis_malignant_aux": mm_test_cbis_malignant_aux,
    "train_cbis_malignant_masks_aux": mm_train_cbis_malignant_masks_aux, "val_cbis_malignant_masks_aux": mm_val_cbis_malignant_masks_aux, "test_cbis_malignant_masks_aux": mm_test_cbis_malignant_masks_aux,
    
    "train_inbreast_malignant_aux": mm_train_inbreast_malignant_aux, "val_inbreast_malignant_aux": mm_val_inbreast_malignant_aux, "test_inbreast_malignant_aux": mm_test_inbreast_malignant_aux,
    "train_inbreast_malignant_masks_aux": mm_train_inbreast_malignant_masks_aux, "val_inbreast_malignant_masks_aux": mm_val_inbreast_malignant_masks_aux, "test_inbreast_malignant_masks_aux": mm_test_inbreast_malignant_masks_aux
}

# Check for None values in each variable
for var_name, paths in variables.items():
    if any(path is None for path in paths):
        print(f"Warning: None value(s) found in variable '{var_name}'")


#### 2.2.4. Multimodal dataset

In [32]:
# Function to create a 50% random sample from each given list
def random_50_percent(data_list):
    # Calculate half the length of the list, rounded down for odd counts
    sample_size = len(data_list) // 2
    # Randomly select sample_size items from the data_list
    return random.sample(data_list, sample_size)

# Function to pick 50% of images and get the corresponding masks
def get_half_images_and_masks(image_list, mask_list, match_function, **match_args):
    # Select 50% of the images randomly
    half_image_list = random.sample(image_list, len(image_list) // 2)
    
    # Use the matching function to find corresponding masks for the 50% images
    half_mask_list = match_function(half_image_list, mask_list, **match_args)
    
    return half_image_list, half_mask_list


In [33]:
# Pretrain
# Ultra
us_train_benign_50 = random_50_percent(train_benign)
us_train_malignant_50 = random_50_percent(train_malignant)
us_train_normal_50 = random_50_percent(train_normal)
us_val_benign_50 = random_50_percent(val_benign)
us_val_malignant_50 = random_50_percent(val_malignant)
us_val_normal_50 = random_50_percent(val_normal)

# Mammo
mm_train_benign_50 = random_50_percent(mm_train_benign)
mm_train_malignant_50 = random_50_percent(mm_train_malignant)
mm_train_normal_50 = random_50_percent(mm_train_normal)
mm_val_benign_50 = random_50_percent(mm_val_benign)
mm_val_malignant_50 = random_50_percent(mm_val_malignant)
mm_val_normal_50 = random_50_percent(mm_val_normal)

# Finetune
# Ultrasound
us_train_benign_ft_50, us_train_benign_masks_ft_50 = get_half_images_and_masks(
    train_benign_ft,
    finetune_busi_masks["benign"],
    match_function=match_masks,
    masks_folder=clean_busi_masks_dir
)

us_train_malignant_ft_50, us_train_malignant_masks_ft_50 = get_half_images_and_masks(
    train_malignant_ft,
    finetune_busi_masks["malignant"],
    match_function=match_masks,
    masks_folder=clean_busi_masks_dir
)

us_train_normal_ft_50, us_train_normal_masks_ft_50 = get_half_images_and_masks(
    train_normal_ft,
    finetune_busi_masks["normal"],
    match_function=match_masks,
    masks_folder=clean_busi_masks_dir
)

us_val_benign_ft_50, us_val_benign_masks_ft_50 = get_half_images_and_masks(
    val_benign_ft,
    finetune_busi_masks["benign"],
    match_function=match_masks,
    masks_folder=clean_busi_masks_dir
)

us_val_malignant_ft_50, us_val_malignant_masks_ft_50 = get_half_images_and_masks(
    val_malignant_ft,
    finetune_busi_masks["malignant"],
    match_function=match_masks,
    masks_folder=clean_busi_masks_dir
)

us_val_normal_ft_50, us_val_normal_masks_ft_50 = get_half_images_and_masks(
    val_normal_ft,
    finetune_busi_masks["normal"],
    match_function=match_masks,
    masks_folder=clean_busi_masks_dir
)

us_test_benign_ft_50, us_test_benign_masks_ft_50 = get_half_images_and_masks(
    test_benign_ft,
    finetune_busi_masks["benign"],
    match_function=match_masks,
    masks_folder=clean_busi_masks_dir
)

us_test_malignant_ft_50, us_test_malignant_masks_ft_50 = get_half_images_and_masks(
    test_malignant_ft,
    finetune_busi_masks["malignant"],
    match_function=match_masks,
    masks_folder=clean_busi_masks_dir
)

us_test_normal_ft_50, us_test_normal_masks_ft_50 = get_half_images_and_masks(
    test_normal_ft,
    finetune_busi_masks["normal"],
    match_function=match_masks,
    masks_folder=clean_busi_masks_dir
)

# Mammography
mm_train_benign_ft_50, mm_train_benign_masks_ft_50 = get_half_images_and_masks(
    mm_train_benign_ft,
    finetune_inbreast_benign_masks,
    match_function=match_masks_mammography_inbreast
)

mm_train_normal_ft_50, mm_train_normal_masks_ft_50 = get_half_images_and_masks(
    mm_train_normal_ft,
    finetune_inbreast_normal_masks,
    match_function=match_masks_mammography_inbreast
)

mm_val_benign_ft_50, mm_val_benign_masks_ft_50 = get_half_images_and_masks(
    mm_val_benign_ft,
    finetune_inbreast_benign_masks,
    match_function=match_masks_mammography_inbreast
)

mm_val_normal_ft_50, mm_val_normal_masks_ft_50 = get_half_images_and_masks(
    mm_val_normal_ft,
    finetune_inbreast_normal_masks,
    match_function=match_masks_mammography_inbreast
)

mm_test_benign_ft_50, mm_test_benign_masks_ft_50 = get_half_images_and_masks(
    mm_test_benign_ft,
    finetune_inbreast_benign_masks,
    match_function=match_masks_mammography_inbreast
)

mm_test_normal_ft_50, mm_test_normal_masks_ft_50 = get_half_images_and_masks(
    mm_test_normal_ft,
    finetune_inbreast_normal_masks,
    match_function=match_masks_mammography_inbreast
)

# Malignant case set aside
mm_train_inbreast_malignant_ft_50_aux, mm_train_inbreast_malignant_masks_ft_50_aux = get_half_images_and_masks(
    mm_train_inbreast_malignant_aux,
    finetune_inbreast_malignant_masks,
    match_function=match_masks_mammography_inbreast
)

mm_val_inbreast_malignant_ft_50_aux, mm_val_inbreast_malignant_masks_ft_50_aux = get_half_images_and_masks(
    mm_val_inbreast_malignant_aux,
    finetune_inbreast_malignant_masks,
    match_function=match_masks_mammography_inbreast
)

mm_test_inbreast_malignant_ft_50_aux, mm_test_inbreast_malignant_masks_ft_50_aux = get_half_images_and_masks(
    mm_test_inbreast_malignant_aux,
    finetune_inbreast_malignant_masks,
    match_function=match_masks_mammography_inbreast
)

mm_train_cbis_malignant_ft_50_aux, mm_train_cbis_malignant_masks_ft_50_aux = get_half_images_and_masks(
    mm_train_cbis_malignant_aux,
    finetune_cbis_malignant_masks,
    match_function=match_masks_mammography_cbis
)

mm_val_cbis_malignant_ft_50_aux, mm_val_cbis_malignant_masks_ft_50_aux = get_half_images_and_masks(
    mm_val_cbis_malignant_aux,
    finetune_cbis_malignant_masks,
    match_function=match_masks_mammography_cbis
)

mm_test_cbis_malignant_ft_50_aux, mm_test_cbis_malignant_masks_ft_50_aux = get_half_images_and_masks(
    mm_test_cbis_malignant_aux,
    finetune_cbis_malignant_masks,
    match_function=match_masks_mammography_cbis
)

# Merging of modalities for pretrain data 
multimodal_train_benign = us_train_benign_50 + mm_train_benign_50
multimodal_train_malignant = us_train_malignant_50 + mm_train_malignant_50
multimodal_train_normal = us_train_normal_50 + mm_train_normal_50
multimodal_val_benign = us_val_benign_50 + mm_val_benign_50
multimodal_val_malignant = us_val_malignant_50 + mm_val_malignant_50
multimodal_val_normal = us_val_normal_50 + mm_val_normal_50

# Now for finetune
multimodal_train_benign_ft = us_train_benign_ft_50 + mm_train_benign_ft_50
multimodal_train_benign_masks_ft = us_train_benign_masks_ft_50 + mm_train_benign_masks_ft_50
multimodal_train_malignant_ft = us_train_malignant_ft_50
multimodal_train_malignant_masks_ft = us_train_malignant_masks_ft_50
multimodal_train_normal_ft = us_train_normal_ft_50 + mm_train_normal_ft_50
multimodal_train_normal_masks_ft = us_train_normal_masks_ft_50 + mm_train_normal_masks_ft_50

multimodal_val_benign_ft = us_val_benign_ft_50 + mm_val_benign_ft_50
multimodal_val_benign_masks_ft = us_val_benign_masks_ft_50 + mm_val_benign_masks_ft_50
multimodal_val_malignant_ft = us_val_malignant_ft_50 + mm_val_cbis_malignant_ft_50_aux
multimodal_val_malignant_masks_ft = us_val_malignant_masks_ft_50
multimodal_val_normal_ft = us_val_normal_ft_50 + mm_val_normal_ft_50
multimodal_val_normal_masks_ft = mm_val_normal_masks_ft_50 + mm_val_normal_masks_ft_50

multimodal_test_benign_ft = us_test_benign_ft_50 + mm_test_benign_ft_50
multimodal_test_benign_masks_ft = mm_test_benign_masks_ft_50 + mm_test_benign_masks_ft_50
multimodal_test_malignant_ft = us_test_malignant_ft_50
multimodal_test_malignant_masks_ft = us_test_malignant_masks_ft_50 + mm_test_inbreast_malignant_ft_50_aux
multimodal_test_normal_ft = us_test_normal_ft_50 + mm_test_normal_ft_50
multimodal_test_normal_masks_ft = mm_test_normal_masks_ft_50 + mm_test_normal_masks_ft_50

In [None]:
print("us only finetune")
print(len(train_benign_ft), len(val_benign_ft), len(test_benign_ft))
print(len(train_malignant_ft), len(val_malignant_ft), len(test_malignant_ft))
print(len(train_normal_ft), len(val_normal_ft), len(test_normal_ft))

print("mm only finetune")
print(len(mm_train_benign_ft), len(mm_val_benign_ft), len(mm_test_benign_ft))
print(len(mm_train_normal_ft), len(mm_val_normal_ft), len(mm_test_normal_ft))
print(len(mm_train_inbreast_malignant_aux), len(mm_val_inbreast_malignant_aux), len(mm_test_inbreast_malignant_aux))
print(len(mm_train_cbis_malignant_aux), len(mm_val_cbis_malignant_aux), len(mm_test_cbis_malignant_aux))

print("50 picked from us only finetune")
print(len(us_train_benign_ft_50), len(us_val_benign_ft_50), len(us_test_benign_ft_50))
print(len(us_train_malignant_ft_50), len(us_val_malignant_ft_50), len(us_test_malignant_ft_50))
print(len(us_train_normal_ft_50), len(us_val_normal_ft_50), len(us_test_normal_ft_50))

print("50 picked mm only finetune")
print(len(mm_train_benign_ft_50), len(mm_val_benign_ft_50), len(mm_test_benign_ft_50))
print(len(mm_train_normal_ft_50), len(mm_val_normal_ft_50), len(mm_test_normal_ft_50))
print(len(mm_train_inbreast_malignant_ft_50_aux), len(mm_val_inbreast_malignant_ft_50_aux), len(mm_test_inbreast_malignant_ft_50_aux))
print(len(mm_train_cbis_malignant_ft_50_aux), len(mm_val_cbis_malignant_ft_50_aux), len(mm_test_cbis_malignant_ft_50_aux))

In [None]:
# Print count and a sample value
def print_info(variable, source_variable, variable_name):
    print(f"Count of {variable_name}: {len(variable)}")
    print(f"Count of source {variable_name}: {len(source_variable)}")
    if len(variable) > 0:
        print(f"Sample value from {variable_name}: {variable[0]}")
    else:
        print(f"No values in {variable_name}")

# Assuming random_50_percent is a function that takes a list and returns 50% random samples
# Replace train_benign, train_malignant, etc., with your actual dataset variables

# For Ultrasound Data
print_info(us_train_benign_50, train_benign, 'us_train_benign_50')
print_info(us_train_malignant_50, train_malignant, 'us_train_malignant_50')
print_info(us_train_normal_50, train_normal, 'us_train_normal_50')
print_info(us_val_benign_50, val_benign, 'us_val_benign_50')
print_info(us_val_malignant_50, val_malignant, 'us_val_malignant_50')
print_info(us_val_normal_50, val_normal, 'us_val_normal_50')

# For Mammography Data
print_info(mm_train_benign_50, mm_train_benign, 'mm_train_benign_50')
print_info(mm_train_malignant_50, mm_train_malignant, 'mm_train_malignant_50')
print_info(mm_train_normal_50, mm_train_normal, 'mm_train_normal_50')
print_info(mm_val_benign_50, mm_val_benign, 'mm_val_benign_50')
print_info(mm_val_malignant_50, mm_val_malignant, 'mm_val_malignant_50')
print_info(mm_val_normal_50, mm_val_normal, 'mm_val_normal_50')

In [None]:
# Data for original datasets (ultrasound and mammography sources)
data = {
    'Classes': ['Benign', 'Malignant', 'Normal'],
    'US Train Original': [513, 396, 83],
    'US Val Original': [129, 99, 21],
    'MM Train Original': [513, 396, 83],
    'MM Val Original': [129, 99, 21],
    'US Train Selected (50%)': [256, 198, 41],
    'US Val Selected (50%)': [64, 49, 10],
    'MM Train Selected (50%)': [256, 198, 41],
    'MM Val Selected (50%)': [64, 49, 10],
}

df = pd.DataFrame(data)

sns.set(style="whitegrid")

# Bar width and position adjustment
bar_width = 0.2
r1 = np.arange(len(df['Classes']))  # Positions for the first set of bars
r2 = [x + bar_width for x in r1]    # Shift positions for the second set
r3 = [x + bar_width for x in r2]    # Shift positions for the third set
r4 = [x + bar_width for x in r3]    # Shift positions for the fourth set

fig, axs = plt.subplots(2, 1, figsize=(12, 10), sharex=True)

# Plot for train data
axs[0].bar(r1, df['US Train Original'], color='lightblue', width=bar_width, label='US Train Original')
axs[0].bar(r2, df['US Train Selected (50%)'], color='blue', width=bar_width, label='US Train Selected (50%)')
axs[0].bar(r3, df['MM Train Original'], color='lightgreen', width=bar_width, label='MM Train Original')
axs[0].bar(r4, df['MM Train Selected (50%)'], color='green', width=bar_width, label='MM Train Selected (50%)')
axs[0].set_title('Comparison of Train Data (Original vs Selected 50%)')

# Plot for validation data
axs[1].bar(r1, df['US Val Original'], color='lightblue', width=bar_width, label='US Val Original')
axs[1].bar(r2, df['US Val Selected (50%)'], color='blue', width=bar_width, label='US Val Selected (50%)')
axs[1].bar(r3, df['MM Val Original'], color='lightgreen', width=bar_width, label='MM Val Original')
axs[1].bar(r4, df['MM Val Selected (50%)'], color='green', width=bar_width, label='MM Val Selected (50%)')
axs[1].set_title('Comparison of Validation Data (Original vs Selected 50%)')

# Add legends and labels
for ax in axs:
    ax.set_ylabel('Count')
    ax.set_xlabel('Classes')
    ax.set_xticks([r + bar_width for r in range(len(df['Classes']))])
    ax.set_xticklabels(df['Classes'])
    ax.legend()

# Display the plots
plt.tight_layout()
plt.show()

In [None]:
def verify_finetune_data():
    # Expected values for the number of images in each category
    expected_counts = {
        'benign': {'train': 121, 'val': 41, 'test': 41},
        'malignant': {'train': 97, 'val': 33, 'test': 33},
        'normal': {'train': 19, 'val': 7, 'test': 7}
    }
    
    # Mammography malignant case: split between CBIS and INBreast
    expected_malignant_mammo = {
        'cbis': {'train': 37, 'val': 13, 'test': 13},
        'inbreast': {'train': 60, 'val': 20, 'test': 20}
    }

    # Verify benign counts
    print_info(us_train_benign_ft_50, train_benign_ft, "Benign Train (Ultrasound)")
    print_info(us_val_benign_ft_50, val_benign_ft, "Benign Validation (Ultrasound)")
    print_info(us_test_benign_ft_50, test_benign_ft, "Benign Test (Ultrasound)")
    
    print_info(mm_train_benign_ft_50, mm_train_benign_ft, "Benign Train (Mammography)")
    print_info(mm_val_benign_ft_50, mm_val_benign_ft, "Benign Validation (Mammography)")
    print_info(mm_test_benign_ft_50, mm_test_benign_ft, "Benign Test (Mammography)")

    # Verify malignant counts (ultrasound only)
    print_info(us_train_malignant_ft_50, train_malignant_ft, "Malignant Train (Ultrasound)")
    print_info(us_val_malignant_ft_50, val_malignant_ft, "Malignant Validation (Ultrasound)")
    print_info(us_test_malignant_ft_50, test_malignant_ft, "Malignant Test (Ultrasound)")

    # Verify mammography malignant counts: CBIS and INBreast separately
    print_info(mm_train_cbis_malignant_ft_50_aux, mm_train_cbis_malignant_aux, "Malignant Train (CBIS - Mammography)")
    print_info(mm_val_cbis_malignant_ft_50_aux, mm_val_cbis_malignant_aux, "Malignant Validation (CBIS - Mammography)")
    print_info(mm_test_cbis_malignant_ft_50_aux, mm_test_cbis_malignant_aux, "Malignant Test (CBIS - Mammography)")
    
    print_info(mm_train_inbreast_malignant_ft_50_aux, mm_train_inbreast_malignant_aux, "Malignant Train (INBreast - Mammography)")
    print_info(mm_val_inbreast_malignant_ft_50_aux, mm_val_inbreast_malignant_aux, "Malignant Validation (INBreast - Mammography)")
    print_info(mm_test_inbreast_malignant_ft_50_aux, mm_test_inbreast_malignant_aux, "Malignant Test (INBreast - Mammography)")

    # Verify normal counts
    print_info(us_train_normal_ft_50, train_normal_ft, "Normal Train (Ultrasound)")
    print_info(us_val_normal_ft_50, val_normal_ft, "Normal Validation (Ultrasound)")
    print_info(us_test_normal_ft_50, test_normal_ft, "Normal Test (Ultrasound)")
    
    print_info(mm_train_normal_ft_50, mm_train_normal_ft, "Normal Train (Mammography)")
    print_info(mm_val_normal_ft_50, mm_val_normal_ft, "Normal Validation (Mammography)")
    print_info(mm_test_normal_ft_50, mm_test_normal_ft, "Normal Test (Mammography)")

verify_finetune_data()

In [None]:
def copy_files(paths, phaseFolder, splitFolder, classFolder, imagetypeFolder, modalityFolder):
    if phaseFolder == 'pretrainData':
        destination_dir = os.path.join(multimodal_dir, phaseFolder, splitFolder, modalityFolder)
    else:
        destination_dir = os.path.join(multimodal_dir, phaseFolder, splitFolder, classFolder, imagetypeFolder, modalityFolder)
    os.makedirs(destination_dir, exist_ok=True)
    
    c=0
    for file_path in paths:
        shutil.copy(file_path, destination_dir)
        c+=1
    print("we did", c, "copies for folder:", destination_dir)

In [None]:
# Ultrasound pretrain
copy_files(us_train_benign_50, 'pretrainData', 'train', 'N/A', 'N/A', 'ultrasoundImages')
copy_files(us_train_malignant_50, 'pretrainData', 'train', 'N/A', 'N/A', 'ultrasoundImages')
copy_files(us_train_normal_50, 'pretrainData', 'train', 'N/A', 'N/A', 'ultrasoundImages')

copy_files(us_val_benign_50, 'pretrainData', 'validation', 'N/A', 'N/A', 'ultrasoundImages')
copy_files(us_val_malignant_50, 'pretrainData', 'validation', 'N/A', 'N/A', 'ultrasoundImages')
copy_files(us_val_normal_50, 'pretrainData', 'validation', 'N/A', 'N/A', 'ultrasoundImages')

# Mammography pretrain
copy_files(mm_train_benign_50, 'pretrainData', 'train', 'N/A', 'N/A', 'mammographyImages')
copy_files(mm_train_malignant_50, 'pretrainData', 'train', 'N/A', 'N/A', 'mammographyImages')
copy_files(mm_train_normal_50, 'pretrainData', 'train', 'N/A', 'N/A', 'mammographyImages')

copy_files(mm_val_benign_50, 'pretrainData', 'validation', 'N/A', 'N/A', 'mammographyImages')
copy_files(mm_val_malignant_50, 'pretrainData', 'validation', 'N/A', 'N/A', 'mammographyImages')
copy_files(mm_val_normal_50, 'pretrainData', 'validation', 'N/A', 'N/A', 'mammographyImages')

In [None]:
# Finetune ultrasound benign
copy_files(us_train_benign_ft_50, 'finetuneData', 'train', 'benign', 'images', 'ultrasoundImages')
copy_files(us_train_benign_masks_ft_50, 'finetuneData', 'train', 'benign', 'masks', 'UltrasoundMasks')

copy_files(us_val_benign_ft_50, 'finetuneData', 'validation', 'benign', 'images', 'ultrasoundImages')
copy_files(us_val_benign_masks_ft_50, 'finetuneData', 'validation', 'benign', 'masks', 'UltrasoundMasks')

copy_files(us_test_benign_ft_50, 'finetuneData', 'test', 'benign', 'images', 'ultrasoundImages')
copy_files(us_test_benign_masks_ft_50, 'finetuneData', 'test', 'benign', 'masks', 'UltrasoundMasks')

# Finetune ultrasound malignant
copy_files(us_train_malignant_ft_50, 'finetuneData', 'train', 'malignant', 'images', 'ultrasoundImages')
copy_files(us_train_malignant_masks_ft_50, 'finetuneData', 'train', 'malignant', 'masks', 'UltrasoundMasks')

copy_files(us_val_malignant_ft_50, 'finetuneData', 'validation', 'malignant', 'images', 'ultrasoundImages')
copy_files(us_val_malignant_masks_ft_50, 'finetuneData', 'validation', 'malignant', 'masks', 'UltrasoundMasks')

copy_files(us_test_malignant_ft_50, 'finetuneData', 'test', 'malignant', 'images', 'ultrasoundImages')
copy_files(us_test_malignant_masks_ft_50, 'finetuneData', 'test', 'malignant', 'masks', 'UltrasoundMasks')

# Finetune ultrasound normal
copy_files(us_train_normal_ft_50, 'finetuneData', 'train', 'normal', 'images', 'ultrasoundImages')
copy_files(us_train_normal_masks_ft_50, 'finetuneData', 'train', 'normal', 'masks', 'UltrasoundMasks')

copy_files(us_val_normal_ft_50, 'finetuneData', 'validation', 'normal', 'images', 'ultrasoundImages')
copy_files(us_val_normal_masks_ft_50, 'finetuneData', 'validation', 'normal', 'masks', 'UltrasoundMasks')

copy_files(us_test_normal_ft_50, 'finetuneData', 'test', 'normal', 'images', 'ultrasoundImages')
copy_files(us_test_normal_masks_ft_50, 'finetuneData', 'test', 'normal', 'masks', 'UltrasoundMasks')

In [None]:
# Combine CBIS-DDSM and INBreast malignant data
combined_train_malignant_ft_50 = mm_train_cbis_malignant_ft_50_aux + mm_train_inbreast_malignant_ft_50_aux
combined_train_malignant_masks_ft_50 = mm_train_cbis_malignant_masks_ft_50_aux  + mm_train_inbreast_malignant_masks_ft_50_aux

combined_val_malignant_ft_50 = mm_val_cbis_malignant_ft_50_aux + mm_val_inbreast_malignant_ft_50_aux
combined_val_malignant_masks_ft_50 = mm_val_cbis_malignant_masks_ft_50_aux + mm_val_inbreast_malignant_masks_ft_50_aux

combined_test_malignant_ft_50 = mm_test_cbis_malignant_ft_50_aux + mm_test_inbreast_malignant_ft_50_aux
combined_test_malignant_masks_ft_50 = mm_test_cbis_malignant_masks_ft_50_aux + mm_test_inbreast_malignant_masks_ft_50_aux

# Call the function for combined malignant cases
copy_files(combined_train_malignant_ft_50, 'finetuneData', 'train', 'malignant', 'images', 'mammographyImages')
copy_files(combined_train_malignant_masks_ft_50, 'finetuneData', 'train', 'malignant', 'masks', 'MammographyMasks')

copy_files(combined_val_malignant_ft_50, 'finetuneData', 'validation', 'malignant', 'images', 'mammographyImages')
copy_files(combined_val_malignant_masks_ft_50, 'finetuneData', 'validation', 'malignant', 'masks', 'MammographyMasks')

copy_files(combined_test_malignant_ft_50, 'finetuneData', 'test', 'malignant', 'images', 'mammographyImages')
copy_files(combined_test_malignant_masks_ft_50, 'finetuneData', 'test', 'malignant', 'masks', 'MammographyMasks')

# Mammography benign
copy_files(mm_train_benign_ft_50, 'finetuneData', 'train', 'benign', 'images', 'mammographyImages')
copy_files(mm_train_benign_masks_ft_50, 'finetuneData', 'train', 'benign', 'masks', 'MammographyMasks')

copy_files(mm_val_benign_ft_50, 'finetuneData', 'validation', 'benign', 'images', 'mammographyImages')
copy_files(mm_val_benign_masks_ft_50, 'finetuneData', 'validation', 'benign', 'masks', 'MammographyMasks')

copy_files(mm_test_benign_ft_50, 'finetuneData', 'test', 'benign', 'images', 'mammographyImages')
copy_files(mm_test_benign_masks_ft_50, 'finetuneData', 'test', 'benign', 'masks', 'MammographyMasks')

# Mammography normal
copy_files(mm_train_normal_ft_50, 'finetuneData', 'train', 'normal', 'images', 'mammographyImages')
copy_files(mm_train_normal_masks_ft_50, 'finetuneData', 'train', 'normal', 'masks', 'MammographyMasks')

copy_files(mm_val_normal_ft_50, 'finetuneData', 'validation', 'normal', 'images', 'mammographyImages')
copy_files(mm_val_normal_masks_ft_50, 'finetuneData', 'validation', 'normal', 'masks', 'MammographyMasks')

copy_files(mm_test_normal_ft_50, 'finetuneData', 'test', 'normal', 'images', 'mammographyImages')
copy_files(mm_test_normal_masks_ft_50, 'finetuneData', 'test', 'normal', 'masks', 'MammographyMasks')

## 3. Preparing datasets for models

### 3.1. Setting classes and variables for all datapoints across datasets

In [None]:
data_used_dir = os.path.join(data_dir, "dataUsed")
classes = ["benign", "malignant", "normal"]
classes_numbercoded = {'benign': 0, 'malignant': 1, 'normal': 2}
num_classes = len(classes)

In [None]:
# Mammography pretraining data
mg_pt_train_im_dir = os.path.join(data_used_dir, 'MammographyDataset', 'pretrainData', 'train')
mg_pt_val_im_dir = os.path.join(data_used_dir, 'MammographyDataset', 'pretrainData', 'validation')

# Mammography finetuning data
mg_ft_train_b_im_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'train', 'benign', 'images')
mg_ft_train_b_msk_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'train', 'benign', 'masks')
mg_ft_train_m_im_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'train', 'malignant', 'images')
mg_ft_train_m_msk_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'train', 'malignant', 'masks')
mg_ft_train_n_im_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'train', 'normal', 'images')
mg_ft_train_n_msk_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'train', 'normal', 'masks')

mg_ft_val_b_im_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'validation', 'benign', 'images')
mg_ft_val_b_msk_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'validation', 'benign', 'masks')
mg_ft_val_m_im_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'validation', 'malignant', 'images')
mg_ft_val_m_msk_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'validation', 'malignant', 'masks')
mg_ft_val_n_im_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'validation', 'normal', 'images')
mg_ft_val_n_msk_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'validation', 'normal', 'masks')

mg_ft_test_b_im_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'test', 'benign', 'images')
mg_ft_test_b_msk_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'test', 'benign', 'masks')
mg_ft_test_m_im_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'test', 'malignant', 'images')
mg_ft_test_m_msk_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'test', 'malignant', 'masks')
mg_ft_test_n_im_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'test', 'normal', 'images')
mg_ft_test_n_msk_dir = os.path.join(data_used_dir, 'MammographyDataset', 'finetuneData', 'test', 'normal', 'masks')

# Ultrasound pretraining data
us_pt_train_im_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'pretrainData', 'train')
us_pt_val_im_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'pretrainData', 'validation')

# Ultrasound finetuning data
us_ft_train_b_im_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'train', 'benign', 'images')
us_ft_train_b_msk_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'train', 'benign', 'masks')
us_ft_train_m_im_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'train', 'malignant', 'images')
us_ft_train_m_msk_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'train', 'malignant', 'masks')
us_ft_train_n_im_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'train', 'normal', 'images')
us_ft_train_n_msk_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'train', 'normal', 'masks')

us_ft_val_b_im_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'validation', 'benign', 'images')
us_ft_val_b_msk_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'validation', 'benign', 'masks')
us_ft_val_m_im_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'validation', 'malignant', 'images')
us_ft_val_m_msk_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'validation', 'malignant', 'masks')
us_ft_val_n_im_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'validation', 'normal', 'images')
us_ft_val_n_msk_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'validation', 'normal', 'masks')

us_ft_test_b_im_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'test', 'benign', 'images')
us_ft_test_b_msk_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'test', 'benign', 'masks')
us_ft_test_m_im_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'test', 'malignant', 'images')
us_ft_test_m_msk_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'test', 'malignant', 'masks')
us_ft_test_n_im_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'test', 'normal', 'images')
us_ft_test_n_msk_dir = os.path.join(data_used_dir, 'UltrasoundDataset', 'finetuneData', 'test', 'normal', 'masks')

# Multimodal pretraining data (Ultrasound part)
multi_pt_train_im_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'pretrainData', 'train', 'ultrasoundImages')
multi_pt_val_im_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'pretrainData', 'validation', 'ultrasoundImages')

# Multimodal pretraining data (Mammography part)
multi_pt_train_im_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'pretrainData', 'train', 'mammographyImages')
multi_pt_val_im_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'pretrainData', 'validation', 'mammographyImages')

# Train phase for Multimodal finetuning data (Ultrasound part)
multi_ft_train_b_im_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'benign', 'images', 'ultrasoundImages')
multi_ft_train_b_msk_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'benign', 'masks', 'UltrasoundMasks')
multi_ft_train_m_im_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'malignant', 'images', 'ultrasoundImages')
multi_ft_train_m_msk_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'malignant', 'masks', 'UltrasoundMasks')
multi_ft_train_n_im_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'normal', 'images', 'ultrasoundImages')
multi_ft_train_n_msk_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'normal', 'masks', 'UltrasoundMasks')

# Train phase for Multimodal finetuning data (Mammography part)
multi_ft_train_b_im_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'benign', 'images', 'mammographyImages')
multi_ft_train_b_msk_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'benign', 'masks', 'MammographyMasks')
multi_ft_train_m_im_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'malignant', 'images', 'mammographyImages')
multi_ft_train_m_msk_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'malignant', 'masks', 'MammographyMasks')
multi_ft_train_n_im_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'normal', 'images', 'mammographyImages')
multi_ft_train_n_msk_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'train', 'normal', 'masks', 'MammographyMasks')

# Validation phase for Multimodal finetuning (Ultrasound part)
multi_ft_val_b_im_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'benign', 'images', 'ultrasoundImages')
multi_ft_val_b_msk_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'benign', 'masks', 'UltrasoundMasks')
multi_ft_val_m_im_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'malignant', 'images', 'ultrasoundImages')
multi_ft_val_m_msk_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'malignant', 'masks', 'UltrasoundMasks')
multi_ft_val_n_im_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'normal', 'images', 'ultrasoundImages')
multi_ft_val_n_msk_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'normal', 'masks', 'UltrasoundMasks')

# Validation phase for Multimodal finetuning (Mammography part)
multi_ft_val_b_im_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'benign', 'images', 'mammographyImages')
multi_ft_val_b_msk_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'benign', 'masks', 'MammographyMasks')
multi_ft_val_m_im_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'malignant', 'images', 'mammographyImages')
multi_ft_val_m_msk_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'malignant', 'masks', 'MammographyMasks')
multi_ft_val_n_im_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'normal', 'images', 'mammographyImages')
multi_ft_val_n_msk_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'validation', 'normal', 'masks', 'MammographyMasks')

# Test phase for Multimodal finetuning (Ultrasound)
multi_ft_test_b_im_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'benign', 'images', 'ultrasoundImages')
multi_ft_test_b_msk_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'benign', 'masks', 'UltrasoundMasks')
multi_ft_test_m_im_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'malignant', 'images', 'ultrasoundImages')
multi_ft_test_m_msk_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'malignant', 'masks', 'UltrasoundMasks')
multi_ft_test_n_im_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'normal', 'images', 'ultrasoundImages')
multi_ft_test_n_msk_usdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'normal', 'masks', 'UltrasoundMasks')

# Test phase for Multimodal finetuning (Mammography)
multi_ft_test_b_im_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'benign', 'images', 'mammographyImages')
multi_ft_test_b_msk_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'benign', 'masks', 'MammographyMasks')
multi_ft_test_m_im_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'malignant', 'images', 'mammographyImages')
multi_ft_test_m_msk_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'malignant', 'masks', 'MammographyMasks')
multi_ft_test_n_im_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'normal', 'images', 'mammographyImages')
multi_ft_test_n_msk_mgdata_dir = os.path.join(data_used_dir, 'MultimodalDataset', 'finetuneData', 'test', 'normal', 'masks', 'MammographyMasks')

In [None]:
def count_files_in_dir(directory):
    return len([file for file in os.listdir(directory) if os.path.isfile(os.path.join(directory, file))])

counts = {
    "MammographyDataset": 0,
    "UltrasoundDataset": 0,
    "MultimodalDataset": 0
}

directories = [
    # Mammography dataset
    mg_pt_train_im_dir, mg_pt_val_im_dir,
    mg_ft_train_b_im_dir, mg_ft_train_b_msk_dir, mg_ft_train_m_im_dir, mg_ft_train_m_msk_dir, mg_ft_train_n_im_dir, mg_ft_train_n_msk_dir,
    mg_ft_val_b_im_dir, mg_ft_val_b_msk_dir, mg_ft_val_m_im_dir, mg_ft_val_m_msk_dir, mg_ft_val_n_im_dir, mg_ft_val_n_msk_dir,
    mg_ft_test_b_im_dir, mg_ft_test_b_msk_dir, mg_ft_test_m_im_dir, mg_ft_test_m_msk_dir, mg_ft_test_n_im_dir, mg_ft_test_n_msk_dir,

    # Ultrasound dataset
    us_pt_train_im_dir, us_pt_val_im_dir,
    us_ft_train_b_im_dir, us_ft_train_b_msk_dir, us_ft_train_m_im_dir, us_ft_train_m_msk_dir, us_ft_train_n_im_dir, us_ft_train_n_msk_dir,
    us_ft_val_b_im_dir, us_ft_val_b_msk_dir, us_ft_val_m_im_dir, us_ft_val_m_msk_dir, us_ft_val_n_im_dir, us_ft_val_n_msk_dir,
    us_ft_test_b_im_dir, us_ft_test_b_msk_dir, us_ft_test_m_im_dir, us_ft_test_m_msk_dir, us_ft_test_n_im_dir, us_ft_test_n_msk_dir,

    # Multimodal dataset
    multi_pt_train_im_usdata_dir, multi_pt_val_im_usdata_dir, multi_pt_train_im_mgdata_dir, multi_pt_val_im_mgdata_dir,
    multi_ft_train_b_im_usdata_dir, multi_ft_train_b_msk_usdata_dir, multi_ft_train_b_im_mgdata_dir, multi_ft_train_b_msk_mgdata_dir,
    multi_ft_train_m_im_usdata_dir, multi_ft_train_m_msk_usdata_dir, multi_ft_train_m_im_mgdata_dir, multi_ft_train_m_msk_mgdata_dir,
    multi_ft_train_n_im_usdata_dir, multi_ft_train_n_msk_usdata_dir, multi_ft_train_n_im_mgdata_dir, multi_ft_train_n_msk_mgdata_dir,
    multi_ft_val_b_im_usdata_dir, multi_ft_val_b_msk_usdata_dir, multi_ft_val_b_im_mgdata_dir, multi_ft_val_b_msk_mgdata_dir,
    multi_ft_val_m_im_usdata_dir, multi_ft_val_m_msk_usdata_dir, multi_ft_val_m_im_mgdata_dir, multi_ft_val_m_msk_mgdata_dir,
    multi_ft_val_n_im_usdata_dir, multi_ft_val_n_msk_usdata_dir, multi_ft_val_n_im_mgdata_dir, multi_ft_val_n_msk_mgdata_dir,
    multi_ft_test_b_im_usdata_dir, multi_ft_test_b_msk_usdata_dir, multi_ft_test_b_im_mgdata_dir, multi_ft_test_b_msk_mgdata_dir,
    multi_ft_test_m_im_usdata_dir, multi_ft_test_m_msk_usdata_dir, multi_ft_test_m_im_mgdata_dir, multi_ft_test_m_msk_mgdata_dir,
    multi_ft_test_n_im_usdata_dir, multi_ft_test_n_msk_usdata_dir, multi_ft_test_n_im_mgdata_dir, multi_ft_test_n_msk_mgdata_dir,
]

def count_files(directories):
    counts = {}
    dataset_sums = {'mammography': 0, 'ultrasound': 0, 'multimodal': 0}
    
    for dir_path in directories:
        if os.path.exists(dir_path):
            num_files = len(os.listdir(dir_path))
            counts[dir_path] = num_files
            
            # Determine dataset type and add to corresponding sum
            if 'mammography' in dir_path:
                dataset_sums['mammography'] += num_files
            elif 'ultrasound' in dir_path:
                dataset_sums['ultrasound'] += num_files
            elif 'Multimodal' in dir_path:
                dataset_sums['multimodal'] += num_files
                
    return counts, dataset_sums

file_counts, dataset_totals = count_files(directories)
file_counts, dataset_totals

View example image and mask

In [None]:
benign_images = [f for f in os.listdir(us_ft_train_b_im_dir) if f.endswith('.png')]
if benign_images:
    # Show first benign image
    benign_images.sort()
    
    benign_image_name = benign_images[0]
    benign_image_path = os.path.join(us_ft_train_b_im_dir, benign_image_name)

    benign_image = cv2.imread(benign_image_path)
    benign_image = cv2.cvtColor(benign_image, cv2.COLOR_BGR2RGB)  # Convert to RGB
    plt.imshow(benign_image)
    plt.title(f'Benign Image: {benign_image_name}')
    plt.axis('off')
    plt.show()

    mask_image_name = benign_image_name.replace('.png', '_mask.png')
    mask_image_path = os.path.join(us_ft_train_b_msk_dir, mask_image_name)

    if os.path.exists(mask_image_path):
        mask_image = cv2.imread(mask_image_path)
        mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)  # Convert to RGB
        plt.imshow(mask_image)
        plt.title(f'Mask Image: {mask_image_name}')
        plt.axis('off')
        plt.show()
    else:
        print(f'Mask image not found: {mask_image_path}')
else:
    print('No benign images found in the specified directory.')


### 3.2. Setting augmentations

In [None]:
class ContrastiveTransformations(object):
    def __init__(
        self, base_transforms, n_views=2):
        self.base_transforms = base_transforms
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transforms(x) for i in range(self.n_views)]

class addBackgroundPadding:
    def __call__(self, img):
        width, height = img.size

        # Calculate padding to add on each side to make the image square
        max_side = max(width, height)
        padding = (
            (max_side - width) // 2,  # left
            (max_side - height) // 2, # top
            (max_side - width + 1) // 2,  # right
            (max_side - height + 1) // 2  # bottom
        )
        # Apply padding with 0s
        return transforms.functional.pad(img, padding, fill=0)

In [None]:
contrast_transformsTHEV2VERSION = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Color is irrelevant
    transforms.Resize((110, 110)),  # Increase resize to allow varied cropping
    transforms.RandomRotation(degrees=10),  # Small rotation for minor variance that happens before resize/crop
    transforms.RandomResizedCrop(size=64, scale=(0.5, 0.9)),  # Adjust scale for variable crop
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1)
    ], p=0.8),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=9)
    ], p=0.3),  # Adjusted probability for Gaussian blur
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

contrast_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Color is irrelevant
    transforms.Resize((110, 110)),  # Increase resize to allow varied cropping
    transforms.RandomResizedCrop(size=64, scale=(0.5, 0.9)),  # Adjust scale for variable crop
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1)
    ], p=0.8),
    transforms.GaussianBlur(kernel_size=9),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Finetune only
only_resize_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Color is irrelevant
    addBackgroundPadding(), # Pad the image to square with black
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Finetune only on segmentation (masks)
only_resize_transforms_masks = transforms.Compose([
    addBackgroundPadding(),  # Same padding for masks
    transforms.Resize((64, 64)),  # Resize to final target size
    transforms.ToTensor()  # No normalization for masks
])

### 3.3. Building data loaders for all splits of the three big datasets

In [None]:
pretrain_batch_size = 64
finetune_batch_size = 32

#### 3.3.1. Pretraining dataloaders construction for both classification and segmentation

In [None]:
class PretrainingDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        image = Image.open(img_path).convert("L") # Color is irrelevant

        # Generate two augmented views of the same image
        view1, view2 = self._get_augmented_image(image)
        return view1, view2

    def _get_augmented_image(self, image):
        if self.transform:
            # Use the transform to create two views
            views = self.transform(image)  # List of augmented images
            return views[0], views[1]  # Return the first two views
        return image, image  # Return original if no transform provided

In [194]:
# Pretrain unlabeled
us_pretrain_train_set = PretrainingDataset(
    image_dir = us_pt_train_im_dir,
    transform = ContrastiveTransformations(contrast_transforms, n_views=2)
)
us_pretrain_validation_set = PretrainingDataset(
    image_dir = us_pt_val_im_dir,
    transform = ContrastiveTransformations(contrast_transforms, n_views=2)
)

mg_pretrain_train_set = PretrainingDataset(
    image_dir = mg_pt_train_im_dir,
    transform = ContrastiveTransformations(contrast_transforms, n_views=2)
)
mg_pretrain_validation_set = PretrainingDataset(
    image_dir = mg_pt_val_im_dir,
    transform = ContrastiveTransformations(contrast_transforms, n_views=2)
)

multi_pretrain_train_usdata = PretrainingDataset(
    image_dir = multi_pt_train_im_usdata_dir,
    transform = ContrastiveTransformations(contrast_transforms, n_views=2)
)
multi_pretrain_train_mgdata = PretrainingDataset(
    image_dir = multi_pt_train_im_mgdata_dir,
    transform = ContrastiveTransformations(contrast_transforms, n_views=2)
)
multi_pretrain_validation_usdata = PretrainingDataset(
    image_dir = multi_pt_val_im_usdata_dir,
    transform = ContrastiveTransformations(contrast_transforms, n_views=2)
)
multi_pretrain_validation_mgdata = PretrainingDataset(
    image_dir = multi_pt_val_im_mgdata_dir,
    transform = ContrastiveTransformations(contrast_transforms, n_views=2)
)

# Concatenate the ultrasound and mammography pretrain datasets
multi_pretrain_train_set = torch.utils.data.ConcatDataset([multi_pretrain_train_usdata, multi_pretrain_train_mgdata])

# Concatenate the ultrasound and mammography pretrain datasets
multi_pretrain_validation_set = torch.utils.data.ConcatDataset([multi_pretrain_validation_usdata, multi_pretrain_validation_mgdata])

In [None]:
us_pretrain_train_dataloader = torch.utils.data.DataLoader(
    us_pretrain_train_set,
    batch_size = pretrain_batch_size,
    shuffle = True,
    drop_last = False,
    num_workers = 8
)

us_pretrain_val_dataloader = torch.utils.data.DataLoader(
    us_pretrain_validation_set,
    batch_size = pretrain_batch_size,
    shuffle = False,
    drop_last = False,
    num_workers = 8
)

mg_pretrain_train_dataloader = torch.utils.data.DataLoader(
    mg_pretrain_train_set,
    batch_size = pretrain_batch_size,
    shuffle = True,
    drop_last = False, 
    num_workers = 8
)

mg_pretrain_val_dataloader = torch.utils.data.DataLoader(
    mg_pretrain_validation_set,
    batch_size = pretrain_batch_size,
    shuffle = False,
    drop_last = False,
    num_workers = 8
)

multi_pretrain_train_dataloader = torch.utils.data.DataLoader(
    multi_pretrain_train_set,
    batch_size = pretrain_batch_size,
    shuffle = True,
    drop_last = False,
    num_workers = 8
)

multi_pretrain_val_dataloader = torch.utils.data.DataLoader(
    multi_pretrain_validation_set,
    batch_size = pretrain_batch_size,
    shuffle = False,
    drop_last = False,
    num_workers = 8
)

#### 3.3.2. Finetune dataloaders construction for classification models

##### 3.3.2.1. Setting finetune classification dataset class

In [None]:
class FinetuningClassifDataset(Dataset):
  def __init__(self, image_dir, label, transform = None):
    self.image_dir = image_dir
    self.transform = transform
    self.images = os.listdir(image_dir)
    self.label = label

  def __len__(self):
    return len(self.images)

  def __getitem__(self,idx):
    img_path = os.path.join(self.image_dir, self.images[idx])
    image = Image.open(img_path).convert("L")

    if self.transform is not None:
      image = self.transform(image)

    label = torch.tensor(classes_numbercoded[self.label], dtype=torch.long)

    return image, label

##### 3.3.2.2. Ultrasound finetune dataloaders

In [197]:
# Finetune labeled
classif_us_finetune_train_benign = FinetuningClassifDataset(
    image_dir = us_ft_train_b_im_dir,
    transform = only_resize_transforms,
    label = "benign"
)
classif_us_finetune_train_malignant = FinetuningClassifDataset(
    image_dir = us_ft_train_m_im_dir,
    transform = only_resize_transforms,
    label = "malignant"
)
classif_us_finetune_train_normal = FinetuningClassifDataset(
    image_dir = us_ft_train_n_im_dir,
    transform = only_resize_transforms,
    label = "normal"
)
classif_us_finetune_train_set = torch.utils.data.ConcatDataset([classif_us_finetune_train_benign, classif_us_finetune_train_malignant, classif_us_finetune_train_normal])

classif_us_finetune_val_benign = FinetuningClassifDataset(
    image_dir = us_ft_val_b_im_dir,
    transform = only_resize_transforms,
    label = "benign"
)
classif_us_finetune_val_malignant = FinetuningClassifDataset(
    image_dir = us_ft_val_m_im_dir,
    transform = only_resize_transforms,
    label = "malignant"
)
classif_us_finetune_val_normal = FinetuningClassifDataset(
    image_dir = us_ft_val_n_im_dir,
    transform = only_resize_transforms,
    label = "normal"
)
classif_us_finetune_validation_set = torch.utils.data.ConcatDataset([classif_us_finetune_val_benign, classif_us_finetune_val_malignant, classif_us_finetune_val_normal])

classif_us_finetune_test_benign = FinetuningClassifDataset(
    image_dir = us_ft_test_b_im_dir,
    transform = only_resize_transforms,
    label = "benign"
)
classif_us_finetune_test_malignant = FinetuningClassifDataset(
    image_dir = us_ft_test_m_im_dir,
    transform = only_resize_transforms,
    label = "malignant"
)
classif_us_finetune_test_normal = FinetuningClassifDataset(
    image_dir = us_ft_test_n_im_dir,
    transform = only_resize_transforms,
    label = "normal"
)
classif_us_finetune_test_set = torch.utils.data.ConcatDataset([classif_us_finetune_test_benign, classif_us_finetune_test_malignant, classif_us_finetune_test_normal])

In [None]:
# Data loaders
classif_us_finetune_train_dataloader = torch.utils.data.DataLoader(
    classif_us_finetune_train_set,
    batch_size = finetune_batch_size,
    shuffle = True,
    drop_last = False
)

classif_us_finetune_val_dataloader = torch.utils.data.DataLoader(
    classif_us_finetune_validation_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False
)

classif_us_finetune_test_dataloader = torch.utils.data.DataLoader(
    classif_us_finetune_test_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False
)

##### 3.3.2.3. Mammography finetune dataloaders

In [199]:
# Finetune labeled
classif_mg_finetune_train_benign = FinetuningClassifDataset(
    image_dir = mg_ft_train_b_im_dir,
    transform = only_resize_transforms,
    label = "benign"
)
classif_mg_finetune_train_malignant = FinetuningClassifDataset(
    image_dir = mg_ft_train_m_im_dir,
    transform = only_resize_transforms,
    label = "malignant"
)
classif_mg_finetune_train_normal = FinetuningClassifDataset(
    image_dir = mg_ft_train_n_im_dir,
    transform = only_resize_transforms,
    label = "normal"
)
classif_mg_finetune_train_set = torch.utils.data.ConcatDataset([classif_mg_finetune_train_benign, classif_mg_finetune_train_malignant, classif_mg_finetune_train_normal])

classif_mg_finetune_val_benign = FinetuningClassifDataset(
    image_dir = mg_ft_val_b_im_dir,
    transform = only_resize_transforms,
    label = "benign"
)
classif_mg_finetune_val_malignant = FinetuningClassifDataset(
    image_dir = mg_ft_val_m_im_dir,
    transform = only_resize_transforms,
    label = "malignant"
)
classif_mg_finetune_val_normal = FinetuningClassifDataset(
    image_dir = mg_ft_val_n_im_dir,
    transform = only_resize_transforms,
    label = "normal"
)
classif_mg_finetune_validation_set = torch.utils.data.ConcatDataset([classif_mg_finetune_val_benign, classif_mg_finetune_val_malignant, classif_mg_finetune_val_normal])

classif_mg_finetune_test_benign = FinetuningClassifDataset(
    image_dir = mg_ft_test_b_im_dir,
    transform = only_resize_transforms,
    label = "benign"
)
classif_mg_finetune_test_malignant = FinetuningClassifDataset(
    image_dir = mg_ft_test_m_im_dir,
    transform = only_resize_transforms,
    label = "malignant"
)
classif_mg_finetune_test_normal = FinetuningClassifDataset(
    image_dir = mg_ft_test_n_im_dir,
    transform = only_resize_transforms,
    label = "normal"
)
classif_mg_finetune_test_set = torch.utils.data.ConcatDataset([classif_mg_finetune_test_benign, classif_mg_finetune_test_malignant, classif_mg_finetune_test_normal])

In [None]:
# Data loaders
classif_mg_finetune_train_dataloader = torch.utils.data.DataLoader(
    classif_mg_finetune_train_set,
    batch_size = finetune_batch_size,
    shuffle = True,
    drop_last = False
)

classif_mg_finetune_val_dataloader = torch.utils.data.DataLoader(
    classif_mg_finetune_validation_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False
)

classif_mg_finetune_test_dataloader = torch.utils.data.DataLoader(
    classif_mg_finetune_test_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False
)

##### 3.3.2.4. Multimodal finetune dataloaders

In [201]:
# Multimodal Finetune - Train
classif_multi_finetune_train_benign_usdata = FinetuningClassifDataset(
    image_dir = multi_ft_train_b_im_usdata_dir,
    transform = only_resize_transforms,
    label = "benign"
)

classif_multi_finetune_train_benign_mgdata = FinetuningClassifDataset(
    image_dir = multi_ft_train_b_im_mgdata_dir,
    transform = only_resize_transforms,
    label = "benign"
)

classif_multi_finetune_train_malignant_usdata = FinetuningClassifDataset(
    image_dir = multi_ft_train_m_im_usdata_dir,
    transform = only_resize_transforms,
    label = "malignant"
)

classif_multi_finetune_train_malignant_mgdata = FinetuningClassifDataset(
    image_dir = multi_ft_train_m_im_mgdata_dir,
    transform = only_resize_transforms,
    label = "malignant"
)

classif_multi_finetune_train_normal_usdata = FinetuningClassifDataset(
    image_dir = multi_ft_train_n_im_usdata_dir,
    transform = only_resize_transforms,
    label = "normal"
)

classif_multi_finetune_train_normal_mgdata = FinetuningClassifDataset(
    image_dir = multi_ft_train_n_im_mgdata_dir,
    transform = only_resize_transforms,
    label = "normal"
)

# Final Multimodal Finetune Training Set
classif_multi_finetune_train_set = torch.utils.data.ConcatDataset([classif_multi_finetune_train_benign_usdata, classif_multi_finetune_train_benign_mgdata, classif_multi_finetune_train_malignant_usdata, classif_multi_finetune_train_malignant_mgdata, classif_multi_finetune_train_normal_usdata, classif_multi_finetune_train_normal_mgdata])

# Multimodal Finetune - Validation
classif_multi_finetune_val_benign_usdata = FinetuningClassifDataset(
    image_dir = multi_ft_val_b_im_usdata_dir,
    transform = only_resize_transforms,
    label = "benign"
)

classif_multi_finetune_val_benign_mgdata = FinetuningClassifDataset(
    image_dir = multi_ft_val_b_im_mgdata_dir,
    transform = only_resize_transforms,
    label = "benign"
)

classif_multi_finetune_val_malignant_usdata = FinetuningClassifDataset(
    image_dir = multi_ft_val_m_im_usdata_dir,
    transform = only_resize_transforms,
    label = "malignant"
)

classif_multi_finetune_val_malignant_mgdata = FinetuningClassifDataset(
    image_dir = multi_ft_val_m_im_mgdata_dir,
    transform = only_resize_transforms,
    label = "malignant"
)

classif_multi_finetune_val_normal_usdata = FinetuningClassifDataset(
    image_dir = multi_ft_val_n_im_usdata_dir,
    transform = only_resize_transforms,
    label = "normal"
)

classif_multi_finetune_val_normal_mgdata = FinetuningClassifDataset(
    image_dir = multi_ft_val_n_im_mgdata_dir,
    transform = only_resize_transforms,
    label = "normal"
)

# Final Multimodal Finetune Validation Set
classif_multi_finetune_val_set = torch.utils.data.ConcatDataset([classif_multi_finetune_val_benign_usdata, classif_multi_finetune_val_benign_mgdata, classif_multi_finetune_val_malignant_usdata, classif_multi_finetune_val_malignant_mgdata, classif_multi_finetune_val_normal_usdata, classif_multi_finetune_val_normal_mgdata])

# Multimodal Finetune - Test
classif_multi_finetune_test_benign_usdata = FinetuningClassifDataset(
    image_dir = multi_ft_test_b_im_usdata_dir,
    transform = only_resize_transforms,
    label = "benign"
)

classif_multi_finetune_test_benign_mgdata = FinetuningClassifDataset(
    image_dir = multi_ft_test_b_im_mgdata_dir,
    transform = only_resize_transforms,
    label = "benign"
)

classif_multi_finetune_test_malignant_usdata = FinetuningClassifDataset(
    image_dir = multi_ft_test_m_im_usdata_dir,
    transform = only_resize_transforms,
    label = "malignant"
)

classif_multi_finetune_test_malignant_mgdata = FinetuningClassifDataset(
    image_dir = multi_ft_test_m_im_mgdata_dir,
    transform = only_resize_transforms,
    label = "malignant"
)

classif_multi_finetune_test_normal_usdata = FinetuningClassifDataset(
    image_dir = multi_ft_test_n_im_usdata_dir,
    transform = only_resize_transforms,
    label = "normal"
)

classif_multi_finetune_test_normal_mgdata = FinetuningClassifDataset(
    image_dir = multi_ft_test_n_im_mgdata_dir,
    transform = only_resize_transforms,
    label = "normal"
)

# Final Multimodal Finetune Training Set
classif_multi_finetune_test_set = torch.utils.data.ConcatDataset([classif_multi_finetune_test_benign_usdata, classif_multi_finetune_test_benign_mgdata, classif_multi_finetune_test_malignant_usdata, classif_multi_finetune_test_malignant_mgdata, classif_multi_finetune_test_normal_usdata, classif_multi_finetune_test_normal_mgdata])

In [None]:
# Data loaders
classif_multi_finetune_train_dataloader = torch.utils.data.DataLoader(
    classif_multi_finetune_train_set,
    batch_size = finetune_batch_size,
    shuffle = True,
    drop_last = False
)

classif_multi_finetune_val_dataloader = torch.utils.data.DataLoader(
    classif_multi_finetune_val_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False
)

classif_multi_finetune_test_dataloader = torch.utils.data.DataLoader(
    classif_multi_finetune_test_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False
)

#### 3.3.3. Finetune dataloaders construction for segmentation models

##### 3.3.3.1. Setting finetune segmentation dataset classes

In [None]:
class FinetuningSegmDataset(Dataset):
    def __init__(self, image_dir, mask_dir, modality, transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.modality = modality
        self.transform = transform
        self.mask_transform = mask_transform

        # Sorted list of image filenames
        self.image_filenames = sorted(os.listdir(image_dir))

        # Create corresponding list of mask filenames
        self.mask_filenames = [self._generate_mask_filename(filename) for filename in self.image_filenames]

    def _generate_mask_filename(self, filename):
        mask_name = None
        if self.modality == "mammo":
            if "ANON" in filename:
                # Mammography ANON format
                parts = filename.split('_')
                id_token = parts[0]
                class_token = parts[-1].split('.')[0]
                mask_name = f"{id_token}_{class_token}_syntheticMask.png" if class_token == "normal" else f"{id_token}_{class_token}_mask.png"

                #Catch all benign cases and the malignant cases that dont use "id" in the filename
            elif "_id" in filename:
                # Malignant mammography format with "id" in filename
                id_token = filename.split('_')[1].replace('id', '')  # Extract part after 'id' and remove 'id'

                # Search for the corresponding mask in the mask directory
                for mask_filename in os.listdir(self.mask_dir):
                    if f"_id{id_token}_" in mask_filename:
                        mask_name = mask_filename
                        break
    
        elif self.modality == "ultra":
            # Simple naming convention for ultrasound
            if filename.endswith('.png'):
                mask_name = filename.replace('.png', '_mask.png')

        return mask_name

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        # Load the image
        img_path = os.path.join(self.image_dir, self.image_filenames[idx])
        image = Image.open(img_path).convert("L")

        # Load the corresponding mask
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
        mask = Image.open(mask_path).convert("L")
        
        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
            mask = (mask > 0.5).float()

        return image, mask

##### 3.3.3.2. Ultrasound finetune dataloaders

In [204]:
# Finetune labeled
segm_us_finetune_train_benign = FinetuningSegmDataset(
    image_dir = us_ft_train_b_im_dir,
    mask_dir = us_ft_train_b_msk_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_us_finetune_train_malignant = FinetuningSegmDataset(
    image_dir = us_ft_train_m_im_dir,
    mask_dir = us_ft_train_m_msk_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_us_finetune_train_normal = FinetuningSegmDataset(
    image_dir = us_ft_train_n_im_dir,
    mask_dir = us_ft_train_n_msk_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_us_finetune_train_set = torch.utils.data.ConcatDataset([segm_us_finetune_train_benign, segm_us_finetune_train_malignant, segm_us_finetune_train_normal])

segm_us_finetune_val_benign = FinetuningSegmDataset(
    image_dir = us_ft_val_b_im_dir,
    mask_dir = us_ft_val_b_msk_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_us_finetune_val_malignant = FinetuningSegmDataset(
    image_dir = us_ft_val_m_im_dir,
    mask_dir = us_ft_val_m_msk_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_us_finetune_val_normal = FinetuningSegmDataset(
    image_dir = us_ft_val_n_im_dir,
    mask_dir = us_ft_val_n_msk_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_us_finetune_validation_set = torch.utils.data.ConcatDataset([segm_us_finetune_val_benign, segm_us_finetune_val_malignant, segm_us_finetune_val_normal])

segm_us_finetune_test_benign = FinetuningSegmDataset(
    image_dir = us_ft_test_b_im_dir,
    mask_dir = us_ft_test_b_msk_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_us_finetune_test_malignant = FinetuningSegmDataset(
    image_dir = us_ft_test_m_im_dir,
    mask_dir = us_ft_test_m_msk_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_us_finetune_test_normal = FinetuningSegmDataset(
    image_dir = us_ft_test_n_im_dir,
    mask_dir = us_ft_test_n_msk_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_us_finetune_test_set = torch.utils.data.ConcatDataset([segm_us_finetune_test_benign, segm_us_finetune_test_malignant, segm_us_finetune_test_normal])


In [None]:
# Data loaders
segm_us_finetune_train_dataloader = torch.utils.data.DataLoader(
    segm_us_finetune_train_set,
    batch_size = finetune_batch_size,
    shuffle = True,
    drop_last = False

segm_us_finetune_val_dataloader = torch.utils.data.DataLoader(
    segm_us_finetune_validation_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False
)

segm_us_finetune_test_dataloader = torch.utils.data.DataLoader(
    segm_us_finetune_test_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False 

##### 3.3.3.3. Mammography finetune dataloaders

In [206]:
# Finetune labeled
segm_mg_finetune_train_benign = FinetuningSegmDataset(
    image_dir = mg_ft_train_b_im_dir,
    mask_dir = mg_ft_train_b_msk_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_mg_finetune_train_malignant = FinetuningSegmDataset(
    image_dir = mg_ft_train_m_im_dir,
    mask_dir = mg_ft_train_m_msk_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_mg_finetune_train_normal = FinetuningSegmDataset(
    image_dir = mg_ft_train_n_im_dir,
    mask_dir = mg_ft_train_n_msk_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_mg_finetune_train_set = torch.utils.data.ConcatDataset([segm_mg_finetune_train_benign, segm_mg_finetune_train_malignant, segm_mg_finetune_train_normal])

segm_mg_finetune_val_benign = FinetuningSegmDataset(
    image_dir = mg_ft_val_b_im_dir,
    mask_dir = mg_ft_val_b_msk_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_mg_finetune_val_malignant = FinetuningSegmDataset(
    image_dir = mg_ft_val_m_im_dir,
    mask_dir = mg_ft_val_m_msk_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_mg_finetune_val_normal = FinetuningSegmDataset(
    image_dir = mg_ft_val_n_im_dir,
    mask_dir = mg_ft_val_n_msk_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_mg_finetune_validation_set = torch.utils.data.ConcatDataset([segm_mg_finetune_val_benign, segm_mg_finetune_val_malignant, segm_mg_finetune_val_normal])

segm_mg_finetune_test_benign = FinetuningSegmDataset(
    image_dir = mg_ft_test_b_im_dir,
    mask_dir = mg_ft_test_b_msk_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_mg_finetune_test_malignant = FinetuningSegmDataset(
    image_dir = mg_ft_test_m_im_dir,
    mask_dir = mg_ft_test_m_msk_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_mg_finetune_test_normal = FinetuningSegmDataset(
    image_dir = mg_ft_test_n_im_dir,
    mask_dir = mg_ft_test_n_msk_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)
segm_mg_finetune_test_set = torch.utils.data.ConcatDataset([segm_mg_finetune_test_benign, segm_mg_finetune_test_malignant, segm_mg_finetune_test_normal])

In [None]:
# Data loaders
segm_mg_finetune_train_dataloader = torch.utils.data.DataLoader(
    segm_mg_finetune_train_set,
    batch_size = finetune_batch_size,
    shuffle = True,
    drop_last = False
)

segm_mg_finetune_val_dataloader = torch.utils.data.DataLoader(
    segm_mg_finetune_validation_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False
)

segm_mg_finetune_test_dataloader = torch.utils.data.DataLoader(
    segm_mg_finetune_test_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False
)

##### 3.3.3.4. Multimodal finetune dataloaders

In [208]:
# Multimodal Finetune - Train
segm_multi_finetune_train_benign_usdata = FinetuningSegmDataset(
    image_dir = multi_ft_train_b_im_usdata_dir,
    mask_dir = multi_ft_train_b_msk_usdata_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_train_benign_mgdata = FinetuningSegmDataset(
    image_dir = multi_ft_train_b_im_mgdata_dir,
    mask_dir = multi_ft_train_b_msk_mgdata_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_train_malignant_usdata = FinetuningSegmDataset(
    image_dir = multi_ft_train_m_im_usdata_dir,
    mask_dir = multi_ft_train_m_msk_usdata_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_train_malignant_mgdata = FinetuningSegmDataset(
    image_dir = multi_ft_train_m_im_mgdata_dir,
    mask_dir = multi_ft_train_m_msk_mgdata_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_train_normal_usdata = FinetuningSegmDataset(
    image_dir = multi_ft_train_n_im_usdata_dir,
    mask_dir = multi_ft_train_n_msk_usdata_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_train_normal_mgdata = FinetuningSegmDataset(
    image_dir = multi_ft_train_n_im_mgdata_dir,
    mask_dir = multi_ft_train_n_msk_mgdata_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

# Final Multimodal Finetune Training Set
segm_multi_finetune_train_set = torch.utils.data.ConcatDataset([segm_multi_finetune_train_benign_usdata, segm_multi_finetune_train_benign_mgdata, segm_multi_finetune_train_malignant_usdata, segm_multi_finetune_train_malignant_mgdata, segm_multi_finetune_train_normal_usdata, segm_multi_finetune_train_normal_mgdata])

# Multimodal Finetune - Validation
segm_multi_finetune_val_benign_usdata = FinetuningSegmDataset(
    image_dir = multi_ft_val_b_im_usdata_dir,
    mask_dir = multi_ft_val_b_msk_usdata_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_val_benign_mgdata = FinetuningSegmDataset(
    image_dir = multi_ft_val_b_im_mgdata_dir,
    mask_dir = multi_ft_val_b_msk_mgdata_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_val_malignant_usdata = FinetuningSegmDataset(
    image_dir = multi_ft_val_m_im_usdata_dir,
    mask_dir = multi_ft_val_m_msk_usdata_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_val_malignant_mgdata = FinetuningSegmDataset(
    image_dir = multi_ft_val_m_im_mgdata_dir,
    mask_dir = multi_ft_val_m_msk_mgdata_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_val_normal_usdata = FinetuningSegmDataset(
    image_dir = multi_ft_val_n_im_usdata_dir,
    mask_dir = multi_ft_val_n_msk_usdata_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_val_normal_mgdata = FinetuningSegmDataset(
    image_dir = multi_ft_val_n_im_mgdata_dir,
    mask_dir = multi_ft_val_n_msk_mgdata_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

# Final Multimodal Finetune Validation Set
segm_multi_finetune_val_set = torch.utils.data.ConcatDataset([segm_multi_finetune_val_benign_usdata, segm_multi_finetune_val_benign_mgdata, segm_multi_finetune_val_malignant_usdata, segm_multi_finetune_val_malignant_mgdata, segm_multi_finetune_val_normal_usdata, segm_multi_finetune_val_normal_mgdata])

# Multimodal Finetune - Test
segm_multi_finetune_test_benign_usdata = FinetuningSegmDataset(
    image_dir = multi_ft_test_b_im_usdata_dir,
    mask_dir = multi_ft_test_b_msk_usdata_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_test_benign_mgdata = FinetuningSegmDataset(
    image_dir = multi_ft_test_b_im_mgdata_dir,
    mask_dir = multi_ft_test_b_msk_mgdata_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_test_malignant_usdata = FinetuningSegmDataset(
    image_dir = multi_ft_test_m_im_usdata_dir,
    mask_dir = multi_ft_test_m_msk_usdata_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_test_malignant_mgdata = FinetuningSegmDataset(
    image_dir = multi_ft_test_m_im_mgdata_dir,
    mask_dir = multi_ft_test_m_msk_mgdata_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_test_normal_usdata = FinetuningSegmDataset(
    image_dir = multi_ft_test_n_im_usdata_dir,
    mask_dir = multi_ft_test_n_msk_usdata_dir,
    modality = "ultra",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

segm_multi_finetune_test_normal_mgdata = FinetuningSegmDataset(
    image_dir = multi_ft_test_n_im_mgdata_dir,
    mask_dir = multi_ft_test_n_msk_mgdata_dir,
    modality = "mammo",
    transform = only_resize_transforms,
    mask_transform = only_resize_transforms_masks
)

# Final Multimodal Finetune Training Set
segm_multi_finetune_test_set = torch.utils.data.ConcatDataset([segm_multi_finetune_test_benign_usdata, segm_multi_finetune_test_benign_mgdata, segm_multi_finetune_test_malignant_usdata, segm_multi_finetune_test_malignant_mgdata, segm_multi_finetune_test_normal_usdata, segm_multi_finetune_test_normal_mgdata])

In [None]:
# Data loaders
segm_multi_finetune_train_dataloader = torch.utils.data.DataLoader(
    segm_multi_finetune_train_set,
    batch_size = finetune_batch_size,
    shuffle = True,
    drop_last = False
)

segm_multi_finetune_val_dataloader = torch.utils.data.DataLoader(
    segm_multi_finetune_val_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False
)

segm_multi_finetune_test_dataloader = torch.utils.data.DataLoader(
    segm_multi_finetune_test_set,
    batch_size = finetune_batch_size,
    shuffle = False,
    drop_last = False
)

### 3.4. View pretrain augmented example and finetune normalized image example

In [None]:
# Visualize augmented images
NUM_IMAGES = 3
imgs = []

for idx in range(NUM_IMAGES):
    view1, view2 = us_pretrain_train_set[idx]
    
    # Add both views directly to the list
    imgs.append(view1)  # view1 is already a tensor
    imgs.append(view2)  # view2 is already a tensor

# Stack images along the batch dimension
imgs = torch.stack(imgs)

print(imgs.size())

# Display a grid of images
img_grid = torchvision.utils.make_grid(imgs, nrow=6, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0).numpy()

plt.figure(figsize=(10, 5))
plt.title('Augmented image examples')
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()

In [None]:
NUM_IMAGES = 6
imgs = []

for idx in range(NUM_IMAGES):
    img, label = classif_us_finetune_train_set[idx]  # Retrieve image and ignore label
    imgs.append(img)

# Stack images along the batch dimension
imgs = torch.stack(imgs)

print("Stacked image tensor size:", imgs.size())

# Display a grid of images
img_grid = torchvision.utils.make_grid(imgs, nrow=6, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0).numpy()

plt.figure(figsize=(10, 5))
plt.title('Normalized and padded finetuning image examples')
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()

In [None]:
NUM_IMAGES = 6
imgs = []
masks = []

for idx in range(NUM_IMAGES):
    img, mask = segm_us_finetune_train_set[idx]  # Retrieve image and mask
    imgs.append(img)
    masks.append(mask)

# Stack images and masks along the batch dimension
imgs = torch.stack(imgs)
masks = torch.stack(masks)

print("Stacked image tensor size:", imgs.size())
print("Stacked mask tensor size:", masks.size())

# Create and display a grid of images and masks
img_grid = torchvision.utils.make_grid(imgs, nrow=6, normalize=True, pad_value=0.9)
mask_grid = torchvision.utils.make_grid(masks, nrow=6, normalize=True, pad_value=0.9)

# Plot images
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(img_grid.permute(1, 2, 0).numpy())
axes[0].set_title('Images')
axes[0].axis('off')

# Plot masks
axes[1].imshow(mask_grid.permute(1, 2, 0).numpy(), cmap='gray')
axes[1].set_title('Masks')
axes[1].axis('off')

plt.show()
plt.close()

In [None]:
# Get a single batch from the DataLoader
batch = next(iter(segm_us_finetune_train_dataloader))

# Unpack images and masks
images, masks = batch

# Check shapes
print(f"Shape of images: {images.shape}")  # Should be [batch_size, channels, height, width]
print(f"Shape of masks: {masks.shape}")    # Should be [batch_size, 1, height, width]

# Visualize the first few images and masks in the batch
NUM_IMAGES = 3
imgs_to_plot = images[:NUM_IMAGES]
masks_to_plot = masks[:NUM_IMAGES]

# Unnormalize images for visualization
unnormalized_imgs = imgs_to_plot * 0.5 + 0.5  # normalization was [-1, 1] -> [0, 1]
unnormalized_imgs = unnormalized_imgs.permute(0, 2, 3, 1)  # [N, H, W, C] for visualization

plt.figure(figsize=(10, NUM_IMAGES * 5))
for i in range(NUM_IMAGES):
    # Plot image
    plt.subplot(NUM_IMAGES, 2, 2 * i + 1)
    plt.imshow(unnormalized_imgs[i].squeeze(-1), cmap='gray')  # Grayscale
    plt.title(f"Image {i + 1}")
    plt.axis('off')

    # Plot mask
    plt.subplot(NUM_IMAGES, 2, 2 * i + 2)
    plt.imshow(masks_to_plot[i].squeeze(0), cmap='gray')  # Single-channel mask
    plt.title(f"Mask {i + 1}")
    plt.axis('off')

plt.tight_layout()
plt.show()

## 4. Preparing model design components

### 4.1. Setting loss criterions

In [None]:
class SimCLRNTXentLoss(nn.Module):
    def __init__(self, temperature):
        super(SimCLRNTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):

        # Normalize the feature vectors
        z_i = F.normalize(z_i, p=2, dim=-1)
        z_j = F.normalize(z_j, p=2, dim=-1)

        # Concatenate the representations from both views
        representations = torch.cat([z_i, z_j], dim=0)

        # Compute cosine similarity between all pairs
        similarity_matrix = torch.matmul(representations, representations.T) / self.temperature

        # Remove diagonal elements (self-similarity) from the similarity matrix
        # Mask the diagonal with a large negative constant (-1e9) instead of -inf
        mask = torch.eye(similarity_matrix.size(0), dtype=torch.bool).to(similarity_matrix.device)
        similarity_matrix = similarity_matrix.masked_fill(mask, -1e9)

        #  Clamp similarity values if they're too large - add again if problems arise
        similarity_matrix = torch.clamp(similarity_matrix, -1, 1)

        # Create labels for positive pairs (view 1 vs view 2 of the same image)
        batch_size = z_i.size(0)
        labels = torch.cat([torch.arange(batch_size), torch.arange(batch_size)]).to(z_i.device)

        # Before calculating the loss
        if torch.isnan(similarity_matrix).any() or torch.isinf(similarity_matrix).any():
            print("Warning: NaN or Inf detected in similarity matrix")
            print(f"Similarity matrix: {similarity_matrix}")
            return torch.tensor(0.0).to(similarity_matrix.device)  # Return 0 (no crashing)
        
        # Apply cross-entropy loss
        loss = F.cross_entropy(similarity_matrix, labels)

        return loss

# Moco incorporates uses NTXentLoss too but needs adapting for the memory bank for negative samples.
class MoCoNTXentLoss(nn.Module):
    def __init__(self, queue_size, temperature):
        super(MoCoNTXentLoss, self).__init__()
        self.queue_size = queue_size
        self.temperature = temperature
        self.similarity_fn = nn.CosineSimilarity(dim=-1)

    def forward(self, query_features, key_features, queue):
        # Compute similarities between query and key
        positive_similarity = self.similarity_fn(query_features, key_features)

        # Calculate negative similarities from the queue
        negative_similarity = self.similarity_fn(query_features.unsqueeze(1), queue.unsqueeze(0))

        # Apply temperature scaling
        positive_similarity /= self.temperature
        negative_similarity /= self.temperature
        
        # Concatenate positive and negative logits
        logits = torch.cat([positive_similarity.unsqueeze(1), negative_similarity], dim=1)

        assert logits.size(0) == query_features.size(0), \
        f"Logits size {logits.size()} doesn't match batch size {query_features.size(0)}"
        
        # For preventing NaNs or Infinites for loss, could be adapted if -10 to 10 isnt ideal range for this case
        logits = torch.clamp(logits, min=-10, max=10)

        # Labels for contrastive loss (positive sample at index 0)
        labels = torch.zeros(query_features.size(0), dtype=torch.long).to(query_features.device)

        # Cross-entropy loss for positive-negative pairs
        loss = F.cross_entropy(logits, labels)
        return loss

# In BYOL contrastive loss isnt explicitly used.
class BYOLLoss(nn.Module):
    def __init__(self, epsilon = 1e-6):
        super(BYOLLoss, self).__init__()
        self.epsilon = epsilon

    def forward(self, z_i, z_j):
        
        # Normalize the feature vectors to unit vectors
        z_i = F.normalize(z_i, p=2, dim=-1)
        z_j = F.normalize(z_j, p=2, dim=-1)

        # Negative cosine similarity (1 - dot product between unit vectors)
        loss = 1 - F.cosine_similarity(z_i, z_j, dim=-1)

        # The final loss is averaged over the batch
        return loss.mean()
    

### 4.2. Setting contrastive learning model classes

In [None]:
# Used by Mocov2 and SimCLRv2
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ProjectionHead, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim), #not in original Mocov2 and Simclrv2 but adds stability without changing functionality
        )
        self.layer2 = nn.Linear(hidden_dim, output_dim)
 
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

# Pretrain model class for SimCLR v2
class SimCLR_v2(nn.Module):
    def __init__(self, backbone, projection_dim = 128):
        super(SimCLR_v2, self).__init__()
        self.backbone = backbone
        self.projection_dim = projection_dim

        #The key component outside of deeper net backbone (like a resnet 128) to change from simclr v1 to v2 
        #is having additional layers in the MLP projection head
        # Use the custom ProjectionHead class
        self.projection_head = ProjectionHead(input_dim=512, hidden_dim=512, output_dim=projection_dim)

    def forward(self, x):
        features = self.backbone(x).flatten(start_dim = 1)
        z = self.projection_head(features)
        return z

# Source used queue_size = 65536, in experiments here is reduced to 25%
class MoCo_v2(nn.Module):
    def __init__(self, backbone, projection_dim=128, queue_size=16384, momentum=0.999, queue_decay=0.99):
        
        super(MoCo_v2, self).__init__()
        self.backbone = backbone
        self.projection_head = ProjectionHead(input_dim=512, hidden_dim=512, output_dim=projection_dim)
        self.queue_size = queue_size
        self.momentum = momentum
        self.queue_decay = queue_decay

        # Initialize the projection head for keys
        self.projection_head_k = ProjectionHead(input_dim=512, hidden_dim=512, output_dim=projection_dim)

        # Initialize the momentum encoder (key encoder)
        self.key_encoder = nn.Sequential(
            *list(self.backbone.children())[:-1],  # Backbone layers minus the last layer (classifier)
            self.projection_head  # Add the projection head on top
        )

        # Initialize the momentum encoder (key encoder) with the projection head for keys
        self.momentum_encoder = nn.Sequential(
            *list(self.backbone.children())[:-1],  # Backbone layers minus the last layer (classifier)
            self.projection_head_k  # Add projection head for keys
        )

        # Initialize the queue and normalize
        self.register_buffer("queue", torch.randn(queue_size, projection_dim))
        self.queue = F.normalize(self.queue, dim=1)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))  # Queue pointer

        # Initialize momentum encoder and projection head for keys
        self.momentum_encoder = copy.deepcopy(self.backbone)
        self.projection_head_k = copy.deepcopy(self.projection_head)

        # Freeze gradients for momentum encoder
        for param in self.momentum_encoder.parameters():
            param.requires_grad = False
        for param in self.projection_head_k.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def momentum_update_key_encoder(self):
        # Momentum update for key encoder parameters
        for param_q, param_k in zip(self.backbone.parameters(), self.momentum_encoder.parameters()):
            param_k.data = param_k.data * self.momentum + param_q.data * (1. - self.momentum)
        for param_q, param_k in zip(self.projection_head.parameters(), self.projection_head_k.parameters()):
            param_k.data = param_k.data * self.momentum + param_q.data * (1. - self.momentum)

    @torch.no_grad()
    def update_queue(self, key_features):
        
        key_features = key_features.detach()  # Detach to avoid backpropagation through the queue
        key_features = F.normalize(key_features, dim=1)  # Normalize the key features

        batch_size = key_features.size(0)
        ptr = int(self.queue_ptr)

        # If queue_decay is active, blend the new keys with the existing queue values
        if self.queue_decay < 1.0:
            for i in range(batch_size):
                index = (ptr + i) % self.queue_size
                self.queue[index] = self.queue_decay * self.queue[index] + (1 - self.queue_decay) * key_features[i]
        else:  # Directly replace the keys in the queue
            end_pos = (ptr + batch_size) % self.queue_size
            if ptr + batch_size <= self.queue_size:
                self.queue[ptr:ptr + batch_size] = key_features
            else:
                part_1_len = self.queue_size - ptr
                self.queue[ptr:] = key_features[:part_1_len]
                self.queue[:end_pos] = key_features[part_1_len:]

        # Update the queue pointer
        self.queue_ptr[0] = (ptr + batch_size) % self.queue_size

    def forward(self, x_q, x_k):
        
        # Forward pass for query
        q = self.backbone(x_q).flatten(start_dim=1)
        q = self.projection_head(q)

        # Forward pass for key using momentum encoder
        with torch.no_grad():
            self.momentum_update_key_encoder()
            k = self.momentum_encoder(x_k).flatten(start_dim=1)
            k = self.projection_head_k(k)

        # Normalize q and k
        q = F.normalize(q, dim=1)
        k = F.normalize(k, dim=1)

        # Update the queue with the new key features
        self.update_queue(k)

        return q, k

# Pretrain model class for BYOL
class BYOL(nn.Module):
    def __init__(self, backbone, projection_dim = 128, momentum=0.99):
        super(BYOL, self).__init__()
        # Online encoder
        self.online_encoder = nn.Sequential(
            backbone,
            nn.Linear(backbone.output_dim, projection_dim),
            nn.BatchNorm1d(projection_dim),
            nn.ReLU()
        )
        
        # Prediction head
        self.prediction_head = nn.Sequential(
            nn.Linear(projection_dim, projection_dim),
            nn.BatchNorm1d(projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )
        
        # Target encoder (deepcopy from online encoder)
        self.target_encoder = copy.deepcopy(self.online_encoder)

        # Freeze target encoder parameters
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        
        self.momentum = momentum

    @torch.no_grad()
    def update_target_encoder(self):
        for param_online, param_target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_target.data = self.momentum * param_target.data + (1 - self.momentum) * param_online.data

    def forward(self, x1, x2):
        # Online encoder forward pass
        z1_online = self.online_encoder(x1)
        z2_online = self.online_encoder(x2)

        # Prediction head forward pass
        p1_online = self.prediction_head(z1_online)
        p2_online = self.prediction_head(z2_online)

        # Target encoder forward pass (no gradients)
        with torch.no_grad():
            z1_target = self.target_encoder(x1)
            z2_target = self.target_encoder(x2)

        return p1_online, p2_online, z1_target, z2_target


### 4.3. Setting saving and loading functions

In [None]:
def results_save(model_id, phase, save_path, results, current_epoch, total_epoch_count):
    # Save the checkpoint information: results, model ID, phase, and epoch.
       
    results_filename = f"logs_{model_id}_{phase}.pkl"
    results_path = os.path.join(save_path, results_filename)

    os.makedirs(save_path, exist_ok=True)
    
    # Check results file already exists
    if os.path.exists(results_path):
        with open(results_path, 'rb') as f:
            existing_results = pickle.load(f)

        # Get last epoch saved
        last_saved_epoch = max(existing_results['epoch'], default=0)

        # If last saved epoch is greater than or equal to the total epoch count, training is done
        if last_saved_epoch >= total_epoch_count:
            raise RuntimeError(f"Training is already complete for {model_id} in {phase} phase. "
                               f"Last saved epoch: {last_saved_epoch}. Restart training if necessary.")

        # Append new epochs' results, skipping already saved epochs
        for key in results.keys():
            existing_results[key].extend(results[key])

        results_to_save = existing_results
    else:
        # No existing file, save new results
        results_to_save = results

    # Save the updated or new results using pickle
    with open(results_path, 'wb') as f:
        pickle.dump(results_to_save, f)

    # If the current epoch equals the total work is complete
    if current_epoch == total_epoch_count:
        print(f"Training for {model_id} in {phase} phase is complete at epoch {current_epoch}.")

    return results_path
    
def load_results(results_file):
    with open(results_file, 'rb') as f:
        results = pickle.load(f)
    return results

### 4.4. Setting SimCLR Pretrain functions usable by classification and segmentation models alike

In [None]:
def train_one_epoch_simclr(pretrain_model, pretrain_criterion, train_dataloader, pretrain_optimizer, device):
    pretrain_model.train()
    total_loss = 0

    for (x0, x1) in train_dataloader:
        pretrain_optimizer.zero_grad()
        
        x0 = x0.to(device)
        x1 = x1.to(device)

        z0 = pretrain_model(x0)
        z1 = pretrain_model(x1)

        loss = pretrain_criterion(z0, z1)

        batch_size = x0.size(0)
        total_loss += loss.item() * batch_size

        loss.backward()

        torch.nn.utils.clip_grad_norm_(pretrain_model.parameters(), max_norm=1.0)

        pretrain_optimizer.step()
        
    avg_train_loss = total_loss / len(train_dataloader.dataset)
    
    return avg_train_loss

def validate_one_epoch_simclr(pretrain_model, pretrain_criterion, val_dataloader, device):
    pretrain_model.eval()
    total_loss = 0

    with torch.no_grad():
        for (x0, x1) in val_dataloader:
            x0 = x0.to(device)
            x1 = x1.to(device)
            
            z0 = pretrain_model(x0)
            z1 = pretrain_model(x1)

            loss = pretrain_criterion(z0, z1)

            batch_size = x0.size(0)
            total_loss += loss.item() * batch_size
            
    avg_val_loss = total_loss / len(val_dataloader.dataset)

    return avg_val_loss

def run_pretraining_simclr(pretrain_model, pretrain_criterion, train_dataloader, val_dataloader, pretrain_optimizer, pretrain_scheduler, device, total_epoch_count, storing_path, identifier, checkpoint_interval):
    pretrain_model.to(device)

    results_path = None
    results = {
        'epoch': [],
        'train_loss': [],
        'val_loss': []
    }

    print("»»» Starting Network Pretraining process \n")

    for epoch in range(total_epoch_count):
        avg_train_loss = train_one_epoch_simclr(
            pretrain_model, pretrain_criterion, train_dataloader, pretrain_optimizer, device
        )
        avg_val_loss = validate_one_epoch_simclr(
            pretrain_model, pretrain_criterion, val_dataloader, device
        )

        print(f"Epoch [{epoch+1}/{total_epoch_count}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        # Append results for the current epoch
        results['epoch'].append(epoch + 1) # Stored as part of the results data to analyze performance metrics later
        results['train_loss'].append(avg_train_loss)
        results['val_loss'].append(avg_val_loss)

        # Inside the epoch loop
        lr_before_scheduler = pretrain_optimizer.param_groups[0]['lr']
        print(f"Epoch [{epoch+1}/{total_epoch_count}], Learning Rate before scheduler: {lr_before_scheduler:.6f}")

        # Apply the scheduler if present
        if pretrain_scheduler:
            pretrain_scheduler.step(avg_val_loss)
            lr_after_scheduler = pretrain_optimizer.param_groups[0]['lr']
            print(f"Learning Rate after scheduler: {lr_after_scheduler:.6f}")

        if (epoch + 1) % checkpoint_interval == 0 or epoch == total_epoch_count - 1:
            print("»»» Checkpoint reached: Saving model state and results")
            try:
                results_path = results_save(identifier, "pt", storing_path, results, epoch, total_epoch_count)
            except RuntimeError as e:
                print(e)
                return None
        
            # Clear results to avoid duplication
            results = {key: [] for key in results.keys()}

    print("\n»»» Network Pretraining process complete\n»»» Logs saved at: ", results_path)
    
    #returns path of file with results
    return results_path


### 4.5. Setting MoCo Pretrain functions usable by classification and segmentation models alike

In [None]:
def train_one_epoch_moco(pretrain_model, pretrain_criterion, train_dataloader, pretrain_optimizer, device):
    # Set the online encoder to training mode
    pretrain_model.train()

    # Set the momentum encoder to eval mode (suggestion 6)
    pretrain_model.momentum_encoder.eval()

    # Freeze momentum encoder gradients (suggestion 7)
    for param in pretrain_model.momentum_encoder.parameters():
        param.requires_grad = False

    total_loss = 0

    for batch_idx, (x_q, x_k) in enumerate(train_dataloader):
        # Move to device
        x_q = x_q.to(device)
        x_k = x_k.to(device)

        # Get the query and key representations
        q, k = pretrain_model(x_q, x_k)

        # Compute the loss
        loss = pretrain_criterion(q, k, pretrain_model.queue)

        # Backpropagation
        loss.backward()
        pretrain_optimizer.step()
        pretrain_optimizer.zero_grad()

        # Update the queue inside the MoCo model
        pretrain_model.update_queue(k)

        batch_size = x_q.size(0)
        total_loss += loss.item() * batch_size

    avg_train_loss = total_loss / len(train_dataloader.dataset)

    return avg_train_loss

def validate_one_epoch_moco(pretrain_model, pretrain_criterion, val_dataloader, device):
    pretrain_model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch_idx, (x_q, x_k) in enumerate(val_dataloader):
            x_q = x_q.to(device)
            x_k = x_k.to(device)

            # Get the query and key representations
            q, k = pretrain_model(x_q, x_k)

            # Calculate the validation loss using the fixed queue
            loss = pretrain_criterion(q, k, pretrain_model.queue)

            batch_size = x_q.size(0)
            total_loss += loss.item() * batch_size

    avg_val_loss = total_loss / len(val_dataloader.dataset)
    
    return avg_val_loss

def run_pretraining_moco(pretrain_model, pretrain_criterion, train_dataloader, val_dataloader, pretrain_optimizer, pretrain_scheduler, device, total_epoch_count, storing_path, identifier, checkpoint_interval):
    pretrain_model.to(device)

    results_path = None
    results = {
        'epoch': [],
        'train_loss': [],
        'val_loss': []
    }

    print("»»» Starting MoCo Pretraining process \n")

    for epoch in range(total_epoch_count):
        # Training
        avg_train_loss = train_one_epoch_moco(
            pretrain_model, pretrain_criterion, train_dataloader, pretrain_optimizer, device
        )

        # Validation
        avg_val_loss = validate_one_epoch_moco(
            pretrain_model, pretrain_criterion, val_dataloader, device
        )

        print(f"Epoch [{epoch+1}/{total_epoch_count}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        # Append results for the current epoch
        results['epoch'].append(epoch + 1)
        results['train_loss'].append(avg_train_loss)
        results['val_loss'].append(avg_val_loss)

        if pretrain_scheduler:
            pretrain_scheduler.step()

        if (epoch + 1) % checkpoint_interval == 0 or epoch == total_epoch_count - 1:
            print("»»» Checkpoint reached: Saving model state and results")
            try:
                results_path = results_save(identifier, "pt", storing_path, results, epoch, total_epoch_count)
            except RuntimeError as e:
                print(f"Error saving checkpoint at epoch {epoch + 1}: {e}")
                return None

            # Clear results to avoid duplication
            results = {key: [] for key in results.keys()}

    print("\n»»» MoCo Pretraining process complete\n»»» Logs saved at: ", results_path)
    
    return results_path


### 4.6. Setting BYOL Pretrain functions usable by classification and segmentation models alike

In [None]:
def train_one_epoch_byol(pretrain_model, pretrain_criterion, train_dataloader, pretrain_optimizer, device):
    # Set online encoder to training mode
    pretrain_model.train()  
    total_loss = 0

    for batch_idx, (x1, x2) in enumerate(train_dataloader):
        x1, x2 = x1.to(device), x2.to(device)

        # Forward pass
        p1_online, p2_online, z1_target, z2_target = pretrain_model(x1, x2)

        # Compute BYOL loss
        loss = pretrain_criterion(p1_online, z1_target) + pretrain_criterion(p2_online, z2_target)

        # Backpropagation
        pretrain_optimizer.zero_grad()
        loss.backward()
        pretrain_optimizer.step()

        # Update target encoder
        pretrain_model.update_target_encoder()

        batch_size = x1.size(0)
        total_loss += loss.item() * batch_size
        
    avg_val_loss = total_loss / len(train_dataloader.dataset)

    return avg_val_loss

def validate_one_epoch_byol(pretrain_model, pretrain_criterion, val_dataloader, device):
    # Set the model to evaluation mode
    pretrain_model.eval()  
    total_loss = 0

    with torch.no_grad():
        for batch_idx, (x1, x2) in enumerate(val_dataloader):
            x1, x2 = x1.to(device), x2.to(device)

            # Forward pass
            p1_online, p2_online, z1_target, z2_target = pretrain_model(x1, x2)

            # Compute the loss
            loss = pretrain_criterion(p1_online, z1_target) + pretrain_criterion(p2_online, z2_target)

            batch_size = x1.size(0)
            total_loss += loss.item() * batch_size

    avg_val_loss = total_loss / len(val_dataloader.dataset)

    return avg_val_loss

def run_pretraining_byol(pretrain_model, pretrain_criterion, train_dataloader, val_dataloader, pretrain_optimizer, device, total_epoch_count, storing_path, identifier, checkpoint_interval):
    pretrain_model.to(device)

    results_path = None
    results = {
        'epoch': [],
        'train_loss': [],
        'val_loss': []
    }

    print("»»» Starting BYOL Pretraining process \n")

    for epoch in range(total_epoch_count):
        
        # Training
        avg_train_loss = train_one_epoch_byol(
            pretrain_model, pretrain_criterion, train_dataloader, pretrain_optimizer, device
        )
        
        # Validation
        avg_val_loss = validate_one_epoch_byol(
            pretrain_model, pretrain_criterion, val_dataloader, device
        )

        print(f"Epoch [{epoch+1}/{total_epoch_count}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        # Append results for the current epoch
        results['epoch'].append(epoch + 1)
        results['train_loss'].append(avg_train_loss)
        results['val_loss'].append(avg_val_loss)


        if (epoch + 1) % checkpoint_interval == 0 or epoch == total_epoch_count - 1:
            print("»»» Checkpoint reached: Saving model state and results")
            try:
                results_path = results_save(identifier, "pt", storing_path, results, epoch, total_epoch_count)
            except RuntimeError as e:
                print(f"Error saving checkpoint at epoch {epoch + 1}: {e}")
                return None

            # Clear results to avoid duplication
            results = {key: [] for key in results.keys()}

    print("\n»»» BYOL Pretraining process complete\n»»» Logs saved at: ", results_path)
    
    return results_path


### 4.7. Setting Pretrain phase loss plot function

In [None]:
def plot_pretrain_results(results_path):
    if not os.path.exists(results_path):
        print(f"File not found: {results_path}")
        return

    print(f"»»» Data taken from path: {results_path}")
    with open(results_path, 'rb') as f:
        results = pickle.load(f)

    # Extract the directory path from the results_path
    directory_path = os.path.dirname(results_path)

    # Extract the identifier correctly
    filename = os.path.basename(results_path)
    parts = filename.split('_')
    
    if len(parts) < 2:
        print(f"»»» Unexpected filename format: {filename}")
        return
    
    # Assume the identifier is the part between the underscores
    identifier = parts[1]  # Adjust if needed based on your file naming convention

    # Define the plot filename and path
    plot_filename = f"{identifier}_loss_pt.png"
    plot_path = os.path.join(os.path.dirname(results_path), plot_filename)

    # Verify and print loss data
    train_losses = results.get('train_loss', [])
    val_losses = results.get('val_loss', [])
    epochs = results.get('epoch', [])

    if not train_losses or not val_losses:
        print("»»» No loss data available for plotting.")
        return
    
    # Check for loss trends
    if train_losses[-1] > train_losses[0]:
        print("»»» Warning: Training loss increased over epochs.")
    if val_losses[-1] > val_losses[0]:
        print("»»» Warning: Validation loss increased over epochs.")

    # Plotting
    plt.figure(figsize=(12, 6))
    plt.plot(results['epoch'], results['train_loss'], label=f'Train Loss ({train_losses[-1]:.4f})', color='blue')
    plt.plot(results['epoch'], results['val_loss'], label=f'Val Loss ({val_losses[-1]:.4f})', color='orange')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Pretrained model {identifier} - Loss graph (Train / Val)')
    plt.legend()

    # Create tick positions at every 10th epoch starting from 1
    epoch_ticks = list(range(1, len(epochs) + 1, 10))

    # Ensure that the final tick (last epoch) is included
    if epoch_ticks[-1] != len(epochs):
        epoch_ticks.append(len(epochs))

    # Adjust ticks to subtract 1 from all ticks except the first and the last
    adjusted_ticks = [tick - 1 if tick != 1 and tick != len(epochs) else tick for tick in epoch_ticks]

    # Set the ticks on the x-axis with the appropriate labels
    plt.xticks(adjusted_ticks, labels=[str(epoch) for epoch in adjusted_ticks])


    plt.gca().yaxis.set_major_formatter(plt.FormatStrFormatter('%.1f'))
    plt.savefig(plot_path)
    plt.close()
    
    print(f"\n»»» Train/Val Loss Plot saved at {directory_path}")
    

### 4.8. Dataset class setting for t-distributed stochastic neighbour embedding visualization

In [None]:
# Pretraining unlabeled set - validation set for t-sne process (no augmentations)
class UnchangedDataset(Dataset):
  def __init__(self, image_dir, transform = None):
    self.image_dir = image_dir
    self.transform = transform
    self.images = [f for f in os.listdir(image_dir) if f.endswith(('jpg', 'png', 'jpeg'))]  # Filter for image files

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    img_path = os.path.join(self.image_dir, self.images[idx])
    image = Image.open(img_path).convert("L")
    if self.transform is not None:
        image = self.transform(image)
    return image

def collate_fn(batch):
    # Assuming each item in batch is a tensor of images
    images = [item[0] for item in batch]  # Extract the images
    return torch.stack(images)  # Stack them into a single tensor

def extract_embeddings(model, dataloader, device, is_byol=False):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for batch in dataloader:
            # Check the type and structure of the batch
            if isinstance(batch, torch.Tensor):
                images = batch.to(device)
            elif isinstance(batch, list):
                # Convert list of tensors to a single tensor
                images = torch.stack(batch).to(device)
            elif isinstance(batch, tuple):
                images = batch[0].to(device)
            else:
                raise TypeError(f"Unexpected batch type: {type(batch)}")

            # Extract features
            if is_byol:
                features = model(images).flatten(start_dim=1)  # Directly use the online encoder for BYOL
            else:
                features = model.backbone(images).flatten(start_dim=1)  # Use backbone for SimCLR and MoCo
            embeddings.append(features.cpu().numpy())

    return np.concatenate(embeddings)

def plot_tsne(embeddings, results_path):
    # Extract directory and model_id from results_path
    directory_path = os.path.dirname(results_path)
    filename = os.path.basename(results_path)
    model_id = filename.split('_')[1]  # Assuming model_id is between underscores

    # Create t-SNE plot
    tsne = TSNE(n_components=2, random_state=42)
    reduced_embeddings = tsne.fit_transform(embeddings)
    plt.figure(figsize=(10, 8))
    plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], alpha=0.6)

    plt.xlabel("t-SNE Dimension 1")
    plt.ylabel("t-SNE Dimension 2")
    fig = plt.gcf()  # Get current figure
    fig.suptitle(f'Pretrained model {model_id} - t-distributed stochastic neighbor embedding visualization', fontsize=16)
    
    # Save the t-SNE plot image
    plot_path = os.path.join(directory_path, f"{model_id}_tsne.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"»»» t-SNE plot saved at: {directory_path}")

def tsne_pipeline(model, dataloader, device, results_path, is_byol=False):
    if not is_byol:
        # Freeze backbone parameters for SimCLR/MoCo
        for param in model.backbone.parameters():
            param.requires_grad = False

    # Extract embeddings
    embeddings = extract_embeddings(model, dataloader, device, is_byol)
    plot_tsne(embeddings, results_path)

    if not is_byol:
        # Unfreeze backbone parameters for SimCLR/MoCo
        for param in model.backbone.parameters():
            param.requires_grad = True

In [None]:
tSNE_us_dataset = UnchangedDataset(
  image_dir = us_pt_val_im_dir,
  transform = ContrastiveTransformations(base_transforms = only_resize_transforms, n_views = 1)
)

tSNE_us_dataloader = torch.utils.data.DataLoader(
    tSNE_us_dataset,
    batch_size = pretrain_batch_size,
    shuffle = False,
    drop_last = False,
    num_workers=8,
    collate_fn = collate_fn
)

In [None]:
tSNE_mg_dataset = UnchangedDataset(
  image_dir = mg_pt_val_im_dir,
  transform = ContrastiveTransformations(base_transforms = only_resize_transforms, n_views = 1)
)

tSNE_mg_dataloader = torch.utils.data.DataLoader(
    tSNE_mg_dataset,
    batch_size = pretrain_batch_size,
    shuffle = False,
    drop_last = False,
    num_workers=8,
    collate_fn = collate_fn
)

In [None]:
#multi case requires merging things
tSNE_multi_pretrain_validation_usdata = UnchangedDataset(
    image_dir = multi_pt_val_im_usdata_dir,
    transform = ContrastiveTransformations(base_transforms = only_resize_transforms, n_views = 1)
)

tSNE_multi_pretrain_validation_mgdata = UnchangedDataset(
    image_dir = multi_pt_val_im_mgdata_dir,
    transform = ContrastiveTransformations(base_transforms = only_resize_transforms, n_views = 1)
)

# Concatenate the ultrasound and mammography pretrain datasets
tSNE_multi_pretrain_validation_dataset = torch.utils.data.ConcatDataset([tSNE_multi_pretrain_validation_usdata, tSNE_multi_pretrain_validation_mgdata])

tSNE_multi_dataloader = torch.utils.data.DataLoader(
    tSNE_multi_pretrain_validation_dataset,
    batch_size = pretrain_batch_size,
    shuffle = False,
    drop_last = False,
    num_workers=8,
    collate_fn = collate_fn
)

### 4.9. Setting Finetune / Test / Plotting functions specific to classification

In [None]:
# Finetuning Functions
def classif_train_one_epoch(finetune_model, finetune_criterion, train_dataloader, finetune_optimizer, device):
    finetune_model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []

    for images, labels in train_dataloader:
        finetune_optimizer.zero_grad()

        # Transfer Data to Device
        images = images.to(device)
        labels = labels.to(device)

        # Forward Pass
        outputs = finetune_model(images)

        # Calculate Loss
        loss = finetune_criterion(outputs, labels)
        #total_loss += loss.item()

        #just checking
        total_loss += loss.item() * images.size(0)
        
        # Compute Predictions and Probabilities
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        # Apply softmax to get probabilities
        probs = F.softmax(outputs, dim=1)
        all_probs.extend(probs.detach().cpu().numpy())

        # Backward Pass and Optimization Step
        loss.backward()
        
        finetune_optimizer.step()

    avg_loss = total_loss / len(train_dataloader.dataset)

    # Convert lists to numpy arrays for consistency
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    return avg_loss, all_labels, all_preds, all_probs

def classif_validate_one_epoch(finetune_model, finetune_criterion, val_dataloader, device):
    finetune_model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for images, labels in val_dataloader:
            # Transfer Data to Device
            images = images.to(device)
            labels = labels.to(device)

            # Forward Pass
            outputs = finetune_model(images)
            
            # Calculate Loss
            loss = finetune_criterion(outputs, labels)
            
            total_loss += loss.item() * images.size(0)

            # Compute Predictions and Probabilities
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            # Apply softmax to get probabilities
            probs = F.softmax(outputs, dim=1)
            all_probs.extend(probs.detach().cpu().numpy())

    avg_loss = total_loss / len(val_dataloader.dataset)

    # Convert lists to numpy arrays for consistency
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    return avg_loss, all_labels, all_preds, all_probs

def classif_run_finetuning(finetune_model, finetune_criterion, train_dataloader, val_dataloader, finetune_optimizer, finetune_scheduler, device, total_epoch_count, storing_path, identifier, checkpoint_interval):

    results_path = None
    results = {
        'epoch': [],
        'train_loss': [],
        'val_loss': [],
        'train_labels': [],
        'train_preds': [],
        'train_probs': [],
        'val_labels': [],
        'val_preds': [],
        'val_probs': []
    }

    print("»»» Starting Classification Network Finetuning process")

    for epoch in range(total_epoch_count):
        # Training
        train_metrics = classif_train_one_epoch(finetune_model, finetune_criterion, train_dataloader, finetune_optimizer, device)

        #Validation
        val_metrics = classif_validate_one_epoch(finetune_model, finetune_criterion, val_dataloader, device)
        
        # Unpack the train and validation metrics
        avg_train_loss, train_labels, train_preds, train_probs = train_metrics
        avg_val_loss, val_labels, val_preds, val_probs = val_metrics

        print(f"Epoch [{epoch+1}/{total_epoch_count}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        results['epoch'].append(epoch + 1)
        results['train_loss'].append(avg_train_loss)
        results['train_labels'].append(train_labels.tolist())
        results['train_preds'].append(train_preds.tolist())      
        results['train_probs'].append(train_probs.tolist())      
        results['val_loss'].append(avg_val_loss)
        results['val_labels'].append(val_labels.tolist())    
        results['val_preds'].append(val_preds.tolist())  
        results['val_probs'].append(val_probs.tolist()) 
        
        if finetune_scheduler:
            finetune_scheduler.step(avg_val_loss)

        if (epoch + 1) % checkpoint_interval == 0 or (epoch + 1) == total_epoch_count:
            print(f"»»» Checkpoint reached at epoch {(epoch + 1)}: Saving model state and results")
            try:
                results_path = results_save(identifier, "ft", storing_path, results, epoch, total_epoch_count)
            except RuntimeError as e:
                print(e)
                return None
            
            # Clear results to avoid duplication
            results = {key: [] for key in results.keys()}

    print("\n»»» Classification Network Finetuning process complete\n»»» Logs saved at: ", results_path)
    
    return results_path # Path of files with model and results

def classif_plot_finetuning_results(results_path, classes):
    
    if not os.path.exists(results_path):
        print(f"File not found: {results_path}")
        return

    print(f"»»» Data taken from path: {results_path}")
    with open(results_path, 'rb') as f:
        results = pickle.load(f)

    # Extract directory and identifier from results_path
    directory_path = os.path.dirname(results_path)
    identifier = os.path.basename(results_path).split('_')[1]

    plot_paths = {
        'box1_pr': os.path.join(directory_path, f"{identifier}_precision_recall_ft.png"),
        'box2_roc_auc': os.path.join(directory_path, f"{identifier}_roc_auc_ft.png"),
        'box3_loss_and_value_metrics': os.path.join(directory_path, f"{identifier}_loss_value_metrics_ft.png")
    }
    
    # Extract data
    epochs = np.array(results['epoch'])  
    train_loss = np.array(results['train_loss'])
    val_loss = np.array(results['val_loss'])
    train_labels = np.concatenate(results['train_labels'])
    train_probs = np.concatenate(results['train_probs'])
    val_labels = np.concatenate(results['val_labels'])
    val_probs = np.concatenate(results['val_probs'])
    train_preds = np.concatenate(results['train_preds'])
    val_preds = np.concatenate(results['val_preds'])

    # Colors for each class for consistent styling
    colors = {'benign': 'blue', 'malignant': 'orange', 'normal': 'green'}

    # 4x2 grid for the subplots
    fig, axes = plt.subplots(4, 2, figsize=(15, 25))

    # Convert train_labels to one-hot encoding to match the shape of train_probs
    train_labels_one_hot = np.eye(len(classes))[train_labels]
    val_labels_one_hot = np.eye(len(classes))[val_labels]

    # Calculate Overall AP for Train and Validation
    overall_ap_train = average_precision_score(train_labels_one_hot, train_probs, average='weighted')
    overall_ap_val = average_precision_score(val_labels_one_hot, val_probs, average='weighted')

    # Top Row - Precision-Recall for Train and Validation
    # Box 1: Precision-Recall (Train)
    for i, class_name in enumerate(classes):
        binary_labels = (train_labels == i).astype(int)
        try:
            precision, recall, _ = precision_recall_curve(binary_labels, train_probs[:, i])
            avg_precision = average_precision_score(binary_labels, train_probs[:, i])
            axes[0, 0].plot(recall, precision, label=f'{class_name} (AP={avg_precision:.2f})', color=colors[class_name])
        except Exception as e:
            print(f"Error plotting Precision-Recall for {class_name} (Train): {e}")

    axes[0, 0].set_xlabel('Recall')
    axes[0, 0].set_ylabel('Precision')
    axes[0, 0].set_title(f'Train - merged PR curves and overall weighted AP score = {overall_ap_train:.2f}')
    axes[0, 0].legend()

    # Box 2: Precision-Recall (Validation)
    for i, class_name in enumerate(classes):
        binary_labels = (val_labels == i).astype(int)
        try:
            precision, recall, _ = precision_recall_curve(binary_labels, val_probs[:, i])
            avg_precision = average_precision_score(binary_labels, val_probs[:, i])
            axes[0, 1].plot(recall, precision, label=f'{class_name} (AP={avg_precision:.2f})', color=colors[class_name])
        except Exception as e:
            print(f"Error plotting Precision-Recall for {class_name} (Validation): {e}")

    axes[0, 1].set_xlabel('Recall')
    axes[0, 1].set_ylabel('Precision')
    axes[0, 1].set_title(f'Validation - merged PR curves and overall weighted AP score = {overall_ap_val:.2f}')
    axes[0, 1].legend()
    
    # Row 2-4: Class-specific Precision-Recall curves (Train and Validation)
    for i, class_name in enumerate(classes):
        class_name = class_name.lower()  # Convert to lowercase to match the dictionary keys

        # PR curve for each class
        binary_train_labels = (train_labels == i).astype(int)
        binary_val_labels = (val_labels == i).astype(int)

        precision_train, recall_train, _ = precision_recall_curve(binary_train_labels, train_probs[:, i])
        precision_val, recall_val, _ = precision_recall_curve(binary_val_labels, val_probs[:, i])

        avg_precision_train_class = average_precision_score(binary_train_labels, train_probs[:, i])
        avg_precision_val_class = average_precision_score(binary_val_labels, val_probs[:, i])

        # Plotting each class separately
        axes[i+1, 0].plot(recall_train, precision_train, color=colors[class_name], label=f'{class_name} (AP={avg_precision_train_class:.2f})')
        axes[i+1, 0].set_title(f'Train - {class_name} PR Curve')
        axes[i+1, 0].set_xlabel('Recall')
        axes[i+1, 0].set_ylabel('Precision')
        axes[i+1, 0].legend()

        axes[i+1, 1].plot(recall_val, precision_val, color=colors[class_name], label=f'{class_name} (AP={avg_precision_val_class:.2f})')
        axes[i+1, 1].set_title(f'Validation - {class_name} PR Curve')
        axes[i+1, 1].set_xlabel('Recall')
        axes[i+1, 1].set_ylabel('Precision')
        axes[i+1, 1].legend()

    # Adjust layout and save the figure
    fig.suptitle(f'Classification model {identifier} - Weighted Precision-Recall Curves (Train / Val)', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(plot_paths['box1_pr'])
    plt.close()

    print(f"»»» Combined Precision-Recall plot saved at {plot_paths['box1_pr']}")

    #----------------
    # Create a 4x2 grid for the subplots
    fig, axes = plt.subplots(4, 2, figsize=(15, 25))

    # Top Row - Overall Weighted ROC-AUC for Train and Validation
    try:
        # Weighted ROC-AUC for train and validation
        roc_auc_weighted_train = roc_auc_score(train_labels, train_probs, multi_class='ovr', average='weighted')
        roc_auc_weighted_val = roc_auc_score(val_labels, val_probs, multi_class='ovr', average='weighted')

        # Class-specific ROC curves for Train
        for i, class_name in enumerate(classes):
            fpr, tpr, _ = roc_curve((train_labels == i).astype(int), train_probs[:, i])
            roc_auc = roc_auc_score((train_labels == i).astype(int), train_probs[:, i])
            axes[0, 0].plot(fpr, tpr, label=f'{class_name} (AUC={roc_auc:.2f})', color=colors[class_name])
        
        # Line for random guess and a legend with weighted AUC
        axes[0, 0].plot([0, 1], [0, 1], 'k--', label='Random Guess')
        axes[0, 0].set_title(f'Train - merged ROC curves and overall weighted ROC-AUC score = {roc_auc_weighted_train:.2f}')
        axes[0, 0].set_xlabel('False Positive Rate')
        axes[0, 0].set_ylabel('True Positive Rate')
        axes[0, 0].legend()

        # Plot class-specific ROC curves for Validation
        for i, class_name in enumerate(classes):
            fpr, tpr, _ = roc_curve((val_labels == i).astype(int), val_probs[:, i])
            roc_auc = roc_auc_score((val_labels == i).astype(int), val_probs[:, i])
            axes[0, 1].plot(fpr, tpr, label=f'{class_name} (AUC={roc_auc:.2f})', color=colors[class_name])
        
        # Line for random guess and a legend with weighted AUC
        axes[0, 1].plot([0, 1], [0, 1], 'k--', label='Random Guess')
        axes[0, 1].set_title(f'Validation - merged ROC curves and overall weighted ROC-AUC score = {roc_auc_weighted_val:.2f}')
        axes[0, 1].set_xlabel('False Positive Rate')
        axes[0, 1].set_ylabel('True Positive Rate')
        axes[0, 1].legend()
    except Exception as e:
        print(f"Error plotting weighted ROC-AUC: {e}")

    # Rows 2-4: Class-specific ROC-AUC curves (Train and Validation)
    for i, class_name in enumerate(classes):
        # Class-specific ROC-AUC for Train
        try:
            fpr_train, tpr_train, _ = roc_curve((train_labels == i).astype(int), train_probs[:, i])
            roc_auc_train = roc_auc_score((train_labels == i).astype(int), train_probs[:, i])
            axes[i + 1, 0].plot(fpr_train, tpr_train, label=f'AUC={roc_auc_train:.2f}', color=colors[class_name])
            axes[i + 1, 0].plot([0, 1], [0, 1], 'k--', label='Random Guess')
            axes[i + 1, 0].set_title(f'Train - {class_name} ROC Curve')
            axes[i + 1, 0].set_xlabel('False Positive Rate')
            axes[i + 1, 0].set_ylabel('True Positive Rate')
            axes[i + 1, 0].legend()
        except Exception as e:
            print(f"Error plotting ROC-AUC for {class_name} (Train): {e}")

        # Class-specific ROC-AUC for Validation
        try:
            fpr_val, tpr_val, _ = roc_curve((val_labels == i).astype(int), val_probs[:, i])
            roc_auc_val = roc_auc_score((val_labels == i).astype(int), val_probs[:, i])
            axes[i + 1, 1].plot(fpr_val, tpr_val, label=f'{class_name} (AUC={roc_auc:.2f})', color=colors[class_name])
            axes[i + 1, 1].plot([0, 1], [0, 1], 'k--', label='Random Guess')
            axes[i + 1, 1].set_title(f'Validation - {class_name} ROC Curve')
            axes[i + 1, 1].set_xlabel('False Positive Rate')
            axes[i + 1, 1].set_ylabel('True Positive Rate')
            axes[i + 1, 1].legend()
        except Exception as e:
            print(f"Error plotting ROC-AUC for {class_name} (Validation): {e}")

    # Adjust layout
    fig.suptitle(f'Classification model {identifier} - Weighted ROC Curves (Train / Val)', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    # Save the figure
    plt.savefig(plot_paths['box2_roc_auc'])
    plt.close()

    print(f"»»» Combined ROC-AUC plot saved at {plot_paths['box2_roc_auc']}")

    #----------------
    # # Define metrics dictionary
    metrics = {
        "MCC (Train)": round(matthews_corrcoef(train_labels, train_preds), 5),
        "MCC (Validation)": round(matthews_corrcoef(val_labels, val_preds), 5),
        "Balanced Accuracy (Train)": round(balanced_accuracy_score(train_labels, train_preds), 5),
        "Balanced Accuracy (Validation)": round(balanced_accuracy_score(val_labels, val_preds), 5),
        "Weighted F1 Score (Train)": round(f1_score(train_labels, train_preds, average='weighted'), 5),
        "Weighted F1 Score (Validation)": round(f1_score(val_labels, val_preds, average='weighted'), 5)
    }

    # # Convert to DataFrame
    metrics_df = pd.DataFrame(metrics, index=[0]).T

    # Create figure with subplots (left for table, right for loss curves)
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))  # Adjust size for both table and plot

    # Extract the last loss values
    final_train_loss = results['train_loss'][-1] if results['train_loss'] else None
    final_val_loss = results['val_loss'][-1] if results['val_loss'] else None
    
    # Plotting Train vs Validation Loss on the right side (axes[1])
    axes[1].plot(epochs, train_loss, label=f'Train Loss ({final_train_loss:.4f})', color='orange')
    axes[1].plot(epochs, val_loss, label=f'Val Loss ({final_val_loss:.4f})', color='blue')

    # Adding labels and title for loss plot
    axes[1].set_xlabel('Epochs')
    axes[1].set_ylabel('Loss')
    axes[1].legend(loc='upper right')

    #--
    epoch_ticks = list(range(1, len(epochs) + 1, 10))

    # Ensure the final epoch is included
    if epoch_ticks[-1] != len(epochs):
        epoch_ticks.append(len(epochs))

    # Adjust ticks to subtract 1 from all except first and last
    adjusted_ticks = [tick - 1 if tick != 1 and tick != len(epochs) else tick for tick in epoch_ticks]

    # Apply to the axes[1] plot
    axes[1].set_xticks(adjusted_ticks)
    axes[1].set_xticklabels([str(epoch) for epoch in adjusted_ticks])
    #--

    # Left: Metrics table without header
    axes[0].axis('off')
    table = axes[0].table(cellText=metrics_df.values, rowLabels=metrics_df.index, cellLoc='center', loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.auto_set_column_width(col=list(range(len(metrics_df.columns))))

    # Add overall title for the figure
    fig.suptitle(f'Classification model {identifier} - Single value metrics and Loss graph (Train / Val)', fontsize=16)

    # Save the combined figure
    plt.savefig(plot_paths['box3_loss_and_value_metrics'], bbox_inches='tight')
    plt.close()

    # Print confirmation
    print(f"»»» Combined (loss plot + metrics table) image saved at {plot_paths['box3_loss_and_value_metrics']}")

# Testing Functions
def classif_run_testing(test_model, test_criterion, test_dataloader, storing_path, identifier, device):
    test_model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []

    print("»»» Starting Classification Network Testing process \n")

    with torch.no_grad():
        for images, labels in test_dataloader:
            # Transfer Data to Device
            images = images.to(device)
            labels = labels.to(device)

            # Forward Pass
            outputs = test_model(images)

            # Calculate Loss
            loss = test_criterion(outputs, labels)
            
            total_loss += loss.item() * images.size(0)

            # Compute Predictions and Probabilities
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())

            # Apply softmax to get probabilities
            probs = F.softmax(outputs, dim=1)
            all_probs.extend(probs.detach().cpu().numpy().tolist())

    # Calculate Confusion Matrix for Multiclass Classification
    cm = confusion_matrix(all_labels, all_preds)

    # Calculate Sensitivity and Specificity per class
    sensitivity = {}
    specificity = {}
    for i in range(cm.shape[0]):
        tp = cm[i, i]  # True Positive: Diagonal elements
        fn = cm[i, :].sum() - tp  # False Negative: Row sum minus TP
        fp = cm[:, i].sum() - tp  # False Positive: Column sum minus TP
        tn = cm.sum() - (tp + fn + fp)  # True Negative: Total sum minus FP, FN, TP
        
        sensitivity[i] = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity[i] = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    results_filename = f"logs_{identifier}_test.pkl"
    results_path = os.path.join(storing_path, results_filename)

    results = {
        'test_loss': total_loss / len(test_dataloader.dataset),
        'test_labels': all_labels,  # True labels
        'test_preds': all_preds,    # Predicted labels
        'test_probs': all_probs,    # Probabilities
        'sensitivity': sensitivity, 
        'specificity': specificity,
        'cm': cm.tolist()
    }
    
    # Save results using pickle
    with open(results_path, 'wb') as f:
        pickle.dump(results, f)

    print(f"\n»»» Classification Network Testing process complete\n»»» Logs saved at: {results_path}")

    return results_path

def classif_plot_testing_results(results_path, classes):
    if not os.path.exists(results_path):
        print(f"File not found: {results_path}")
        return

    print(f"»»» Data taken from path: {results_path}")
    with open(results_path, 'rb') as f:
        results = pickle.load(f)

    # Extract directory and identifier from results_path
    directory_path = os.path.dirname(results_path)
    identifier = os.path.basename(results_path).split('_')[1]  # Adjust if needed

    # Define plot paths for the images
    plot_paths = {
        'box1_pr': os.path.join(directory_path, f"{identifier}_precision_recall_test.png"),
        'box2_roc_auc': os.path.join(directory_path, f"{identifier}_roc_auc_test.png"),
        'box3_value_metrics_cmatrix': os.path.join(directory_path, f"{identifier}_value_metrics_cmatrix_test.png")
    }

    # Extract data for test phase from results
    test_loss = np.array(results.get('test_loss', []))  # Convert to NumPy array if it exists
    test_labels = results['test_labels']
    test_probs = np.array(results['test_probs'])
    test_preds = results['test_preds']
    cm = np.array(results['cm'])

    # Define colors for each class
    colors = {'benign': 'blue', 'malignant': 'orange', 'normal': 'green', 'overall': 'black'}

    # Precision-Recall Curves
    try:
        fig, axes = plt.subplots(4, 1, figsize=(10, 20))

        # Overall Precision-Recall curve (just showing the AP value, no line)
        binary_labels = np.array([1 if label in classes else 0 for label in test_labels])  # Combined binary labels
        avg_precision = average_precision_score(binary_labels, test_probs.mean(axis=1)) if np.any(binary_labels) else 0.0
        axes[0].set_title(f'Test - merged PR curves and overall weighted AP score = {avg_precision:.2f}')  # Title only with the overall value
        axes[0].set_xlabel('Recall')
        axes[0].set_ylabel('Precision')

        # Class-specific Precision-Recall curves in separate plots
        for i, class_name in enumerate(classes):
            binary_labels = np.where(np.array(test_labels) == i, 1, 0)
            
            # Ensure both positive and negative samples for AP calculation
            if np.any(binary_labels):  # Only compute if there are positive samples
                precision, recall, _ = precision_recall_curve(binary_labels, test_probs[:, i])
                avg_precision = average_precision_score(binary_labels, test_probs[:, i]) if np.any(binary_labels) else 0.0
            else:
                precision, recall = [0], [1]  # No positive samples, default to no PR curve
                avg_precision = 0.0  # Assign default value for AP
            
            # Plot each class curve
            axes[0].plot(recall, precision, label=f'{class_name} (AP={avg_precision:.2f})', color=colors[class_name])

        axes[0].legend()

        # Class-specific Precision-Recall curves in separate plots
        for i, class_name in enumerate(classes):
            binary_labels = np.where(np.array(test_labels) == i, 1, 0)
            
            # Ensure both positive and negative samples for AP calculation
            if np.any(binary_labels):  # Only compute positive samples
                precision, recall, _ = precision_recall_curve(binary_labels, test_probs[:, i])
                avg_precision = average_precision_score(binary_labels, test_probs[:, i]) if np.any(binary_labels) else 0.0
            else:
                precision, recall = [0], [1]  # No positive samples, default to no PR curve
                avg_precision = 0.0  # Assign a default value for AP

            axes[i + 1].plot(recall, precision, label=f'{class_name} (AP={avg_precision:.2f})', color=colors[class_name])
            axes[i + 1].set_title(f'Test - {class_name} PR Curve')
            axes[i + 1].set_xlabel('Recall')
            axes[i + 1].set_ylabel('Precision')
            axes[i + 1].legend()

        plt.tight_layout()
        plt.subplots_adjust(top=0.95)
        fig.suptitle(f'Classification model {identifier} - Weighted Precision-Recall Curves (Test)', fontsize=16)
        plt.savefig(plot_paths['box1_pr'])
        plt.close()
        print(f"»»» Precision-Recall curve saved at {plot_paths['box1_pr']}")
    except Exception as e:
        print(f"Error plotting Precision-Recall curves: {e}")


    # ROC-AUC Curves
    try:
        fig, axes = plt.subplots(4, 1, figsize=(10, 20))

        roc_auc_weighted_test = roc_auc_score(test_labels, test_probs, multi_class='ovr', average='weighted')

        # Overall ROC-AUC curve (including the random guess line)
        axes[0].plot([0, 1], [0, 1], 'k--', label='Random Guess')  # Random guess line
        for i, class_name in enumerate(classes):
            binary_labels = np.where(np.array(test_labels) == i, 1, 0)
            fpr, tpr, _ = roc_curve(binary_labels, test_probs[:, i])
            roc_auc = roc_auc_score(binary_labels, test_probs[:, i])

            # Plot each class curve in the top plot
            axes[0].plot(fpr, tpr, label=f'{class_name} (AUC={roc_auc:.2f})', color=colors[class_name])

        axes[0].set_title(f'Test - merged ROC curves and overall weighted ROC-AUC score = {roc_auc_weighted_test:.2f}')
        axes[0].set_xlabel('False Positive Rate')
        axes[0].set_ylabel('True Positive Rate')
        axes[0].legend()

        # Class-specific ROC-AUC curves in separate plots (including random guess line)
        for i, class_name in enumerate(classes):
            binary_labels = np.where(np.array(test_labels) == i, 1, 0)
            fpr, tpr, _ = roc_curve(binary_labels, test_probs[:, i])
            roc_auc = roc_auc_score(binary_labels, test_probs[:, i])

            # Plot each class curve in individual plots
            axes[i + 1].plot(fpr, tpr, label=f'{class_name} (AUC={roc_auc:.2f})', color=colors[class_name])
            axes[i + 1].plot([0, 1], [0, 1], 'k--', label='Random Guess')  # Random guess line
            axes[i + 1].set_title(f'Test - {class_name} ROC Curve')
            axes[i + 1].set_xlabel('False Positive Rate')
            axes[i + 1].set_ylabel('True Positive Rate')
            axes[i + 1].legend()

        plt.tight_layout()
        plt.subplots_adjust(top=0.95)
        fig.suptitle(f'Classification model {identifier} - Weighted ROC Curves (Test)', fontsize=16)
        plt.savefig(plot_paths['box2_roc_auc'])
        plt.close()
        print(f"»»» ROC-AUC curve saved at {plot_paths['box2_roc_auc']}")
    except Exception as e:
        print(f"Error plotting ROC curves: {e}")

    # Test Metrics
    # Extract data for test phase from results
    test_loss = results.get('test_loss', 0)  # Final test loss
    sensitivity = results['sensitivity']
    specificity = results['specificity']

    # Define metrics for the table
    value_metrics = {
        "MCC": round(matthews_corrcoef(test_labels, test_preds), 5),
        "Balanced Accuracy": round(balanced_accuracy_score(test_labels, test_preds), 5),
        "Weighted F1 Score": round(f1_score(test_labels, test_preds, average='weighted'), 5),
        "Final Loss Value": round(test_loss, 5)
    }
    
    # Adding Sensitivity and Specificity per class
    for i, class_name in enumerate(classes):
        value_metrics[f"Sensitivity ({class_name})"] = round(sensitivity.get(i, 0), 5)
        value_metrics[f"Specificity ({class_name})"] = round(specificity.get(i, 0), 5)

    metrics_df = pd.DataFrame(value_metrics, index=[0]).T

    # Plotting the metrics table and confusion matrix
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))  # Two subplots side by side
    ax1.axis('off')  # Turn off axis for table
    table = ax1.table(cellText=metrics_df.values, rowLabels=metrics_df.index, loc='center')
    table.auto_set_column_width([0, 1, 2])  # Adjust columns for the metrics and values

    fig.suptitle(f'Classification model {identifier} - Single value metrics and Confusion Matrix (Test)', fontsize=16, ha='center')

    # Plot the confusion matrix on the second axis (ax2)
    cm_display = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    cm_display.plot(cmap='Blues', ax=ax2)

    # Adjust the layout to prevent overlap
    plt.tight_layout()
    
    # Save the figure
    plt.savefig(plot_paths['box3_value_metrics_cmatrix'], bbox_inches='tight')
    plt.close()

    print(f"»»» Metrics table and confusion matrix saved at {plot_paths['box3_value_metrics_cmatrix']}")
    

### 4.10. Setting Finetune / Test / Plotting functions specific to segmentation

In [None]:
# Finetuning Functions
def segm_train_one_epoch(finetune_model, finetune_criterion, train_dataloader, finetune_optimizer, device):
    finetune_model.train()
    running_loss = 0.0
    correct_pixels = 0
    total_pixels = 0
    all_preds = []
    all_labels = []

    for batch_idx, (inputs, labels) in enumerate(train_dataloader):

        # Move data to device
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients
        finetune_optimizer.zero_grad()

        # Forward pass
        outputs = finetune_model(inputs)

        # Compute loss
        loss = finetune_criterion(outputs, labels)

        # Accumulate loss, and the inputs.size(0) ensures loss contribution is weighted by number of samples in batch
        running_loss += loss.item() * inputs.size(0) 

        # Backward pass and optimizer step
        loss.backward()
        
        # Clip gradients to a max norm
        torch.nn.utils.clip_grad_norm_(finetune_model.parameters(), max_norm=0.5)  

        finetune_optimizer.step()

        # Detach predictions and accumulate for IoU/Dice
        binary_preds = (outputs.detach() > 0.5).float()
        correct_pixels += (binary_preds == labels).sum().item()
        total_pixels += labels.numel()

        # Store predictions and labels
        all_preds.extend(binary_preds.cpu().numpy().astype(int).flatten())
        all_labels.extend(labels.cpu().numpy().astype(int).flatten())

    # Convert lists to numpy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    train_loss = running_loss / len(train_dataloader.dataset)
    train_iou = jaccard_score(all_labels, all_preds, average='binary')
    train_pixel_accuracy = correct_pixels / total_pixels

    return train_loss, train_iou, train_pixel_accuracy

def segm_validate_one_epoch(finetune_model, finetune_criterion, val_dataloader, device, is_last_epoch):
    finetune_model.eval()
    running_loss = 0.0
    correct_pixels = 0
    total_pixels = 0
    all_preds = []
    all_labels = []

    final_epoch_preds = []
    final_epoch_labels = []
    final_epoch_inputs = []

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(val_dataloader):
            # Move data to device
            inputs, labels = inputs.to(device), labels.to(device)
                
            # Forward pass
            outputs = finetune_model(inputs)

            # Compute loss
            loss = finetune_criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)

            # Threshold predictions (binarization or softmax)
            binary_preds = (outputs > 0.5).float()
            correct_pixels += (binary_preds == labels).sum().item()
            total_pixels += labels.numel()

            # Store predictions and labels
            all_preds.extend(binary_preds.cpu().numpy().astype(int).flatten())
            all_labels.extend(labels.cpu().numpy().astype(int).flatten())

            # If it's the last epoch, store data for example views (images, masks, predictions)
            if is_last_epoch:
                final_epoch_preds.append(binary_preds.cpu().numpy())
                final_epoch_labels.append(labels.cpu().numpy())
                final_epoch_inputs.append(inputs.cpu().numpy())  # Store original inputs (images)

    # Convert lists to numpy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    if is_last_epoch:
        # Safely concatenate final epoch data only if non-empty
        final_epoch_preds = np.concatenate(final_epoch_preds, axis=0) if final_epoch_preds else np.array([])
        final_epoch_labels = np.concatenate(final_epoch_labels, axis=0) if final_epoch_labels else np.array([])
        final_epoch_inputs = np.concatenate(final_epoch_inputs, axis=0) if final_epoch_inputs else np.array([])

    val_loss = running_loss / len(val_dataloader.dataset)
    val_pixel_accuracy = correct_pixels / total_pixels
    val_iou = jaccard_score(all_labels, all_preds, average='binary')
    val_dice = f1_score(all_labels, all_preds, average='binary')

    return val_loss, val_iou, val_dice, val_pixel_accuracy, all_labels, all_preds, final_epoch_labels, final_epoch_preds, final_epoch_inputs

def segm_run_finetuning(finetune_model, finetune_criterion, train_dataloader, val_dataloader, finetune_optimizer, finetune_scheduler, device, total_epoch_count, storing_path, identifier, checkpoint_interval):

    results_path = None
    results = {
        'epoch': [],
        'train_loss': [],
        'train_iou': [],
        'train_pixel_accuracy': [],
        'val_loss': [],
        'val_iou': [],
        'val_dice': [],
        'val_pixel_accuracy': [],
        'val_labels': [],
        'val_preds': [],
        'final_epoch_labels': [],
        'final_epoch_preds': [],
        'final_epoch_inputs': []
    }

    print("»»» Starting Segmentation Network Finetuning process")

    for epoch in range(total_epoch_count):
        # Training
        avg_train_loss, train_iou, train_pixel_accuracy = segm_train_one_epoch(finetune_model, finetune_criterion, train_dataloader, finetune_optimizer, device)

        # Validation
        if (epoch + 1) != total_epoch_count:
            avg_val_loss, val_iou, val_dice, val_pixel_accuracy, val_labels, val_preds, final_epoch_labels, final_epoch_preds, final_epoch_inputs = segm_validate_one_epoch(finetune_model, finetune_criterion, val_dataloader, device, is_last_epoch = False)
        else:
            avg_val_loss, val_iou, val_dice, val_pixel_accuracy, val_labels, val_preds, final_epoch_labels, final_epoch_preds, final_epoch_inputs = segm_validate_one_epoch(finetune_model, finetune_criterion, val_dataloader, device, is_last_epoch = True)

            results['final_epoch_labels'].append(final_epoch_labels)
            results['final_epoch_preds'].append(final_epoch_preds)
            results['final_epoch_inputs'].append(final_epoch_inputs)
        
        print(f"Epoch [{epoch+1}/{total_epoch_count}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        results['epoch'].append(epoch + 1)
        results['train_loss'].append(avg_train_loss)
        results['train_iou'].append(train_iou)
        results['train_pixel_accuracy'].append(train_pixel_accuracy)
        results['val_loss'].append(avg_val_loss)
        results['val_iou'].append(val_iou)
        results['val_dice'].append(val_dice)
        results['val_pixel_accuracy'].append(val_pixel_accuracy)
        results['val_labels'].append(val_labels)
        results['val_preds'].append(val_preds)
   
        if finetune_scheduler:
            finetune_scheduler.step(avg_val_loss)

        if (epoch + 1) % checkpoint_interval == 0 or (epoch + 1) == total_epoch_count:
            print(f"»»» Checkpoint reached at epoch {(epoch + 1)}: Saving model state and results")

            try:
                results_path = results_save(identifier, "ft", storing_path, results, epoch, total_epoch_count)
            except RuntimeError as e:
                print(e)
                return None
            
            # Clear results to avoid duplication
            results = {key: [] for key in results.keys()}

    print("\n»»» Segmentation Network Finetuning process complete\n»»» Logs saved at: ", results_path)
      
    return results_path # Path of files with model and results

#---
#auxilliary functions
def categorize_and_select_examples_iou_ft(inputs, labels, preds):
    best_box = []
    worst_box = []
    normal_case_box = []
    all_examples = []

    # Categorize examples
    for idx in range(len(inputs)):
        label_image = labels[idx, 0]
        pred_image = preds[idx, 0]

        # Calculate Intersection and Union for IoU calculation
        intersection = np.sum((label_image == 1) & (pred_image == 1))
        union = np.sum((label_image == 1) | (pred_image == 1))

        # Handle edge case where both mask and prediction are 0s
        iou = intersection / union if union > 0 else 0

        # Identify normal cases (no tumor, both mask and prediction are all 0s)
        if np.sum(label_image) == 0 and np.sum(pred_image) == 0:
            normal_case_box.append((idx, iou))  # Consider as perfect IoU for no tumor
        else:
            all_examples.append((idx, iou))

    # Sort all examples by IoU score
    sorted_examples = sorted(all_examples, key=lambda x: x[1], reverse=True)

    # Split into best and worst boxes (top 20% and bottom 20%)
    top_n = min(20, len(sorted_examples))
    bottom_n = min(20, len(sorted_examples))
    best_box = sorted_examples[:top_n]
    worst_box = sorted_examples[-bottom_n:]

    # Collect indices of all examples
    best_indices = {i for i, _ in best_box}
    worst_indices = {i for i, _ in worst_box}
    normal_indices = {i for i, _ in normal_case_box}

    # Identify average cases
    average_box = [
        (i, iou) for i, iou in all_examples if i not in best_indices and i not in worst_indices
    ]

    # Select 3 examples randomly from each box, ensuring total selection of 9 examples
    selected_best = random.sample(best_box + normal_case_box, min(3, len(best_box + normal_case_box)))
    selected_worst = random.sample(worst_box, min(3, len(worst_box)))
    selected_average = random.sample(average_box, min(3, len(average_box)))

    # Return the combined selection
    return selected_best, selected_average, selected_worst

def categorize_and_select_examples_dice_ft(final_epoch_labels, final_epoch_preds):
    best_box = []
    worst_box = []
    normal_case_box = []
    all_examples = []

    # Categorize examples
    for i in range(len(final_epoch_labels)):
        label = final_epoch_labels[i, 0]
        pred = final_epoch_preds[i, 0]
        label_flat = label.flatten()
        pred_flat = pred.flatten()

        # Identify normal cases
        if np.sum(label_flat) == 0 and np.sum(pred_flat) == 0:
            normal_case_box.append((i, 0.0))  # Dice score is 0 for no tumor
        else:
            dice_score = f1_score(label_flat, pred_flat, average='binary', zero_division=1)
            all_examples.append((i, dice_score))

    # Sort examples by Dice score
    sorted_examples = sorted(all_examples, key=lambda x: x[1], reverse=True)

    # Split into best and worst boxes (e.g., top 20% and bottom 20%)
    top_n = min(20, len(sorted_examples))
    bottom_n = min(20, len(sorted_examples))
    best_box = sorted_examples[:top_n]
    worst_box = sorted_examples[-bottom_n:]

    # Collect indices of all examples
    best_indices = {i for i, _ in best_box}
    worst_indices = {i for i, _ in worst_box}
    normal_indices = {i for i, _ in normal_case_box}

    # Identify average cases
    average_box = [
        (i, dice) for i, dice in all_examples if i not in best_indices and i not in worst_indices
    ]

    # Select 3 examples randomly from each box
    selected_best = random.sample(best_box + normal_case_box, min(3, len(best_box + normal_case_box)))
    selected_worst = random.sample(worst_box, min(3, len(worst_box)))
    selected_average = random.sample(average_box, min(3, len(average_box)))

    # Return the combined selection
    return selected_best, selected_average, selected_worst

def categorize_and_select_examples_iou_test(labels, preds):
    best_box = []
    worst_box = []
    normal_case_box = []
    all_examples = []

    # Categorize examples
    for idx in range(len(labels)):
        label_image = labels[idx, 0]
        pred_image = preds[idx, 0]

        # Calculate Intersection and Union for IoU calculation
        intersection = np.sum((label_image == 1) & (pred_image == 1))
        union = np.sum((label_image == 1) | (pred_image == 1))

        # Handle edge case where both mask and prediction are 0s
        iou = intersection / union if union > 0 else 0

        # Identify normal cases (no tumor, both mask and prediction are all 0s)
        if np.sum(label_image) == 0 and np.sum(pred_image) == 0:
            normal_case_box.append((idx, iou))  # Consider as perfect IoU for no tumor
        else:
            all_examples.append((idx, iou))

    # Sort all examples by IoU score
    sorted_examples = sorted(all_examples, key=lambda x: x[1], reverse=True)

    # Split into best and worst boxes (top 20% and bottom 20%)
    top_n = min(20, len(sorted_examples))
    bottom_n = min(20, len(sorted_examples))
    best_box = sorted_examples[:top_n]
    worst_box = sorted_examples[-bottom_n:]

    # Collect indices of all examples
    best_indices = {i for i, _ in best_box}
    worst_indices = {i for i, _ in worst_box}
    normal_indices = {i for i, _ in normal_case_box}

    # Identify average cases
    average_box = [
        (i, iou) for i, iou in all_examples if i not in best_indices and i not in worst_indices
    ]

    # Select 3 examples randomly from each box, ensuring total selection of 9 examples
    selected_best = random.sample(best_box + normal_case_box, min(3, len(best_box + normal_case_box)))
    selected_worst = random.sample(worst_box, min(3, len(worst_box)))
    selected_average = random.sample(average_box, min(3, len(average_box)))

    # Return the combined selection
    return selected_best, selected_average, selected_worst

def categorize_and_select_examples_dice_test(labels, preds):
    best_box = []
    worst_box = []
    normal_case_box = []
    all_examples = []

    # Categorize examples
    for i in range(len(labels)):
        label = labels[i, 0]
        pred = preds[i, 0]
        label_flat = label.flatten()
        pred_flat = pred.flatten()

        # Identify normal cases
        if np.sum(label_flat) == 0 and np.sum(pred_flat) == 0:
            normal_case_box.append((i, 0.0))  # Dice score is 0 for no tumor
        else:
            dice_score = f1_score(label_flat, pred_flat, average='binary', zero_division=1)
            all_examples.append((i, dice_score))

    # Sort examples by Dice score
    sorted_examples = sorted(all_examples, key=lambda x: x[1], reverse=True)

    # Split into best and worst boxes (e.g., top 20% and bottom 20%)
    top_n = min(20, len(sorted_examples))
    bottom_n = min(20, len(sorted_examples))
    best_box = sorted_examples[:top_n]
    worst_box = sorted_examples[-bottom_n:]

    # Collect indices of all examples
    best_indices = {i for i, _ in best_box}
    worst_indices = {i for i, _ in worst_box}
    normal_indices = {i for i, _ in normal_case_box}

    # Identify average cases
    average_box = [
        (i, dice) for i, dice in all_examples if i not in best_indices and i not in worst_indices
    ]

    # Select 3 examples randomly from each box
    selected_best = random.sample(best_box + normal_case_box, min(3, len(best_box + normal_case_box)))
    selected_worst = random.sample(worst_box, min(3, len(worst_box)))
    selected_average = random.sample(average_box, min(3, len(average_box)))

    # Return the combined selection
    return selected_best, selected_average, selected_worst

#---

def segm_plot_finetuning_results(results_path):
    if not os.path.exists(results_path):
        print(f"File not found: {results_path}")
        return

    print(f"»»» Data taken from path: {results_path}")
    with open(results_path, 'rb') as f:
        results = pickle.load(f)

    directory_path = os.path.dirname(results_path)
    identifier = os.path.basename(results_path).split('_')[1]  # Adjust if needed

    plot_paths = {
        'box1_loss_iou_pxlacc': os.path.join(directory_path, f"{identifier}_loss_iou_pxlacc_ft.png"),
        'box2_dice_prec_rec':os.path.join(directory_path, f"{identifier}_dice_prec_rec_ft.png"),
        'box3_example_views_dice': os.path.join(directory_path, f"{identifier}_example_views_dice_ft.png"),
        'box4_example_views_iou': os.path.join(directory_path, f"{identifier}_example_views_iou_ft.png")
    }

    epochs = np.array(results['epoch'])  # Epochs

    # Extract loss data and last values
    train_loss = np.array(results['train_loss'])  # Train Loss
    val_loss = np.array(results['val_loss'])  # Validation Loss
    train_loss_final = round(train_loss[-1], 5) if train_loss.size > 0 else None # Last value of Train Loss
    val_loss_final = round(val_loss[-1], 5) if val_loss.size > 0 else None # Last value of Validation Loss

    # Extract IoU data and last values
    train_iou = np.array(results['train_iou'])
    val_iou = np.array(results['val_iou'])
    train_iou_final = round(train_iou[-1], 5) if train_iou.size > 0 else None  # Last value of Train IoU
    val_iou_final = round(val_iou[-1], 5) if val_iou.size > 0 else None  # Last value of Validation IoU

    # Extract Pixel Accuracy data and last values
    train_pixel_accuracy = np.array(results['train_pixel_accuracy'])
    val_pixel_accuracy = np.array(results['val_pixel_accuracy'])
    train_pixel_accuracy_final = round(train_pixel_accuracy[-1], 5) if train_pixel_accuracy.size > 0 else None  # Last value of Train Pixel Accuracy
    val_pixel_accuracy_final = round(val_pixel_accuracy[-1], 5) if val_pixel_accuracy.size > 0 else None  # Last value of Validation Pixel Accuracy
    
    #----
    # Box 1

    epoch_ticks = [1] + list(range(10, len(epochs) + 1, 10))
    if epoch_ticks[-1] != len(epochs):
        epoch_ticks.append(len(epochs))

    # Create figure with subplots for the three graphs
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))  # Three plots side by side

    # Plot Loss Graph on the left
    axes[0].plot(epochs, train_loss, label=f'Train Loss ({train_loss_final})', color='orange')
    axes[0].plot(epochs, val_loss, label=f'Val Loss ({val_loss_final})', color='blue')
    axes[0].set_xlabel('Epochs')
    axes[0].set_ylabel('Loss')
    axes[0].legend(loc='upper right')
    axes[0].set_title('Loss')

    axes[0].set_xticks(epoch_ticks)

    # IoU Plot in the middle
    axes[1].plot(epochs, train_iou, label=f'Train IoU ({train_iou_final})', color='orange')
    axes[1].plot(epochs, val_iou, label=f'Val IoU ({val_iou_final})', color='blue')
    axes[1].set_xlabel('Epochs')
    axes[1].set_ylabel('IoU')
    axes[1].legend(loc='upper right')
    axes[1].set_title('IoU')

    axes[1].set_xticks(epoch_ticks)

    # Pixel Accuracy Plot
    axes[2].plot(epochs, train_pixel_accuracy, label=f'Train Pixel Acc ({train_pixel_accuracy_final})', color='orange')
    axes[2].plot(epochs, val_pixel_accuracy, label=f'Val Pixel Acc ({val_pixel_accuracy_final})', color='blue')
    axes[2].set_xlabel('Epochs')
    axes[2].set_ylabel('Pixel Accuracy')
    axes[2].legend(loc='lower right')
    axes[2].set_title('Pixel Accuracy')

    axes[2].set_xticks(epoch_ticks)

    # Add overall title for the figure
    fig.suptitle(f'Segmentation model {identifier} - Loss / IoU / Pixel Accuracy Graphs (Train / Val)', fontsize=16)

    # Save the combined figure
    plt.savefig(plot_paths['box1_loss_iou_pxlacc'], bbox_inches='tight')
    plt.close()

    print(f"»»» Loss / Intersection over union / Pixel accuracy image saved at {plot_paths['box1_loss_iou_pxlacc']}")
    
    #---
    # Box 2

    val_labels = results['val_labels']
    val_preds = results['val_preds']

    # Initialize metric arrays
    val_dice = []
    val_precision = []
    val_recall = []

    # Compute metrics for each epoch
    for labels, preds in zip(val_labels, val_preds):
        labels = np.array(labels)
        preds = np.array(preds)

        val_dice.append(f1_score(labels, preds, average='binary'))
        val_precision.append(precision_score(labels, preds, average='binary'))
        val_recall.append(recall_score(labels, preds, average='binary'))

    # Convert metrics to numpy arrays for easier handling
    val_dice = np.array(val_dice)
    val_precision = np.array(val_precision)
    val_recall = np.array(val_recall)

    # Get last values for legends
    val_dice_final = round(val_dice[-1], 5) if val_dice.size > 0 else None
    val_precision_final = round(val_precision[-1], 5) if val_precision.size > 0 else None
    val_recall_final = round(val_recall[-1], 5) if val_recall.size > 0 else None

    # Create figure for validation metrics
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Dice Graph
    axes[0].plot(epochs, val_dice, label=f'Dice ({val_dice_final})', color='blue')
    axes[0].set_xlabel('Epochs')
    axes[0].set_ylabel('Dice Score')
    axes[0].legend(loc='lower right')
    axes[0].set_title('Dice Score')

    axes[0].set_xticks(epoch_ticks)

    # Precision Graph
    axes[1].plot(epochs, val_precision, label=f'Precision ({val_precision_final})', color='green')
    axes[1].set_xlabel('Epochs')
    axes[1].set_ylabel('Precision')
    axes[1].legend(loc='lower right')
    axes[1].set_title('Precision')

    axes[1].set_xticks(epoch_ticks)

    # Recall Graph
    axes[2].plot(epochs, val_recall, label=f'Recall ({val_recall_final})', color='purple')
    axes[2].set_xlabel('Epochs')
    axes[2].set_ylabel('Recall')
    axes[2].legend(loc='lower right')
    axes[2].set_title('Recall')

    axes[2].set_xticks(epoch_ticks)

    # Add overall title for the figure
    fig.suptitle(f'Segmentation model {identifier} - Dice / Precision / Recall Graphs (Validation)', fontsize=16)

    # Save the combined figure
    plt.savefig(plot_paths['box2_dice_prec_rec'], bbox_inches='tight')
    plt.close()

    print(f"»»» Dice / Precision / Recall image saved at {plot_paths['box2_dice_prec_rec']}")

    #---
    # Box 3
   
    # Extract the final epoch inputs, labels, and predictions
    final_epoch_inputs = np.array(results['final_epoch_inputs'])  # [1, N, 1, H, W]
    final_epoch_labels = np.array(results['final_epoch_labels'])  # [1, N, 1, H, W]
    final_epoch_preds = np.array(results['final_epoch_preds'])    # [1, N, 1, H, W]

    # Flatten the batch dimension
    final_epoch_inputs = final_epoch_inputs[0]
    final_epoch_labels = final_epoch_labels[0]
    final_epoch_preds = final_epoch_preds[0]

    # Example selection logic for Dice
    selected_best_dice, selected_average_dice, selected_worst_dice = categorize_and_select_examples_dice_ft(
        final_epoch_labels, final_epoch_preds
    )
    selected_examples_dice = selected_best_dice + selected_average_dice + selected_worst_dice

    # Visualization for Dice scores
    fig, ax = plt.subplots(len(selected_examples_dice), 4, figsize=(12, len(selected_examples_dice) * 3), dpi=150)
    plt.suptitle(
        f"Segmentation model {identifier} - Random example visualizations based on Dice score (Validation)", # Here it is still 3 high values, 3 average, 3 low in top to bottom order
        fontsize=14,
        y=0.99
    )
    fig.subplots_adjust(top=0.7)

    for row, (idx, dice) in enumerate(selected_examples_dice):
        input_image = final_epoch_inputs[idx, 0]
        label_image = final_epoch_labels[idx, 0]
        pred_image = final_epoch_preds[idx, 0]

        input_image = (input_image * 0.5 + 0.5) * 255
        input_image = input_image.astype(np.uint8)

        overlay = np.zeros((label_image.shape[0], label_image.shape[1], 3), dtype=np.uint8)
        overlay[(pred_image == 1) & (label_image == 1)] = [144, 238, 144]  # Light green
        overlay[(pred_image == 1) & (label_image == 0)] = [255, 99, 71]    # Red
        overlay[(label_image == 1) & (pred_image == 0)] = [100, 149, 237]  # Blue

        ax[row, 0].imshow(input_image, cmap='gray')
        ax[row, 0].set_title("Input" if row == 0 else "")
        ax[row, 0].axis('off')

        ax[row, 1].imshow(label_image, cmap='gray')
        ax[row, 1].set_title("Ground Truth" if row == 0 else "")
        ax[row, 1].axis('off')

        ax[row, 2].imshow(pred_image, cmap='gray')
        ax[row, 2].set_title("Prediction" if row == 0 else "")
        ax[row, 2].axis('off')

        overlay_title = f"Overlay (Dice: {dice:.2f})"
        if np.sum(label_image) == 0:
            overlay_title = f"Overlay (Dice: {dice:.2f}) (no tumor)"

        ax[row, 3].imshow(overlay)
        ax[row, 3].set_title(overlay_title)
        ax[row, 3].axis('off')

    plt.tight_layout(h_pad=1.5, w_pad=1.0)
    plt.savefig(plot_paths['box3_example_views_dice'])
    plt.close()

    print(f"»»» Example views with overlay based on Dice image saved at: {plot_paths['box3_example_views_dice']}")

    #---
    # Box 4
    
    selected_best_iou, selected_average_iou, selected_worst_iou = categorize_and_select_examples_iou_ft(
        final_epoch_inputs, final_epoch_labels, final_epoch_preds
    )
    selected_examples_iou = selected_best_iou + selected_average_iou + selected_worst_iou

    # Visualization for IoU scores
    fig, ax = plt.subplots(len(selected_examples_iou), 4, figsize=(12, len(selected_examples_iou) * 3), dpi=150)
    plt.suptitle(
        f"Segmentation model {identifier} - Random example visualizations based on IoU score (Validation)", # Here it is still 3 high values, 3 average, 3 low in top to bottom order
        fontsize=14,
        y=0.99
    )
    fig.subplots_adjust(top=0.7)

    for row, (idx, iou) in enumerate(selected_examples_iou):
        input_image = final_epoch_inputs[idx, 0]
        label_image = final_epoch_labels[idx, 0]
        pred_image = final_epoch_preds[idx, 0]

        input_image = (input_image * 0.5 + 0.5) * 255
        input_image = input_image.astype(np.uint8)

        overlay = np.zeros((label_image.shape[0], label_image.shape[1], 3), dtype=np.uint8)
        overlay[(pred_image == 1) & (label_image == 1)] = [144, 238, 144]  # Light green
        overlay[(pred_image == 1) & (label_image == 0)] = [255, 99, 71]    # Red
        overlay[(label_image == 1) & (pred_image == 0)] = [100, 149, 237]  # Blue

        ax[row, 0].imshow(input_image, cmap='gray')
        ax[row, 0].set_title("Input" if row == 0 else "")
        ax[row, 0].axis('off')

        ax[row, 1].imshow(label_image, cmap='gray')
        ax[row, 1].set_title("Ground Truth" if row == 0 else "")
        ax[row, 1].axis('off')

        ax[row, 2].imshow(pred_image, cmap='gray')
        ax[row, 2].set_title("Prediction" if row == 0 else "")
        ax[row, 2].axis('off')

        overlay_title = f"Overlay (IoU: {iou:.2f})"
        if np.sum(label_image) == 0:
            overlay_title = f"Overlay (IoU: {iou:.2f}) (no tumor)"

        ax[row, 3].imshow(overlay)
        ax[row, 3].set_title(overlay_title)
        ax[row, 3].axis('off')

    plt.tight_layout(h_pad=1.5, w_pad=1.0)
    plt.savefig(plot_paths['box4_example_views_iou'])
    plt.close()

    print(f"»»» Example views with overlay based on IoU image saved at: {plot_paths['box4_example_views_iou']}")

# Testing Functions
def segm_run_testing(test_model, test_criterion, test_dataloader, storing_path, identifier, device):
    test_model.eval()
    
    # Initialize variables for the overall test pass
    all_preds = []
    all_labels = []
    all_inputs = []
    all_preds_flat = []
    all_labels_flat = []
    correct_pixels = 0
    total_pixels = 0
    running_loss = 0.0

    hausdorff_preds = []  # Store coordinates of predicted regions
    hausdorff_labels = []  # Store coordinates of ground truth regions

    # Initialize a dictionary to store results
    results = {
        'test_loss': [],
        'test_iou': [],
        'test_pixel_accuracy': [],
        'test_dice': [],
        'test_labels': [],
        'test_preds': [],
        'test_inputs': [],
        'test_hausdorff': []
    }

    print("»»» Starting Segmentation Network Testing process \n")

    for batch_idx, (inputs, labels) in enumerate(test_dataloader):
        # Move data to device
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = test_model(inputs)

        # Compute loss
        loss = test_criterion(outputs, labels)
        running_loss += loss.item() * inputs.size(0)

        # Threshold predictions (binarization or softmax)
        binary_preds = (outputs > 0.5).float()

        # Update pixel accuracy counts
        correct_pixels += (binary_preds == labels).sum().item()
        total_pixels += labels.numel()

        # Store data for final evaluation visualization
        all_preds.append(binary_preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())
        all_inputs.append(inputs.cpu().numpy())

        # Flatten and accumulate predictions and labels for global metrics
        all_preds_flat.extend(binary_preds.cpu().numpy().astype(int).flatten())
        all_labels_flat.extend(labels.cpu().numpy().astype(int).flatten())

        # Prepare coordinates for Hausdorff distance calculation
        for pred, label in zip(binary_preds.cpu().numpy(), labels.cpu().numpy()):
            pred_coords = np.argwhere(pred[0])
            label_coords = np.argwhere(label[0])
            if pred_coords.size > 0 and label_coords.size > 0:  # Only add non-empty coordinates
                hausdorff_preds.append(pred_coords)
                hausdorff_labels.append(label_coords)

    # Concatenate all batches into cohesive NumPy arrays
    if all_preds:
        results['test_preds'] = np.concatenate(all_preds, axis=0)
        results['test_labels'] = np.concatenate(all_labels, axis=0)
        results['test_inputs'] = np.concatenate(all_inputs, axis=0)
    else:
        results['test_preds'] = np.array([])
        results['test_labels'] = np.array([])
        results['test_inputs'] = np.array([])

    # Calculate global metrics (IoU, Dice, etc.)
    test_loss = running_loss / len(test_dataloader.dataset) if len(test_dataloader.dataset) > 0 else np.nan
    test_pixel_accuracy = correct_pixels / total_pixels if total_pixels > 0 else np.nan

    if len(all_preds_flat) > 0 and len(all_labels_flat) > 0:
        test_iou = jaccard_score(all_labels_flat, all_preds_flat, average='binary')
        test_dice = f1_score(all_labels_flat, all_preds_flat, average='binary')
    else:
        test_iou = np.nan
        test_dice = np.nan

    results['test_loss'].append(test_loss)
    results['test_iou'].append(test_iou)
    results['test_pixel_accuracy'].append(test_pixel_accuracy)
    results['test_dice'].append(test_dice)

    # Compute Hausdorff distance
    hausdorff_vals = []
    for pred_coords, label_coords in zip(hausdorff_preds, hausdorff_labels):
        forward_hd = directed_hausdorff(pred_coords, label_coords)[0]
        reverse_hd = directed_hausdorff(label_coords, pred_coords)[0]
        hausdorff_vals.append(max(forward_hd, reverse_hd))

    if hausdorff_vals:
        final_hausdorff = np.mean(hausdorff_vals)
    else:
        final_hausdorff = np.nan  # No valid data for Hausdorff calculation
    results['test_hausdorff'].append(final_hausdorff)

    # Save results to file
    results_filename = f"logs_{identifier}_test.pkl"
    results_path = os.path.join(storing_path, results_filename)

    with open(results_path, 'wb') as f:
        pickle.dump(results, f)

    print(f"\n»»» Segmentation Network Testing process complete\n»»» Logs saved at: {results_path}")

    return results_path

def segm_plot_testing_results(results_path):
    if not os.path.exists(results_path):
        print(f"File not found: {results_path}")
        return

    print(f"»»» Data taken from path: {results_path}")
    with open(results_path, 'rb') as f:
        results = pickle.load(f)

    # Extract directory and identifier from results_path
    directory_path = os.path.dirname(results_path)
    identifier = os.path.basename(results_path).split('_')[1]

    # Define plot paths for segmentation results
    plot_paths = {
        'box1_value_metrics': os.path.join(directory_path, f"{identifier}_value_metrics_test.png"),
        'box2_example_views_dice': os.path.join(directory_path, f"{identifier}_example_views_dice_test.png"),
        'box3_example_views_iou': os.path.join(directory_path, f"{identifier}_example_views_iou_test.png")
    } 

    #----
    # Box 1

    value_metrics = {
        "Test Loss": results['test_loss'][0],
        "Test IoU": results['test_iou'][0],
        "Pixel Accuracy": results['test_pixel_accuracy'][0],
        "Dice Similarity Coefficient": results['test_dice'][0],
        "Hausdorff Distance": results['test_hausdorff'][0]
    }

    fig, ax = plt.subplots(figsize=(6, 3))
    ax.axis('tight')
    ax.axis('off')

    fig.suptitle(f'Segmentation model {identifier} - Single value metrics (Test)', fontsize=16, ha='center')

    table_data = [[key, f"{value:.4f}"] for key, value in value_metrics.items()]
    table = ax.table(cellText=table_data, loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.auto_set_column_width([0, 1])

    plt.savefig(plot_paths['box1_value_metrics'], bbox_inches='tight')
    plt.close(fig)

    print(f"»»» Metrics table image saved at {plot_paths['box1_value_metrics']}")

    #----
    # Box 2
    
    selected_best_dice, selected_average_dice, selected_worst_dice = categorize_and_select_examples_dice_test(
        results['test_labels'], results['test_preds']
    )
    selected_examples_dice = selected_best_dice + selected_average_dice + selected_worst_dice

    fig, ax = plt.subplots(len(selected_examples_dice), 4, figsize=(12, len(selected_examples_dice) * 3), dpi=150)
    plt.suptitle(
        f"Segmentation model {identifier} - Random example visualizations based on Dice score (Test)",
        fontsize=14,
        y=0.99
    )
    fig.subplots_adjust(top=0.7)

    for row, (idx, dice) in enumerate(selected_examples_dice):
        input_image = results['test_inputs'][idx, 0]
        label_image = results['test_labels'][idx, 0]
        pred_image = results['test_preds'][idx, 0]

        input_image = (input_image * 0.5 + 0.5) * 255
        input_image = input_image.astype(np.uint8)

        overlay = np.zeros((label_image.shape[0], label_image.shape[1], 3), dtype=np.uint8)
        overlay[(pred_image == 1) & (label_image == 1)] = [144, 238, 144]  # Light green
        overlay[(pred_image == 1) & (label_image == 0)] = [255, 99, 71]    # Red
        overlay[(label_image == 1) & (pred_image == 0)] = [100, 149, 237]  # Blue

        ax[row, 0].imshow(input_image, cmap='gray')
        ax[row, 0].set_title("Input" if row == 0 else "")
        ax[row, 0].axis('off')

        ax[row, 1].imshow(label_image, cmap='gray')
        ax[row, 1].set_title("Ground Truth" if row == 0 else "")
        ax[row, 1].axis('off')

        ax[row, 2].imshow(pred_image, cmap='gray')
        ax[row, 2].set_title("Prediction" if row == 0 else "")
        ax[row, 2].axis('off')

        overlay_title = f"Overlay (Dice: {dice:.2f})"
        if np.sum(label_image) == 0:
            overlay_title = f"Overlay (Dice: {dice:.2f}) (no tumor)"

        ax[row, 3].imshow(overlay)
        ax[row, 3].set_title(overlay_title)
        ax[row, 3].axis('off')

    plt.tight_layout(h_pad=1.5, w_pad=1.0)
    plt.savefig(plot_paths['box2_example_views_dice'])
    plt.close()
    print(f"»»» Example views with overlay based on Dice image saved at: {plot_paths['box2_example_views_dice']}")

    #----
    # Box 3
    
    selected_best_iou, selected_average_iou, selected_worst_iou = categorize_and_select_examples_iou_test(
        results['test_labels'], results['test_preds']
    )
    selected_examples_iou = selected_best_iou + selected_average_iou + selected_worst_iou

    fig, ax = plt.subplots(len(selected_examples_iou), 4, figsize=(12, len(selected_examples_iou) * 3), dpi=150)
    plt.suptitle(
        f"Segmentation model {identifier} - Random example visualizations based on IoU score (Test)",
        fontsize=14,
        y=0.99
    )
    fig.subplots_adjust(top=0.7)

    for row, (idx, iou) in enumerate(selected_examples_iou):
        input_image = results['test_inputs'][idx, 0]
        label_image = results['test_labels'][idx, 0]
        pred_image = results['test_preds'][idx, 0]

        input_image = (input_image * 0.5 + 0.5) * 255
        input_image = input_image.astype(np.uint8)

        overlay = np.zeros((label_image.shape[0], label_image.shape[1], 3), dtype=np.uint8)
        overlay[(pred_image == 1) & (label_image == 1)] = [144, 238, 144]  # Light green
        overlay[(pred_image == 1) & (label_image == 0)] = [255, 99, 71]    # Red
        overlay[(label_image == 1) & (pred_image == 0)] = [100, 149, 237]  # Blue

        ax[row, 0].imshow(input_image, cmap='gray')
        ax[row, 0].set_title("Input" if row == 0 else "")
        ax[row, 0].axis('off')

        ax[row, 1].imshow(label_image, cmap='gray')
        ax[row, 1].set_title("Ground Truth" if row == 0 else "")
        ax[row, 1].axis('off')

        ax[row, 2].imshow(pred_image, cmap='gray')
        ax[row, 2].set_title("Prediction" if row == 0 else "")
        ax[row, 2].axis('off')

        overlay_title = f"Overlay (IoU: {iou:.2f})"
        if np.sum(label_image) == 0:
            overlay_title = f"Overlay (IoU: {iou:.2f}) (no tumor)"

        ax[row, 3].imshow(overlay)
        ax[row, 3].set_title(overlay_title)
        ax[row, 3].axis('off')

    plt.tight_layout(h_pad=1.5, w_pad=1.0)
    plt.savefig(plot_paths['box3_example_views_iou'])
    plt.close()
    print(f"»»» Example views with overlay based on IoU image saved at: {plot_paths['box3_example_views_iou']}")

'\n#Old, keep here though\n# def segm_run_testing(test_model, test_criterion, test_dataloader, storing_path, identifier, device):\n#     test_model.eval()\n    \n#     # Initialize variables for the overall test pass\n#     all_preds = []\n#     all_labels = []\n#     correct_pixels = 0\n#     total_pixels = 0\n#     running_loss = 0.0\n\n#     # Initialize Hausdorff-specific variables\n#     hausdorff_preds = []  # Store coordinates of predicted regions\n#     hausdorff_labels = []  # Store coordinates of ground truth regions\n\n#     # Initialize a dictionary to store results\n#     results = {\n#         \'test_loss\': [],\n#         \'test_iou\': [],\n#         \'test_pixel_accuracy\': [],\n#         \'test_dice\': [],\n#         \'test_labels\': [],\n#         \'test_preds\': [],\n#         \'test_inputs\': [],\n#         \'test_hausdorff\': []\n#     }\n\n#     print("»»» Starting Segmentation Network Testing process \n")\n\n#     for batch_idx, (inputs, labels) in enumerate(test

## 5. Pretrain Pretext Task - Contrastive Representation Learning

### 5.0. Setting pretrain storing paths / backbone builder function / general pretrain parameters

In [None]:
storing_path51 = "../results/pretrainPhase/simclr/ultrasound"
storing_path52 = "../results/pretrainPhase/simclr/mammography"
storing_path53 = "../results/pretrainPhase/simclr/multimodal"
storing_path54 = "../results/pretrainPhase/moco/ultrasound"
storing_path55 = "../results/pretrainPhase/moco/mammography"
storing_path56= "../results/pretrainPhase/moco/multimodal"
storing_path57 = "../results/pretrainPhase/byol/ultrasound"
storing_path58 = "../results/pretrainPhase/byol/mammography"
storing_path59 = "../results/pretrainPhase/byol/multimodal"

def initialize_resnet_backbone(is_byol):
    resnet = torchvision.models.resnet18(weights=None)
    backbone = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False), *list(resnet.children())[1:-2], nn.AdaptiveAvgPool2d((1, 1))
    )

    if is_byol:
        backbone.add_module('flatten', nn.Flatten())
        backbone.output_dim = 512
    return backbone

total_epoch_count = 100
checkpoint_interval = 5 
lr_pretrain = 0.014
temperature = 0.07
weight_decay = 1e-6

In [None]:
#from torchsummary import summary
#model = initialize_resnet_backbone(False).to(device)
#summary(model, input_size=(1, 64, 64))

In [None]:
#---------------------------------------------------------------------------
## Code snippet to show how to load the pretrained model if needed (only run if needed)
# 
# (SIMCLR ONLY) 
## 1. create a new backbone and simclr model using said backbone
#backbone_simclr_us_pretrain_loaded = initialize_resnet_backbone(False)
#model_simclr_us_pretrain_loaded = SimCLR_v2(backbone_simclr_us_pretrain_loaded, projection_dim=128)
#model_simclr_us_pretrain_loaded.to(device)

## 2. get path to the saved model
#model_filename = "M1_savedModel.pth"
#model_path = os.path.join(storing_path51, model_filename)

## Load the model weights into the model
#model_simclr_us_pretrain_loaded.load_state_dict(torch.load(model_path))

#------------

# (MOCO ONLY) 
## 1. create a new backbone and simclr model using said backbone
#backbone_simclr_us_pretrain_loaded = initialize_resnet_backbone(False)
#model_simclr_us_pretrain_loaded = SimCLR_v2(backbone_simclr_us_pretrain_loaded, projection_dim=128)
#model_simclr_us_pretrain_loaded.to(device)

## 2. get path to the saved model
#model_filename = "M1_savedModel.pth"
#model_path = os.path.join(storing_path51, model_filename)

## Load the model weights into the model
#model_simclr_us_pretrain_loaded.load_state_dict(torch.load(model_path))

## 1. create a new backbone and moco model using said backbone
#backbone_moco_us_pretrain_loaded = initialize_resnet_backbone(False)
#model_moco_us_pretrain_loaded = MoCo_v2(backbone_moco_us_pretrain_loaded)
#model_moco_us_pretrain_loaded.to(device)

## 2. get path to the saved model
#model_filename = "M4_savedModel.pth"
#model_path = os.path.join(storing_path54, model_filename)

## Load the model weights into the model
#model_moco_us_pretrain_loaded.load_state_dict(torch.load(model_path))

#------------

# (BYOL ONLY)
## 1. Create a new backbone and BYOL model using said backbone
#backbone_byol_us_pretrain_loaded = initialize_resnet_backbone(True)  #just the byol case uses true here
#model_byol_us_pretrain_loaded = BYOL(backbone_byol_us_pretrain_loaded, projection_dim=128)
#model_byol_us_pretrain_loaded.to(device)

## 2. Get the path to the saved model
#model_filename = "M7_savedModel.pth"
#model_path = os.path.join(storing_path57, model_filename)

## 3. Load the model weights into the BYOL model
#model_byol_us_pretrain_loaded.load_state_dict(torch.load(model_path))

### 5.1. Using <font color='green'>SimCLR</font> method and <font color='purple'>Ultrasound</font> dataset

In [None]:
# This is a 4 step process
# 1. Prepare model
# Initialize the SimCLR model with the ResNet backbone
backbone_simclr_us_pretrain = initialize_resnet_backbone(False)
model_simclr_us_pretrain = SimCLR_v2(backbone_simclr_us_pretrain, projection_dim=128)
model_simclr_us_pretrain.to(device)

# Contrastive loss with temperature
criterion_simclr_us_pretrain = SimCLRNTXentLoss(temperature)
criterion_simclr_us_pretrain.to(device)

# Adam optimizer
optimizer_simclr_us_pretrain = optim.Adam(model_simclr_us_pretrain.parameters(), lr = lr_pretrain, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler
scheduler_simclr_us_pretrain = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_simclr_us_pretrain, mode = 'min', patience = 3, factor = 0.5)

In [None]:
# 2. Pretrain loop
pretraining_results_path_simclr_us = run_pretraining_simclr(
    model_simclr_us_pretrain,                  # SimCLR model
    criterion_simclr_us_pretrain,              # NTXentLoss (contrastive loss function)
    us_pretrain_train_dataloader,              # Dataloader for training data
    us_pretrain_val_dataloader,                # Dataloader for validation data
    optimizer_simclr_us_pretrain,              # Adam optimizer
    scheduler_simclr_us_pretrain,              # Learning rate scheduler
    device,                                    # Device to run training on (GPU/CPU)
    total_epoch_count,                         # Number of epochs to train
    storing_path51,                            # Path to store logs
    "M1",                                      # Model identifier
    checkpoint_interval                        # Save checkpoints every N epochs
)

In [168]:
# 3. Save the pretrained model
model_filename = "M1_savedModel.pth"
model_path = os.path.join(storing_path51, model_filename)
torch.save(model_simclr_us_pretrain.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of pretrain loop
# Plot and save pretraining results
plot_pretrain_results(pretraining_results_path_simclr_us)

# Pre-Finetune t-distributed Stochastic Neighbor Embedding (t-SNE) Visualization
tsne_pipeline(model_simclr_us_pretrain, tSNE_us_dataloader, device, pretraining_results_path_simclr_us, is_byol=False)

### 5.2. Using <font color='green'>SimCLR</font> method and <font color='pink'>Mammography</font> dataset

In [181]:
# 1. Prepare model
# Initialize the SimCLR model with the ResNet backbone
backbone_simclr_mammo_pretrain = initialize_resnet_backbone(False)
model_simclr_mammo_pretrain = SimCLR_v2(backbone_simclr_mammo_pretrain, projection_dim=128)
model_simclr_mammo_pretrain.to(device)

# Contrastive loss with temperature
criterion_simclr_mammo_pretrain = SimCLRNTXentLoss(temperature)
criterion_simclr_mammo_pretrain.to(device)

# Adam optimizer
optimizer_simclr_mammo_pretrain = optim.Adam(model_simclr_mammo_pretrain.parameters(), lr = lr_pretrain, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler
scheduler_simclr_mammo_pretrain = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_simclr_mammo_pretrain, mode = 'min', patience = 3, factor = 0.5)

In [None]:
# 2. Pretrain loop
pretraining_results_path_simclr_mammo = run_pretraining_simclr(
    model_simclr_mammo_pretrain,                   # SimCLR model
    criterion_simclr_mammo_pretrain,               # NTXentLoss (contrastive loss function)
    mg_pretrain_train_dataloader,                  # Dataloader for training data
    mg_pretrain_val_dataloader,                    # Dataloader for validation data
    optimizer_simclr_mammo_pretrain,               # Adam optimizer
    scheduler_simclr_mammo_pretrain,               # Learning rate scheduler
    device,                                        # Device to run training on (GPU/CPU)
    total_epoch_count,                             # Number of epochs to train
    storing_path52,                                # Path to store logs
    "M2",                                          # Model identifier
    checkpoint_interval                            # Save checkpoints every N epochs
)

In [187]:
# 3. Save the pretrained model
model_filename = "M2_savedModel.pth"
model_path = os.path.join(storing_path52, model_filename)
torch.save(model_simclr_mammo_pretrain.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of pretrain loop
# Plot and save pretraining results
plot_pretrain_results(pretraining_results_path_simclr_mammo)

# Pre-Finetune t-distributed Stochastic Neighbor Embedding (t-SNE) Visualization
tsne_pipeline(model_simclr_mammo_pretrain, tSNE_mg_dataloader, device, pretraining_results_path_simclr_mammo)

### 5.3. Using <font color='green'>SimCLR</font> method and <font color='orange'>Multimodal</font> dataset

In [189]:
# 1. Prepare model
# Initialize the SimCLR model with the ResNet backbone
backbone_simclr_multi_pretrain = initialize_resnet_backbone(False)
model_simclr_multi_pretrain = SimCLR_v2(backbone_simclr_multi_pretrain, projection_dim=128)
model_simclr_multi_pretrain.to(device)

# Contrastive loss with temperature
criterion_simclr_multi_pretrain = SimCLRNTXentLoss(temperature)
criterion_simclr_multi_pretrain.to(device)

# Adam optimizer
optimizer_simclr_multi_pretrain = optim.Adam(model_simclr_multi_pretrain.parameters(), lr = lr_pretrain, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler
scheduler_simclr_multi_pretrain = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_simclr_multi_pretrain, mode = 'min', patience = 3, factor = 0.5)

In [None]:
# 2. Pretrain loop
pretraining_results_path_simclr_multi = run_pretraining_simclr(
    model_simclr_multi_pretrain,                      # SimCLR model
    criterion_simclr_multi_pretrain,                  # NTXentLoss (contrastive loss function)
    multi_pretrain_train_dataloader,                  # Dataloader for training data
    multi_pretrain_val_dataloader,                    # Dataloader for validation data
    optimizer_simclr_multi_pretrain,                  # Adam optimizer
    scheduler_simclr_multi_pretrain,                  # Learning rate scheduler
    device,                                           # Device to run training on (GPU/CPU)
    total_epoch_count,                                # Number of epochs to train
    storing_path53,                                   # Path to store logs
    "M3",                                             # Model identifier
    checkpoint_interval                               # Save checkpoints every N epochs
)

In [191]:
# 3. Save the pretrained model
model_filename = "M3_savedModel.pth"
model_path = os.path.join(storing_path53, model_filename)
torch.save(model_simclr_multi_pretrain.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of pretrain loop
# Plot and save pretraining results
plot_pretrain_results(pretraining_results_path_simclr_multi)

# Pre-Finetune t-distributed Stochastic Neighbor Embedding (t-SNE) Visualization
tsne_pipeline(model_simclr_multi_pretrain, tSNE_multi_dataloader, device, pretraining_results_path_simclr_multi)

### 5.4. Using <font color='turquoise'>MoCo</font> method and <font color='purple'>Ultrasound</font> dataset

In [None]:
#new (checking)
# 1. Prepare model
# Initialize the MoCo model with the ResNet backbone
backbone_moco_us_pretrain = initialize_resnet_backbone(False)
model_moco_us_pretrain = MoCo_v2(backbone_moco_us_pretrain)
model_moco_us_pretrain.to(device)

# Contrastive loss with queue_size, temperature
criterion_moco_us_pretrain = MoCoNTXentLoss(queue_size = 16384, temperature = temperature)  # Contrastive loss with temperature
criterion_moco_us_pretrain.to(device)

# Adam optimizer
optimizer_moco_us_pretrain = optim.Adam(model_moco_us_pretrain.parameters(), lr = lr_pretrain, weight_decay = weight_decay)  

#Moco_v2 works better with cosineAnnealingLR scheduler
scheduler_moco_us_pretrain = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer_moco_us_pretrain, 
    T_max = total_epoch_count,  # total number of epochs for one cycle
    eta_min = 0,  # Minimum learning rate (you can adjust this if needed)
    last_epoch = -1  # Start from scratch
)

In [None]:
# 2. Pretrain loop
pretraining_results_path_moco_us = run_pretraining_moco(
    model_moco_us_pretrain,             # MoCo model
    criterion_moco_us_pretrain,         # Adapted NTXentLoss (contrastive loss function)
    us_pretrain_train_dataloader,       # Dataloader for training data
    us_pretrain_val_dataloader,         # Dataloader for validation data
    optimizer_moco_us_pretrain,         # Adam optimizer
    scheduler_moco_us_pretrain,         # Learning rate scheduler
    device,                             # Device to run training on (GPU/CPU)
    total_epoch_count,                  # Number of epochs to train
    storing_path54,                     # Path to store logs
    "M4",                               # Model identifier
    checkpoint_interval                 # Save checkpoints every N epochs
)

In [195]:
# 3. Save the pretrained model
model_filename = "M4_savedModel.pth"
model_path = os.path.join(storing_path54, model_filename)
torch.save(model_moco_us_pretrain.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of pretrain loop
# Plot and save pretraining results
plot_pretrain_results(pretraining_results_path_moco_us)

# Pre-Finetune t-distributed Stochastic Neighbor Embedding (t-SNE) Visualization
tsne_pipeline(model_moco_us_pretrain, tSNE_us_dataloader, device, pretraining_results_path_moco_us, is_byol = False)

### 5.5. Using <font color='turquoise'>MoCo</font> method and <font color='pink'>Mammography</font> dataset

In [None]:
#new (checking)
# 1. Prepare model
# Initialize the MoCo model with the ResNet backbone
backbone_moco_mammo_pretrain = initialize_resnet_backbone(False)
model_moco_mammo_pretrain = MoCo_v2(backbone_moco_mammo_pretrain)
model_moco_mammo_pretrain.to(device)

# Contrastive loss with queue_size, temperature
criterion_moco_mammo_pretrain = MoCoNTXentLoss(queue_size = 16384, temperature = temperature)  # Contrastive loss with temperature
criterion_moco_mammo_pretrain.to(device)

# Adam optimizer
optimizer_moco_mammo_pretrain = optim.Adam(model_moco_mammo_pretrain.parameters(), lr = lr_pretrain, weight_decay = weight_decay)  

#Moco_v2 works better with cosineAnnealingLR scheduler
scheduler_moco_mammo_pretrain = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer_moco_mammo_pretrain, 
    T_max = total_epoch_count,  # total number of epochs for one cycle
    eta_min = 0,  # minimum learning rate (you can adjust this if needed)
    last_epoch = -1  # start from scratch
)

In [None]:
# 2. Pretrain loop
pretraining_results_path_moco_mammo = run_pretraining_moco(
    model_moco_mammo_pretrain,             # MoCo model
    criterion_moco_mammo_pretrain,         # Adapted NTXentLoss (contrastive loss function)
    mg_pretrain_train_dataloader,          # Dataloader for training data
    mg_pretrain_val_dataloader,            # Dataloader for validation data
    optimizer_moco_mammo_pretrain,         # Adam optimizer
    scheduler_moco_mammo_pretrain,         # Learning rate scheduler
    device,                                # Device to run training on (GPU/CPU)
    total_epoch_count,                     # Number of epochs to train
    storing_path55,                        # Path to store logs
    "M5",                                  # Model identifier
    checkpoint_interval                    # Save checkpoints every N epochs
)

In [199]:
# 3. Save the pretrained model
model_filename = "M5_savedModel.pth"
model_path = os.path.join(storing_path55, model_filename)
torch.save(model_moco_mammo_pretrain.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of pretrain loop
# Plot and save pretraining results
plot_pretrain_results(pretraining_results_path_moco_mammo)

# Pre-Finetune t-distributed Stochastic Neighbor Embedding (t-SNE) Visualization
tsne_pipeline(model_moco_mammo_pretrain, tSNE_mg_dataloader, device, pretraining_results_path_moco_mammo, is_byol = False)

### 5.6. Using <font color='turquoise'>MoCo</font> method and <font color='orange'>Multimodal</font> dataset

In [None]:
#new (checking)
# 1. Prepare model
# Initialize the MoCo model with the ResNet backbone
backbone_moco_multi_pretrain = initialize_resnet_backbone(False)
model_moco_multi_pretrain = MoCo_v2(backbone_moco_multi_pretrain)
model_moco_multi_pretrain.to(device)

# Contrastive loss with queue_size, temperature
criterion_moco_multi_pretrain = MoCoNTXentLoss(queue_size = 16384, temperature = temperature)  # Contrastive loss with temperature
criterion_moco_multi_pretrain.to(device)

# Adam optimizer
optimizer_moco_multi_pretrain = optim.Adam(model_moco_multi_pretrain.parameters(), lr = lr_pretrain, weight_decay = weight_decay)  

#Moco_v2 works better with cosineAnnealingLR scheduler
scheduler_moco_multi_pretrain = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer_moco_multi_pretrain, 
    T_max = total_epoch_count,  # total number of epochs for one cycle
    eta_min = 0,  # minimum learning rate (you can adjust this if needed)
    last_epoch = -1  # start from scratch
)

In [None]:
# 2. Pretrain loop
pretraining_results_path_moco_multi= run_pretraining_moco(
    model_moco_multi_pretrain,             # MoCo model
    criterion_moco_multi_pretrain,         # Adapted NTXentLoss (contrastive loss function)
    multi_pretrain_train_dataloader,       # Dataloader for training data
    multi_pretrain_val_dataloader,         # Dataloader for validation data
    optimizer_moco_multi_pretrain,         # Adam optimizer
    scheduler_moco_multi_pretrain,         # Learning rate scheduler
    device,                                # Device to run training on (GPU/CPU)
    total_epoch_count,                     # Number of epochs to train
    storing_path56,                        # Path to store logs
    "M6",                                  # Model identifier
    checkpoint_interval                    # Save checkpoints every N epochs
)

In [203]:
# 3. Save the pretrained model
model_filename = "M6_savedModel.pth"
model_path = os.path.join(storing_path56, model_filename)
torch.save(model_moco_multi_pretrain.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of pretrain loop
# Plot and save pretraining results
plot_pretrain_results(pretraining_results_path_moco_multi)

# Pre-Finetune t-distributed Stochastic Neighbor Embedding (t-SNE) Visualization
tsne_pipeline(model_moco_multi_pretrain, tSNE_multi_dataloader, device, pretraining_results_path_moco_multi, is_byol = False)

### 5.7. Using <font color='yellow'>BYOL</font> method and <font color='purple'>Ultrasound</font> dataset

In [None]:
#new (checking)
# 1. Prepare model
# Initialize the MoCo model with the ResNet backbone
backbone_byol_us_pretrain = initialize_resnet_backbone(True)
model_byol_us_pretrain = BYOL(backbone_byol_us_pretrain)
model_byol_us_pretrain.to(device)

# Initialize loss function
criterion_byol_us_pretrain = BYOLLoss()
criterion_byol_us_pretrain.to(device)

# Adam optimizer
optimizer_byol_us_pretrain = optim.Adam(model_byol_us_pretrain.parameters(), lr = lr_pretrain, weight_decay = weight_decay)

# Original BYOL work didnt use scheduler, we follow suit

# All BYOL cases involve model_byol.online_encoder and is_byol = True

In [None]:
# 2. Pretrain loop
pretraining_results_path_byol_us = run_pretraining_byol(
    model_byol_us_pretrain,                  # BYOL model
    criterion_byol_us_pretrain,              # Negative cosine similarity
    us_pretrain_train_dataloader,            # Dataloader for training data
    us_pretrain_val_dataloader,              # Dataloader for validation data
    optimizer_byol_us_pretrain,              # Adam optimizer
    device,                                  # Device to run training on (GPU/CPU)
    total_epoch_count,                       # Number of epochs to train
    storing_path57,                          # Path to store logs
    "M7",                                    # Model identifier
    checkpoint_interval                      # Save checkpoints every N epochs
)

In [207]:
# 3. Save the pretrained model
model_filename = "M7_savedModel.pth"
model_path = os.path.join(storing_path57, model_filename)
torch.save(model_byol_us_pretrain.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of pretrain loop
# Plot and save pretraining results
plot_pretrain_results(pretraining_results_path_byol_us)

# Pre-Finetune t-distributed Stochastic Neighbor Embedding (t-SNE) Visualization
tsne_pipeline(model_byol_us_pretrain.online_encoder, tSNE_us_dataloader, device, pretraining_results_path_byol_us, is_byol = True)

### 5.8. Using <font color='yellow'>BYOL</font> method and <font color='pink'>Mammography</font> dataset

In [None]:
#new (checking)
# 1. Prepare model
# Initialize the MoCo model with the ResNet backbone
backbone_byol_mammo_pretrain = initialize_resnet_backbone(True)
model_byol_mammo_pretrain = BYOL(backbone_byol_mammo_pretrain)
model_byol_mammo_pretrain.to(device)

# Initialize loss function
criterion_byol_mammo_pretrain = BYOLLoss()
criterion_byol_mammo_pretrain.to(device)

# Adam optimizer
optimizer_byol_mammo_pretrain = optim.Adam(model_byol_mammo_pretrain.parameters(), lr = lr_pretrain, weight_decay = weight_decay)

In [None]:
# 2. Pretrain loop
pretraining_results_path_byol_mammo = run_pretraining_byol(
    model_byol_mammo_pretrain,                  # BYOL model
    criterion_byol_mammo_pretrain,              # Negative cosine similarity
    mg_pretrain_train_dataloader,               # Dataloader for training data
    mg_pretrain_val_dataloader,                 # Dataloader for validation data
    optimizer_byol_mammo_pretrain,              # Adam optimizer
    device,                                     # Device to run training on (GPU/CPU)
    total_epoch_count,                          # Number of epochs to train
    storing_path58,                             # Path to store logs
    "M8",                                       # Model identifier
    checkpoint_interval                         # Save checkpoints every N epochs
)

In [211]:
# 3. Save the pretrained model
model_filename = "M8_savedModel.pth"
model_path = os.path.join(storing_path58, model_filename)
torch.save(model_byol_mammo_pretrain.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of pretrain loop
# Plot and save pretraining results
plot_pretrain_results(pretraining_results_path_byol_mammo)

# Pre-Finetune t-distributed Stochastic Neighbor Embedding (t-SNE) Visualization
tsne_pipeline(model_byol_mammo_pretrain.online_encoder, tSNE_mg_dataloader, device, pretraining_results_path_byol_mammo, is_byol = True)

### 5.9. Using <font color='yellow'>BYOL</font> method and <font color='orange'>Multimodal</font> dataset

In [None]:
#new (checking)
# 1. Prepare model
# Initialize the MoCo model with the ResNet backbone
backbone_byol_multi_pretrain = initialize_resnet_backbone(True)
model_byol_multi_pretrain = BYOL(backbone_byol_multi_pretrain)
model_byol_multi_pretrain.to(device)

# Initialize loss function
criterion_byol_multi_pretrain = BYOLLoss()
criterion_byol_multi_pretrain.to(device)

# Adam optimizer
optimizer_byol_multi_pretrain = optim.Adam(model_byol_multi_pretrain.parameters(), lr = lr_pretrain, weight_decay = weight_decay)

In [None]:
# 2. Pretrain loop
pretraining_results_path_byol_multi = run_pretraining_byol(
    model_byol_multi_pretrain,                     # BYOL model
    criterion_byol_multi_pretrain,                 # Negative cosine similarity
    multi_pretrain_train_dataloader,               # Dataloader for training data
    multi_pretrain_val_dataloader,                 # Dataloader for validation data
    optimizer_byol_multi_pretrain,                 # Adam optimizer
    device,                                        # Device to run training on (GPU/CPU)
    total_epoch_count,                             # Number of epochs to train
    storing_path59,                                # Path to store logs
    "M9",                                          # Model identifier
    checkpoint_interval                            # Save checkpoints every N epochs
)

In [217]:
# 3. Save the pretrained model
model_filename = "M9_savedModel.pth"
model_path = os.path.join(storing_path59, model_filename)
torch.save(model_byol_multi_pretrain.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of pretrain loop
# Plot and save pretraining results
plot_pretrain_results(pretraining_results_path_byol_multi)

# Pre-Finetune t-distributed Stochastic Neighbor Embedding (t-SNE) Visualization
tsne_pipeline(model_byol_multi_pretrain.online_encoder, tSNE_multi_dataloader, device, pretraining_results_path_byol_multi, is_byol=True)

## 6. Finetuning / Testing for Downstream Task - Multiclassification

### 6.0. Setting classification models architecture and storing paths

In [None]:
storing_path61 = "../results/classification/simclr/ultrasound"
storing_path62 = "../results/classification/simclr/mammography"
storing_path63 = "../results/classification/simclr/multimodal"
storing_path64 = "../results/classification/moco/ultrasound"
storing_path65 = "../results/classification/moco/mammography"
storing_path66 = "../results/classification/moco/multimodal"
storing_path67 = "../results/classification/byol/ultrasound"
storing_path68 = "../results/classification/byol/mammography"
storing_path69 = "../results/classification/byol/multimodal"
storing_path610 = "../results/classification/supervised/ultrasound"
storing_path611 = "../results/classification/supervised/mammography"
storing_path612 = "../results/classification/supervised/multimodal"

class classif_model_finetune(nn.Module):
    def __init__(self, base_model, num_classes):
        super(classif_model_finetune, self).__init__()
        self.base_model = base_model  # Use the backbone of concluded Pretrain phase
        self.fc = nn.Linear(512, num_classes)  # Add a new FC layer

    def forward(self, x):
        x = self.base_model(x).flatten(start_dim = 1)
        x = self.fc(x)
        return x

# Supervised classification model with randomly initialized weights
class classif_model_supervised(nn.Module):
    def __init__(self, num_classes):
        super(classif_model_supervised, self).__init__()
        self.base_model = models.resnet18(weights=None)  # Random init (No Pretraining)

        # Modify first convolutional layer to accept 1-channel grayscale images
        self.base_model.conv1 = nn.Conv2d(
            in_channels = 1,
            out_channels = 64,  
            kernel_size = 7, 
            stride = 2, 
            padding = 3, 
            bias = False
        )

        # Remove the default fully connected layer (fc)
        in_features = self.base_model.fc.in_features  # Get input size of last layer
        self.base_model.fc = nn.Identity()  # Remove it

        # Replace final classification layer
        self.fc = nn.Linear(in_features, num_classes) 

    def forward(self, x):
        x = self.base_model(x).flatten(start_dim=1)
        x = self.fc(x)
        return x
    
lr_finetune = 0.001
weight_decay = 1e-6

### 6.1. With models pretrained with <font color='green'>SimCLR</font> method and employing the <font color='purple'>Ultrasound</font> dataset

In [None]:
# This is a 6 step process
# 1. Prepare model
# Initialize finetuning model with the associated ResNet of pretraining phase
# Always working with a copy ensuring no inadverted changes to original pretrained models
finetune_backbone = deepcopy(model_simclr_us_pretrain.backbone) 
model_simclr_us_classif = classif_model_finetune(finetune_backbone, num_classes)
model_simclr_us_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_simclr_us_classif = nn.CrossEntropyLoss()
criterion_simclr_us_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_simclr_us_classif = torch.optim.Adam(model_simclr_us_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_simclr_us_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_simclr_us_classif, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_simclr_us = classif_run_finetuning(
    model_simclr_us_classif,                           # model resulting from SimCLR pretraining         
    criterion_simclr_us_classif,                       # CrossEntropyLoss (the standard choice for classification)
    classif_us_finetune_train_dataloader,              # Dataloader for training data
    classif_us_finetune_val_dataloader,                # Dataloader for validation data
    optimizer_simclr_us_classif,                       # Adam optimizer
    scheduler_simclr_us_classif,                       # Learning rate scheduler
    device,                                            # Device to run training on (GPU/CPU)
    total_epoch_count,                                 # Number of epochs to finetune
    storing_path61,                                    # Path to store logs      
    "M1_classification",                               # Model identifier
    checkpoint_interval                                # Save checkpoints every N epochs
)

In [222]:
# 3. Save the finetuned classification model
model_filename = "M1_classification_savedModel.pth"
model_path = os.path.join(storing_path61, model_filename)
torch.save(model_simclr_us_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_simclr_us, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_simclr_us = classif_run_testing(
    model_simclr_us_classif,
    criterion_simclr_us_classif,
    classif_us_finetune_test_dataloader,
    storing_path61,
    "M1_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_simclr_us, classes)

### 6.2. With models pretrained with <font color='green'>SimCLR</font> method and employing the <font color='pink'>Mammography</font> dataset

In [None]:
# 1. Prepare model
# Initialize finetuning model with the associated ResNet of pretraining phase
# Always working with a copy to ensure no inadverted changes happen to original pretrain models since they are to be used twice (once per task of which we have 2)
finetune_backbone = deepcopy(model_simclr_mammo_pretrain.backbone) 
model_simclr_mammo_classif = classif_model_finetune(finetune_backbone, num_classes)
model_simclr_mammo_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_simclr_mammo_classif = nn.CrossEntropyLoss()
criterion_simclr_mammo_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_simclr_mammo_classif = torch.optim.Adam(model_simclr_mammo_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_simclr_mammo_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_simclr_mammo_classif, mode='min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_simclr_mammo = classif_run_finetuning(
    model_simclr_mammo_classif,                           # model resulting from SimCLR pretraining         
    criterion_simclr_mammo_classif,                       # CrossEntropyLoss (the standard choice for classification)
    classif_mg_finetune_train_dataloader,                 # Dataloader for training data
    classif_mg_finetune_val_dataloader,                   # Dataloader for validation data
    optimizer_simclr_mammo_classif,                       # Adam optimizer
    scheduler_simclr_mammo_classif,                       # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path62,                                       # Path to store logs      
    "M2_classification",                                  # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [228]:
# 3. Save the finetuned classification model
model_filename = "M2_classification_savedModel.pth"
model_path = os.path.join(storing_path62, model_filename)
torch.save(model_simclr_mammo_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_simclr_mammo, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_simclr_mammo = classif_run_testing(
    model_simclr_mammo_classif,
    criterion_simclr_mammo_classif,
    classif_mg_finetune_test_dataloader,
    storing_path62,
    "M2_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_simclr_mammo, classes)

### 6.3. With models pretrained with <font color='green'>SimCLR</font> method and employing the <font color='orange'>Multimodal</font> dataset

In [None]:
# 1. Prepare model
# Initialize finetuning model with the associated ResNet of pretraining phase
# Always working with a copy to ensure no inadverted changes happen to original pretrain models since they are to be used twice (once per task of which we have 2)
finetune_backbone = deepcopy(model_simclr_multi_pretrain.backbone) 
model_simclr_multi_classif = classif_model_finetune(finetune_backbone, num_classes)
model_simclr_multi_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_simclr_multi_classif = nn.CrossEntropyLoss()
criterion_simclr_multi_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_simclr_multi_classif = torch.optim.Adam(model_simclr_multi_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_simclr_multi_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_simclr_multi_classif, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_simclr_multi = classif_run_finetuning(
    model_simclr_multi_classif,                           # model resulting from SimCLR pretraining         
    criterion_simclr_multi_classif,                       # CrossEntropyLoss (the standard choice for classification)
    classif_multi_finetune_train_dataloader,              # Dataloader for training data
    classif_multi_finetune_val_dataloader,                # Dataloader for validation data
    optimizer_simclr_multi_classif,                       # Adam optimizer
    scheduler_simclr_multi_classif,                       # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path63,                                       # Path to store logs      
    "M3_classification",                                  # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [234]:
# 3. Save the finetuned classification model
model_filename = "M3_classification_savedModel.pth"
model_path = os.path.join(storing_path63, model_filename)
torch.save(model_simclr_multi_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_simclr_multi, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_simclr_multi = classif_run_testing(
    model_simclr_multi_classif,
    criterion_simclr_multi_classif,
    classif_multi_finetune_test_dataloader,
    storing_path63,
    "M3_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_simclr_multi, classes)

### 6.4. With models pretrained with <font color='turquoise'>MoCo</font> method and employing the <font color='purple'>Ultrasound</font> dataset

In [None]:
# 1. Prepare model
# Initialize finetuning model with the associated ResNet of pretraining phase
# Always working with a copy to ensure no inadverted changes happen to original pretrain models since they are to be used twice (once per task of which we have 2)
finetune_backbone = deepcopy(model_moco_us_pretrain.backbone) 
model_moco_us_classif = classif_model_finetune(finetune_backbone, num_classes)
model_moco_us_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_moco_us_classif = nn.CrossEntropyLoss()
criterion_moco_us_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_moco_us_classif = torch.optim.Adam(model_moco_us_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_moco_us_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_moco_us_classif, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_moco_us = classif_run_finetuning(
    model_moco_us_classif,                                # model resulting from SimCLR pretraining         
    criterion_moco_us_classif,                            # CrossEntropyLoss (the standard choice for classification)
    classif_us_finetune_train_dataloader,                 # Dataloader for training data
    classif_us_finetune_val_dataloader,                   # Dataloader for validation data
    optimizer_moco_us_classif,                            # Adam optimizer
    scheduler_moco_us_classif,                            # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path64,                                       # Path to store logs      
    "M4_classification",                                  # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [240]:
# 3. Save the finetuned classification model
model_filename = "M4_classification_savedModel.pth"
model_path = os.path.join(storing_path64, model_filename)
torch.save(model_moco_us_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_moco_us, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_moco_us = classif_run_testing(
    model_moco_us_classif,
    criterion_moco_us_classif,
    classif_us_finetune_test_dataloader,
    storing_path64,
    "M4_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_moco_us, classes)

### 6.5. With models pretrained with <font color='turquoise'>MoCo</font> method and employing the <font color='pink'>Mammography</font> dataset

In [None]:
# 1. Prepare model
# Initialize finetuning model with the associated ResNet of pretraining phase
# Always working with a copy to ensure no inadverted changes happen to original pretrain models since they are to be used twice (once per task of which we have 2)
finetune_backbone = deepcopy(model_moco_mammo_pretrain.backbone) 
model_moco_mammo_classif = classif_model_finetune(finetune_backbone, num_classes)
model_moco_mammo_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_moco_mammo_classif = nn.CrossEntropyLoss()
criterion_moco_mammo_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_moco_mammo_classif = torch.optim.Adam(model_moco_mammo_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_moco_mammo_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_moco_mammo_classif, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_moco_mammo = classif_run_finetuning(
    model_moco_mammo_classif,                           # model resulting from SimCLR pretraining         
    criterion_moco_mammo_classif,                       # CrossEntropyLoss (the standard choice for classification)
    classif_mg_finetune_train_dataloader,               # Dataloader for training data
    classif_mg_finetune_val_dataloader,                 # Dataloader for validation data
    optimizer_moco_mammo_classif,                       # Adam optimizer
    scheduler_moco_mammo_classif,                       # Learning rate scheduler
    device,                                             # Device to run training on (GPU/CPU)
    total_epoch_count,                                  # Number of epochs to finetune
    storing_path65,                                     # Path to store logs      
    "M5_classification",                                # Model identifier
    checkpoint_interval                                 # Save checkpoints every N epochs
)

In [246]:
# 3. Save the finetuned classification model
model_filename = "M5_classification_savedModel.pth"
model_path = os.path.join(storing_path65, model_filename)
torch.save(model_moco_mammo_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_moco_mammo, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_moco_mammo = classif_run_testing(
    model_moco_mammo_classif,
    criterion_moco_mammo_classif,
    classif_mg_finetune_test_dataloader,
    storing_path65,
    "M5_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_moco_mammo, classes)

### 6.6. With models pretrained with <font color='turquoise'>MoCo</font> method and employing the <font color='orange'>Multimodal</font> dataset

In [None]:
# 1. Prepare model
# Initialize finetuning model with the associated ResNet of pretraining phase
# Always working with a copy to ensure no inadverted changes happen to original pretrain models since they are to be used twice (once per task of which we have 2)
finetune_backbone = deepcopy(model_moco_multi_pretrain.backbone) 
model_moco_multi_classif = classif_model_finetune(finetune_backbone, num_classes)
model_moco_multi_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_moco_multi_classif = nn.CrossEntropyLoss()
criterion_moco_multi_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_moco_multi_classif = torch.optim.Adam(model_moco_multi_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_moco_multi_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_moco_multi_classif, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_moco_multi = classif_run_finetuning(
    model_moco_multi_classif,                             # model resulting from SimCLR pretraining         
    criterion_moco_multi_classif,                         # CrossEntropyLoss (the standard choice for classification)
    classif_multi_finetune_train_dataloader,              # Dataloader for training data
    classif_multi_finetune_val_dataloader,                # Dataloader for validation data
    optimizer_moco_multi_classif,                       # Adam optimizer
    scheduler_moco_multi_classif,                       # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path66,                                       # Path to store logs      
    "M6_classification",                                  # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [252]:
# 3. Save the finetuned classification model
model_filename = "M6_classification_savedModel.pth"
model_path = os.path.join(storing_path66, model_filename)
torch.save(model_moco_multi_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_moco_multi, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_moco_multi = classif_run_testing(
    model_moco_multi_classif,
    criterion_moco_multi_classif,
    classif_multi_finetune_test_dataloader,
    storing_path66,
    "M6_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_moco_multi, classes)

### 6.7. With models pretrained with <font color='yellow'>BYOL</font> method and employing the <font color='purple'>Ultrasound</font> dataset

In [None]:
# 1. Prepare model
# Initialize finetuning model with the associated ResNet of pretraining phase
# Always working with a copy to ensure no inadverted changes happen to original pretrain models since they are to be used twice (once per task of which we have 2)
finetune_backbone = deepcopy(model_byol_us_pretrain.online_encoder[0]) #the backbone for byol is accessed with .online_encoder[0]
model_byol_us_classif = classif_model_finetune(finetune_backbone, num_classes)
model_byol_us_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_byol_us_classif = nn.CrossEntropyLoss()
criterion_byol_us_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_byol_us_classif = torch.optim.Adam(model_byol_us_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_byol_us_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_byol_us_classif, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_byol_us = classif_run_finetuning(
    model_byol_us_classif,                             # model resulting from SimCLR pretraining         
    criterion_byol_us_classif,                         # CrossEntropyLoss (the standard choice for classification)
    classif_us_finetune_train_dataloader,              # Dataloader for training data
    classif_us_finetune_val_dataloader,                # Dataloader for validation data
    optimizer_byol_us_classif,                         # Adam optimizer
    scheduler_byol_us_classif,                         # Learning rate scheduler
    device,                                            # Device to run training on (GPU/CPU)
    total_epoch_count,                                 # Number of epochs to finetune
    storing_path67,                                    # Path to store logs      
    "M7_classification",                               # Model identifier
    checkpoint_interval                                # Save checkpoints every N epochs
)

In [258]:
# 3. Save the finetuned classification model
model_filename = "M7_classification_savedModel.pth"
model_path = os.path.join(storing_path67, model_filename)
torch.save(model_byol_us_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_byol_us, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_byol_us = classif_run_testing(
    model_byol_us_classif,
    criterion_byol_us_classif,
    classif_us_finetune_test_dataloader,
    storing_path67,
    "M7_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_byol_us, classes)

### 6.8. With models pretrained with <font color='yellow'>BYOL</font> method and employing the <font color='pink'>Mammography</font> dataset

In [None]:
# 1. Prepare model
# Initialize finetuning model with the associated ResNet of pretraining phase
# Always working with a copy to ensure no inadverted changes happen to original pretrain models since they are to be used twice (once per task of which we have 2)
finetune_backbone = deepcopy(model_byol_mammo_pretrain.online_encoder[0]) 
model_byol_mammo_classif = classif_model_finetune(finetune_backbone, num_classes)
model_byol_mammo_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_byol_mammo_classif = nn.CrossEntropyLoss()
criterion_byol_mammo_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_byol_mammo_classif = torch.optim.Adam(model_byol_mammo_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_byol_mammo_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_byol_mammo_classif, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_byol_mammo = classif_run_finetuning(
    model_byol_mammo_classif,                           # model resulting from SimCLR pretraining         
    criterion_byol_mammo_classif,                       # CrossEntropyLoss (the standard choice for classification)
    classif_mg_finetune_train_dataloader,               # Dataloader for training data
    classif_mg_finetune_val_dataloader,                 # Dataloader for validation data
    optimizer_byol_mammo_classif,                       # Adam optimizer
    scheduler_byol_mammo_classif,                       # Learning rate scheduler
    device,                                             # Device to run training on (GPU/CPU)
    total_epoch_count,                                  # Number of epochs to finetune
    storing_path68,                                     # Path to store logs      
    "M8_classification",                                # Model identifier
    checkpoint_interval                                 # Save checkpoints every N epochs
)

In [264]:
# 3. Save the finetuned classification model
model_filename = "M8_classification_savedModel.pth"
model_path = os.path.join(storing_path68, model_filename)
torch.save(model_byol_mammo_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_byol_mammo, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_byol_mammo = classif_run_testing(
    model_byol_mammo_classif,
    criterion_byol_mammo_classif,
    classif_mg_finetune_test_dataloader,
    storing_path68,
    "M8_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_byol_mammo, classes)

### 6.9. With models pretrained with <font color='yellow'>BYOL</font> method and employing the <font color='orange'>Multimodal</font> dataset

In [None]:
# 1. Prepare model
# Initialize finetuning model with the associated ResNet of pretraining phase
# Always working with a copy to ensure no inadverted changes happen to original pretrain models since they are to be used twice (once per task of which we have 2)
finetune_backbone = deepcopy(model_byol_multi_pretrain.online_encoder[0]) 
model_byol_multi_classif = classif_model_finetune(finetune_backbone, num_classes)
model_byol_multi_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_byol_multi_classif = nn.CrossEntropyLoss()
criterion_byol_multi_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_byol_multi_classif = torch.optim.Adam(model_byol_multi_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_byol_multi_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_byol_multi_classif, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_byol_multi = classif_run_finetuning(
    model_byol_multi_classif,                             # model resulting from SimCLR pretraining         
    criterion_byol_multi_classif,                         # CrossEntropyLoss (the standard choice for classification)
    classif_multi_finetune_train_dataloader,              # Dataloader for training data
    classif_multi_finetune_val_dataloader,                # Dataloader for validation data
    optimizer_byol_multi_classif,                         # Adam optimizer
    scheduler_byol_multi_classif,                         # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path69,                                       # Path to store logs      
    "M9_classification",                                  # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [270]:
# 3. Save the finetuned classification model
model_filename = "M9_classification_savedModel.pth"
model_path = os.path.join(storing_path69, model_filename)
torch.save(model_byol_multi_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_byol_multi, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_byol_multi = classif_run_testing(
    model_byol_multi_classif,
    criterion_byol_multi_classif,
    classif_multi_finetune_test_dataloader,
    storing_path69,
    "M9_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_byol_multi, classes)

### 6.10. With models that aren't <font color='red'>Pretrained</font> and employing the <font color='purple'>Ultrasound</font> dataset

In [None]:
# 1. Prepare model
model_supervised_us_classif = classif_model_supervised(num_classes)  # No pretraining
model_supervised_us_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_supervised_us_classif = nn.CrossEntropyLoss()
criterion_supervised_us_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_supervised_us_classif = torch.optim.Adam(model_supervised_us_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_supervised_us_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_supervised_us_classif, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_supervised_us = classif_run_finetuning(
    model_supervised_us_classif,                       # supervised counterpart model       
    criterion_supervised_us_classif,                   # CrossEntropyLoss (the standard choice for classification)
    classif_us_finetune_train_dataloader,              # Dataloader for training data
    classif_us_finetune_val_dataloader,                # Dataloader for validation data
    optimizer_supervised_us_classif,                   # Adam optimizer
    scheduler_supervised_us_classif,                   # Learning rate scheduler
    device,                                            # Device to run training on (GPU/CPU)
    total_epoch_count,                                 # Number of epochs to finetune
    storing_path610,                                   # Path to store logs      
    "M10_classification",                              # Model identifier
    checkpoint_interval                                # Save checkpoints every N epochs
)

In [248]:
# 3. Save the finetuned classification model
model_filename = "M10_classification_savedModel.pth"
model_path = os.path.join(storing_path610, model_filename)
torch.save(model_supervised_us_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_supervised_us, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_supervised_us = classif_run_testing(
    model_supervised_us_classif,
    criterion_supervised_us_classif,
    classif_us_finetune_test_dataloader,
    storing_path610,
    "M10_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_supervised_us, classes)

### 6.11. With models that aren't <font color='red'>Pretrained</font> and employing the <font color='pink'>Mammography</font> dataset

In [None]:
# 1. Prepare model
model_supervised_mammo_classif = classif_model_supervised(num_classes)  # No pretraining
model_supervised_mammo_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_supervised_mammo_classif = nn.CrossEntropyLoss()
criterion_supervised_mammo_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_supervised_mammo_classif = torch.optim.Adam(model_supervised_mammo_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_supervised_mammo_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_supervised_mammo_classif, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_supervised_mammo = classif_run_finetuning(
    model_supervised_mammo_classif,                    # supervised counterpart model       
    criterion_supervised_mammo_classif,                # CrossEntropyLoss (the standard choice for classification)
    classif_mg_finetune_train_dataloader,              # Dataloader for training data
    classif_mg_finetune_val_dataloader,                # Dataloader for validation data
    optimizer_supervised_mammo_classif,                # Adam optimizer
    scheduler_supervised_mammo_classif,                # Learning rate scheduler
    device,                                            # Device to run training on (GPU/CPU)
    total_epoch_count,                                 # Number of epochs to finetune
    storing_path611,                                   # Path to store logs      
    "M11_classification",                              # Model identifier
    checkpoint_interval                                # Save checkpoints every N epochs
)

In [254]:
# 3. Save the finetuned classification model
model_filename = "M11_classification_savedModel.pth"
model_path = os.path.join(storing_path611, model_filename)
torch.save(model_supervised_mammo_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_supervised_mammo, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_supervised_mammo = classif_run_testing(
    model_supervised_mammo_classif,
    criterion_supervised_mammo_classif,
    classif_mg_finetune_test_dataloader,
    storing_path611,
    "M11_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_supervised_mammo, classes)

### 6.12. With models that aren't <font color='red'>Pretrained</font> and employing the <font color='orange'>Multimodal</font> dataset

In [None]:
# 1. Prepare model
model_supervised_multi_classif = classif_model_supervised(num_classes)  # No pretraining
model_supervised_multi_classif.to(device)

# Standard cross entropy loss function used in classification
criterion_supervised_multi_classif = nn.CrossEntropyLoss()
criterion_supervised_multi_classif.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_supervised_multi_classif = torch.optim.Adam(model_supervised_multi_classif.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_supervised_multi_classif = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_supervised_multi_classif, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
classif_ft_results_path_supervised_multi = classif_run_finetuning(
    model_supervised_multi_classif,                    # supervised counterpart model       
    criterion_supervised_multi_classif,                # CrossEntropyLoss (the standard choice for classification)
    classif_multi_finetune_train_dataloader,           # Dataloader for training data
    classif_multi_finetune_val_dataloader,             # Dataloader for validation data
    optimizer_supervised_multi_classif,                # Adam optimizer
    scheduler_supervised_multi_classif,                # Learning rate scheduler
    device,                                            # Device to run training on (GPU/CPU)
    total_epoch_count,                                 # Number of epochs to finetune
    storing_path612,                                   # Path to store logs      
    "M12_classification",                              # Model identifier
    checkpoint_interval                                # Save checkpoints every N epochs
)

In [268]:
# 3. Save the finetuned classification model
model_filename = "M12_classification_savedModel.pth"
model_path = os.path.join(storing_path612, model_filename)
torch.save(model_supervised_multi_classif.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
classif_plot_finetuning_results(classif_ft_results_path_supervised_multi, classes)

In [None]:
# 5. Final test run of finetuned classification model
classif_test_results_path_supervised_multi = classif_run_testing(
    model_supervised_multi_classif,
    criterion_supervised_multi_classif,
    classif_multi_finetune_test_dataloader,
    storing_path612,
    "M12_classification",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned classification model
classif_plot_testing_results(classif_test_results_path_supervised_multi, classes)

## 7. Finetuning / Testing for Downstream Task - Binary Segmentation

### 7.0. Setting segmentation models architecture and storing paths

In [None]:
storing_path71 = "../results/segmentation/simclr/ultrasound"
storing_path72 = "../results/segmentation/simclr/mammography"
storing_path73 = "../results/segmentation/simclr/multimodal"
storing_path74 = "../results/segmentation/moco/ultrasound"
storing_path75 = "../results/segmentation/moco/mammography"
storing_path76 = "../results/segmentation/moco/multimodal"
storing_path77 = "../results/segmentation/byol/ultrasound"
storing_path78 = "../results/segmentation/byol/mammography"
storing_path79 = "../results/segmentation/byol/multimodal"
storing_path710 = "../results/segmentation/supervised/ultrasound"
storing_path711 = "../results/segmentation/supervised/mammography"
storing_path712 = "../results/segmentation/supervised/multimodal"

def extract_encoder_layers(pretrained_backbone):
    # Initial layers: Conv2d, BatchNorm2d, ReLU, MaxPool2d
    initial_layers = pretrained_backbone[:4]
    
    # Extract hierarchical layers
    layer1 = pretrained_backbone[4]
    layer2 = pretrained_backbone[5]
    layer3 = pretrained_backbone[6]
    layer4 = pretrained_backbone[7]

    return initial_layers, layer1, layer2, layer3, layer4

class segm_model_finetune(nn.Module):
    def __init__(self, base_model_backbone, num_classes=1):
        super(segm_model_finetune, self).__init__()

        initial, l1, l2, l3, l4 = extract_encoder_layers(base_model_backbone)
        self.encoder = nn.ModuleDict({
            "initial": initial,
            "layer1": l1,
            "layer2": l2,
            "layer3": l3,
            "layer4": l4
        })

        # Simplified Decoder layers
        self.decoder1 = self._conv_block(512 + 256, 256, upsample=True)  # Skip connection: 512 from encoder + 256 from previous decoder
        self.decoder2 = self._conv_block(256 + 128, 128, upsample=True)  # Skip connection: 256 from encoder + 128 from previous decoder
        self.decoder3 = self._conv_block(128 + 64, 64, upsample=True)    # Skip connection: 128 from encoder + 64 from previous decoder
        self.decoder4 = self._conv_block(64 + 64, 32, upsample=True)     # Skip connection: 64 from encoder + 64 from previous decoder

        # Segmentation head (1x1 convolution for pixel-wise classification)
        self.segmentation_head = nn.Conv2d(32, num_classes, kernel_size=1)

        # Initialize weights
        self._initialize_weights()

    def _conv_block(self, in_channels, out_channels, upsample=False):
        layers = []
        if upsample:
            layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False))
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.ReLU(inplace=True))

        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) and m not in self.encoder.modules():
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        # Encoder: Capture intermediate outputs
        x = self.encoder["initial"](x)
        encoder1 = self.encoder["layer1"](x)  # Early layer
        encoder2 = self.encoder["layer2"](encoder1)  # Downsampled
        encoder3 = self.encoder["layer3"](encoder2)  # Further downsampled
        encoder4 = self.encoder["layer4"](encoder3)  # Final encoder stage

        # Resize encoder outputs to match decoder spatial dimensions
        encoder3_resized = F.interpolate(encoder3, size=encoder4.shape[2:], mode='bilinear', align_corners=False)
        x = self.decoder1(torch.cat([encoder4, encoder3_resized], dim=1))

        encoder2_resized = F.interpolate(encoder2, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = self.decoder2(torch.cat([x, encoder2_resized], dim=1))

        encoder1_resized = F.interpolate(encoder1, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = self.decoder3(torch.cat([x, encoder1_resized], dim=1))

        # Resize encoder1 output again to match the current x
        encoder1_resized = F.interpolate(encoder1, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = self.decoder4(torch.cat([x, encoder1_resized], dim=1))

        # Segmentation head
        output = self.segmentation_head(x)

        # Resize output to match target mask size
        output = F.interpolate(output, size=(64, 64), mode='bilinear', align_corners=False)

        return output

class segm_model_supervised(nn.Module):
    def __init__(self, base_model_backbone, num_classes=1):
        super(segm_model_supervised, self).__init__()

        # Using the pre-trained ResNet backbone for feature extraction
        self.encoder = base_model_backbone  # ResNet backbone

        # Simplified Decoder layers (using the same architecture as in your previous model)
        self.decoder1 = self._conv_block(512 + 256, 256, upsample=True)  # Skip connection: 512 from encoder + 256 from previous decoder
        self.decoder2 = self._conv_block(256 + 128, 128, upsample=True)  # Skip connection: 256 from encoder + 128 from previous decoder
        self.decoder3 = self._conv_block(128 + 64, 64, upsample=True)    # Skip connection: 128 from encoder + 64 from previous decoder
        self.decoder4 = self._conv_block(64 + 64, 32, upsample=True)     # Skip connection: 64 from encoder + 64 from previous decoder

        # Segmentation head (1x1 convolution for pixel-wise classification)
        self.segmentation_head = nn.Conv2d(32, num_classes, kernel_size=1)

        # Initialize weights
        self._initialize_weights()

    def _conv_block(self, in_channels, out_channels, upsample=False):
        layers = []
        if upsample:
            layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False))
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.ReLU(inplace=True))

        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) and m not in self.encoder.modules():
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        # Encoder: Pass the input through the ResNet backbone
        x = self.encoder.conv1(x)  # Conv1 layer
        x = self.encoder.bn1(x)    # BatchNorm1 layer
        x = self.encoder.relu(x)   # ReLU activation
        x = self.encoder.maxpool(x)  # MaxPool layer

        # Now pass the input through the ResNet layers
        encoder1 = self.encoder.layer1(x)
        encoder2 = self.encoder.layer2(encoder1)
        encoder3 = self.encoder.layer3(encoder2)
        encoder4 = self.encoder.layer4(encoder3)

        # Resize encoder outputs to match decoder spatial dimensions
        encoder3_resized = F.interpolate(encoder3, size=encoder4.shape[2:], mode='bilinear', align_corners=False)
        x = self.decoder1(torch.cat([encoder4, encoder3_resized], dim=1))

        encoder2_resized = F.interpolate(encoder2, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = self.decoder2(torch.cat([x, encoder2_resized], dim=1))

        encoder1_resized = F.interpolate(encoder1, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = self.decoder3(torch.cat([x, encoder1_resized], dim=1))

        # Resize encoder1 output again to match the current x
        encoder1_resized = F.interpolate(encoder1, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = self.decoder4(torch.cat([x, encoder1_resized], dim=1))

        # Segmentation head (final output map)
        output = self.segmentation_head(x)

        # Resize output to match target mask size
        output = F.interpolate(output, size=(64, 64), mode='bilinear', align_corners=False)

        return output

# Loss function used for segmentation nets
class DiceBCELoss(nn.Module):
    def __init__(self):
        super(DiceBCELoss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, preds, targets):
        # BCE loss directly applied to raw logits
        bce_loss = self.bce_loss(preds, targets)

        # Compute Dice loss
        # Apply sigmoid to logits for Dice loss calculation
        preds = torch.sigmoid(preds)
        intersection = torch.sum(preds * targets)
        smooth = 1e-6
        dice_loss = 1 - (2. * intersection + smooth) / (torch.sum(preds) + torch.sum(targets) + smooth)
        
        # Combined loss
        return bce_loss + dice_loss


### 7.1. With models pretrained with <font color='green'>SimCLR</font> method and employing the <font color='purple'>Ultrasound</font> dataset

In [None]:
# This is a 6 step process
# 1. Prepare model
# Always working with a copy ensuring no inadverted changes to original pretrained models
finetune_backbone = deepcopy(model_simclr_us_pretrain.backbone) 
model_simclr_us_segm = segm_model_finetune(finetune_backbone, num_classes = 1)
model_simclr_us_segm.to(device)

# Combined dice and binary cross entropy loss function used for segmentation here
criterion_simclr_us_segm = DiceBCELoss()
criterion_simclr_us_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_simclr_us_segm = torch.optim.Adam(model_simclr_us_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_simclr_us_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_simclr_us_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_simclr_us = segm_run_finetuning(
    model_simclr_us_segm,                           # model resulting from pretraining         
    criterion_simclr_us_segm,                       # DiceBCELoss (chosen loss function for segmentation)
    segm_us_finetune_train_dataloader,              # Dataloader for training data
    segm_us_finetune_val_dataloader,                # Dataloader for validation data
    optimizer_simclr_us_segm,                       # Adam optimizer
    scheduler_simclr_us_segm,                       # Learning rate scheduler
    device,                                         # Device to run training on (GPU/CPU)
    total_epoch_count,                              # Number of epochs to finetune
    storing_path71,                                 # Path to store logs      
    "M1_segmentation",                              # Model identifier
    checkpoint_interval                             # Save checkpoints every N epochs
)

In [448]:
# 3. Save the finetuned segmentation model
model_filename = "M1_segmentation_savedModel.pth"
model_path = os.path.join(storing_path71, model_filename)
torch.save(model_simclr_us_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_simclr_us)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_simclr_us = segm_run_testing(
    model_simclr_us_segm,
    criterion_simclr_us_segm,
    segm_us_finetune_test_dataloader,
    storing_path71,
    "M1_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_simclr_us)

### 7.2. With models pretrained with <font color='green'>SimCLR</font> method and employing the <font color='pink'>Mammography</font> dataset

In [None]:
# 1. Prepare model
# Always working with a copy ensuring no inadverted changes to original pretrained models
finetune_backbone = deepcopy(model_simclr_mammo_pretrain.backbone) 
model_simclr_mammo_segm = segm_model_finetune(finetune_backbone, num_classes = 1)
model_simclr_mammo_segm.to(device)

# Combined dice and binary cross entropy loss function used for segmentation here
criterion_simclr_mammo_segm = DiceBCELoss()
criterion_simclr_mammo_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_simclr_mammo_segm = torch.optim.Adam(model_simclr_mammo_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_simclr_mammo_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_simclr_mammo_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_simclr_mammo = segm_run_finetuning(
    model_simclr_mammo_segm,                           # model resulting from pretraining         
    criterion_simclr_mammo_segm,                       # DiceBCELoss (chosen loss function for segmentation)
    segm_mg_finetune_train_dataloader,                 # Dataloader for training data
    segm_mg_finetune_val_dataloader,                   # Dataloader for validation data
    optimizer_simclr_mammo_segm,                       # Adam optimizer
    scheduler_simclr_mammo_segm,                       # Learning rate scheduler
    device,                                            # Device to run training on (GPU/CPU)
    total_epoch_count,                                 # Number of epochs to finetune
    storing_path72,                                    # Path to store logs      
    "M2_segmentation",                                 # Model identifier
    checkpoint_interval                                # Save checkpoints every N epochs
)

In [454]:
# 3. Save the finetuned segmentation model
model_filename = "M2_segmentation_savedModel.pth"
model_path = os.path.join(storing_path72, model_filename)
torch.save(model_simclr_mammo_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_simclr_mammo)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_simclr_mammo = segm_run_testing(
    model_simclr_mammo_segm,
    criterion_simclr_mammo_segm,
    segm_mg_finetune_test_dataloader,
    storing_path72,
    "M2_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_simclr_mammo)

### 7.3. With models pretrained with <font color='green'>SimCLR</font> method and employing the <font color='orange'>Multimodal</font> dataset

In [None]:
# 1. Prepare model
# Always working with a copy ensuring no inadverted changes to original pretrained models
finetune_backbone = deepcopy(model_simclr_multi_pretrain.backbone) 
model_simclr_multi_segm = segm_model_finetune(finetune_backbone, num_classes = 1)
model_simclr_multi_segm.to(device)

# combined dice and binary cross entropy loss function used for segmentation here
criterion_simclr_multi_segm = DiceBCELoss()
criterion_simclr_multi_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_simclr_multi_segm = torch.optim.Adam(model_simclr_multi_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_simclr_multi_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_simclr_multi_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_simclr_multi = segm_run_finetuning(
    model_simclr_multi_segm,                              # model resulting from pretraining         
    criterion_simclr_multi_segm,                          # DiceBCELoss (chosen loss function for segmentation)
    segm_multi_finetune_train_dataloader,                 # Dataloader for training data
    segm_multi_finetune_val_dataloader,                   # Dataloader for validation data
    optimizer_simclr_multi_segm,                          # Adam optimizer
    scheduler_simclr_multi_segm,                          # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path73,                                       # Path to store logs      
    "M3_segmentation",                                    # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [460]:
# 3. Save the finetuned segmentation model
model_filename = "M3_segmentation_savedModel.pth"
model_path = os.path.join(storing_path73, model_filename)
torch.save(model_simclr_multi_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_simclr_multi)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_simclr_multi = segm_run_testing(
    model_simclr_multi_segm,
    criterion_simclr_multi_segm,
    segm_multi_finetune_test_dataloader,
    storing_path73,
    "M3_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_simclr_multi)

### 7.4. With models pretrained with <font color='turquoise'>MoCo</font> method and employing the <font color='purple'>Ultrasound</font> dataset

In [None]:
#ESTAVAS AQUI A CORRER ESTA PARTE (7.4) CORRE ATE AO FIM E SE TUDO TIVER OK VOLTA A CONTINUAR A ESCREVER RELATORIO E ESPERA RESPOSTA DO PROF SOBRE METER ISTO A CORRER NUMA MAQUINA DA FCUL

In [None]:
# 1. Prepare model
# Always working with a copy ensuring no inadverted changes to original pretrained models
finetune_backbone = deepcopy(model_moco_us_pretrain.backbone) 
model_moco_us_segm = segm_model_finetune(finetune_backbone, num_classes = 1)
model_moco_us_segm.to(device)

# Combined dice and binary cross entropy loss function used for segmentation here
criterion_moco_us_segm = DiceBCELoss()
criterion_moco_us_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_moco_us_segm = torch.optim.Adam(model_moco_us_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_moco_us_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_moco_us_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_moco_us = segm_run_finetuning(
    model_moco_us_segm,                                   # model resulting from pretraining         
    criterion_moco_us_segm,                               # DiceBCELoss (chosen loss function for segmentation)
    segm_us_finetune_train_dataloader,                    # Dataloader for training data
    segm_us_finetune_val_dataloader,                      # Dataloader for validation data
    optimizer_moco_us_segm,                               # Adam optimizer
    scheduler_moco_us_segm,                               # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path74,                                       # Path to store logs      
    "M4_segmentation",                                    # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [470]:
# 3. Save the finetuned segmentation model
model_filename = "M4_segmentation_savedModel.pth"
model_path = os.path.join(storing_path74, model_filename)
torch.save(model_moco_us_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_moco_us)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_moco_us = segm_run_testing(
    model_moco_us_segm,
    criterion_moco_us_segm,
    segm_us_finetune_test_dataloader,
    storing_path74,
    "M4_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_moco_us)

### 7.5. With models pretrained with <font color='turquoise'>MoCo</font> method and employing the <font color='pink'>Mammography</font> dataset

In [None]:
# 1. Prepare model
# Always working with a copy ensuring no inadverted changes to original pretrained models
finetune_backbone = deepcopy(model_moco_mammo_pretrain.backbone) 
model_moco_mammo_segm = segm_model_finetune(finetune_backbone, num_classes = 1)
model_moco_mammo_segm.to(device)

# Combined dice and binary cross entropy loss function used for segmentation here
criterion_moco_mammo_segm = DiceBCELoss()
criterion_moco_mammo_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_moco_mammo_segm = torch.optim.Adam(model_moco_mammo_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_moco_mammo_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_moco_mammo_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_moco_mammo = segm_run_finetuning(
    model_moco_mammo_segm,                                # model resulting from pretraining         
    criterion_moco_mammo_segm,                            # DiceBCELoss (chosen loss function for segmentation)
    segm_mg_finetune_train_dataloader,                    # Dataloader for training data
    segm_mg_finetune_val_dataloader,                      # Dataloader for validation data
    optimizer_moco_mammo_segm,                            # Adam optimizer
    scheduler_moco_mammo_segm,                            # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path75,                                       # Path to store logs      
    "M5_segmentation",                                    # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [476]:
# 3. Save the finetuned segmentation model
model_filename = "M5_segmentation_savedModel.pth"
model_path = os.path.join(storing_path75, model_filename)
torch.save(model_moco_mammo_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_moco_mammo)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_moco_mammo = segm_run_testing(
    model_moco_mammo_segm,
    criterion_moco_mammo_segm,
    segm_mg_finetune_test_dataloader,
    storing_path75,
    "M5_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_moco_mammo)

### 7.6. With models pretrained with <font color='turquoise'>MoCo</font> method and employing the <font color='orange'>Multimodal</font> dataset

In [None]:
# 1. Prepare model
# Always working with a copy ensuring no inadverted changes to original pretrained models
finetune_backbone = deepcopy(model_moco_multi_pretrain.backbone) 
model_moco_multi_segm = segm_model_finetune(finetune_backbone, num_classes = 1)
model_moco_multi_segm.to(device)

# Combined dice and binary cross entropy loss function used for segmentation here
criterion_moco_multi_segm = DiceBCELoss()
criterion_moco_multi_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_moco_multi_segm = torch.optim.Adam(model_moco_multi_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_moco_multi_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_moco_multi_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_moco_multi = segm_run_finetuning(
    model_moco_multi_segm,                                # model resulting from pretraining         
    criterion_moco_multi_segm,                            # DiceBCELoss (chosen loss function for segmentation)
    segm_multi_finetune_train_dataloader,                 # Dataloader for training data
    segm_multi_finetune_val_dataloader,                   # Dataloader for validation data
    optimizer_moco_multi_segm,                            # Adam optimizer
    scheduler_moco_multi_segm,                            # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path76,                                       # Path to store logs      
    "M6_segmentation",                                    # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [482]:
# 3. Save the finetuned segmentation model
model_filename = "M6_segmentation_savedModel.pth"
model_path = os.path.join(storing_path76, model_filename)
torch.save(model_moco_multi_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_moco_multi)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_moco_multi = segm_run_testing(
    model_moco_multi_segm,
    criterion_moco_multi_segm,
    segm_multi_finetune_test_dataloader,
    storing_path76,
    "M6_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_moco_multi)

### 7.7. With models pretrained with <font color='yellow'>BYOL</font> method and employing the <font color='purple'>Ultrasound</font> dataset

In [None]:
# 1. Prepare model
# Always working with a copy ensuring no inadverted changes to original pretrained models
finetune_backbone = deepcopy(model_byol_us_pretrain.online_encoder[0]) 

# The backbone for BYOL is accessed with .online_encoder[0]
model_byol_us_segm = segm_model_finetune(finetune_backbone, num_classes = 1)
model_byol_us_segm.to(device)

# Combined dice and binary cross entropy loss function used for segmentation here
criterion_byol_us_segm = DiceBCELoss()
criterion_byol_us_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_byol_us_segm = torch.optim.Adam(model_byol_us_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_byol_us_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_byol_us_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_byol_us = segm_run_finetuning(
    model_byol_us_segm,                                   # model resulting from pretraining         
    criterion_byol_us_segm,                               # DiceBCELoss (chosen loss function for segmentation)
    segm_us_finetune_train_dataloader,                    # Dataloader for training data
    segm_us_finetune_val_dataloader,                      # Dataloader for validation data
    optimizer_byol_us_segm,                               # Adam optimizer
    scheduler_byol_us_segm,                               # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path77,                                       # Path to store logs      
    "M7_segmentation",                                    # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [488]:
# 3. Save the finetuned segmentation model
model_filename = "M7_segmentation_savedModel.pth"
model_path = os.path.join(storing_path77, model_filename)
torch.save(model_byol_us_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_byol_us)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_byol_us = segm_run_testing(
    model_byol_us_segm,
    criterion_byol_us_segm,
    segm_us_finetune_test_dataloader,
    storing_path77,
    "M7_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_byol_us)

### 7.8. With models pretrained with <font color='yellow'>BYOL</font> method and employing the <font color='pink'>Mammography</font> dataset

In [None]:
# 1. Prepare model
# Always working with a copy ensuring no inadverted changes to original pretrained models
finetune_backbone = deepcopy(model_byol_mammo_pretrain.online_encoder[0])

# The backbone for BYOL is accessed with .online_encoder[0]
model_byol_mammo_segm = segm_model_finetune(finetune_backbone, num_classes = 1)
model_byol_mammo_segm.to(device)

# Combined dice and binary cross entropy loss function used for segmentation here
criterion_byol_mammo_segm = DiceBCELoss()
criterion_byol_mammo_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_byol_mammo_segm = torch.optim.Adam(model_byol_mammo_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_byol_mammo_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_byol_mammo_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_byol_mammo = segm_run_finetuning(
    model_byol_mammo_segm,                                # model resulting from pretraining         
    criterion_byol_mammo_segm,                            # DiceBCELoss (chosen loss function for segmentation)
    segm_mg_finetune_train_dataloader,                    # Dataloader for training data
    segm_mg_finetune_val_dataloader,                      # Dataloader for validation data
    optimizer_byol_mammo_segm,                            # Adam optimizer
    scheduler_byol_mammo_segm,                            # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path78,                                       # Path to store logs      
    "M8_segmentation",                                    # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [494]:
# 3. Save the finetuned segmentation model
model_filename = "M8_segmentation_savedModel.pth"
model_path = os.path.join(storing_path78, model_filename)
torch.save(model_byol_mammo_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_byol_mammo)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_byol_mammo = segm_run_testing(
    model_byol_mammo_segm,
    criterion_byol_mammo_segm,
    segm_mg_finetune_test_dataloader,
    storing_path78,
    "M8_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_byol_mammo)

### 7.9. With models pretrained with <font color='yellow'>BYOL</font> method and employing the <font color='orange'>Multimodal</font> dataset

In [None]:
# 1. Prepare model
# Always working with a copy ensuring no inadverted changes to original pretrained models
finetune_backbone = deepcopy(model_byol_multi_pretrain.online_encoder[0])

# The backbone for BYOL is accessed with .online_encoder[0]
model_byol_multi_segm = segm_model_finetune(finetune_backbone, num_classes = 1)
model_byol_multi_segm.to(device)

# Combined dice and binary cross entropy loss function used for segmentation here
criterion_byol_multi_segm = DiceBCELoss()
criterion_byol_multi_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_byol_multi_segm = torch.optim.Adam(model_byol_multi_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_byol_multi_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_byol_multi_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_byol_multi = segm_run_finetuning(
    model_byol_multi_segm,                                # model resulting from pretraining         
    criterion_byol_multi_segm,                            # DiceBCELoss (chosen loss function for segmentation)
    segm_multi_finetune_train_dataloader,                 # Dataloader for training data
    segm_multi_finetune_val_dataloader,                   # Dataloader for validation data
    optimizer_byol_multi_segm,                            # Adam optimizer
    scheduler_byol_multi_segm,                            # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path79,                                       # Path to store logs      
    "M9_segmentation",                                    # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [500]:
# 3. Save the finetuned segmentation model
model_filename = "M9_segmentation_savedModel.pth"
model_path = os.path.join(storing_path79, model_filename)
torch.save(model_byol_multi_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_byol_multi)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_byol_multi = segm_run_testing(
    model_byol_multi_segm,
    criterion_byol_multi_segm,
    segm_multi_finetune_test_dataloader,
    storing_path79,
    "M9_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_byol_multi)

### 7.10. With models that aren't <font color='red'>Pretrained</font> and employing the <font color='purple'>Ultrasound</font> dataset

In [None]:
# 1. Prepare model (Supervised - no pretraining)
supervised_us_backbone = torchvision.models.resnet18(weights = None)

# Change the first convolutional layer to accept 1 input channel
supervised_us_backbone.conv1 = nn.Conv2d(1, 64, kernel_size = (7, 7), stride = (2, 2), padding = (3, 3), bias = False)

model_supervised_us_segm = segm_model_supervised(base_model_backbone = supervised_us_backbone, num_classes = 1)
model_supervised_us_segm.to(device)

# Combined dice and binary cross entropy loss function used for segmentation here
criterion_supervised_us_segm = DiceBCELoss()
criterion_supervised_us_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_supervised_us_segm = torch.optim.Adam(model_supervised_us_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_supervised_us_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_supervised_us_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_supervised_us = segm_run_finetuning(
    model_supervised_us_segm,                             # model resulting from pretraining         
    criterion_supervised_us_segm,                         # DiceBCELoss (chosen loss function for segmentation)
    segm_us_finetune_train_dataloader,                    # Dataloader for training data
    segm_us_finetune_val_dataloader,                      # Dataloader for validation data
    optimizer_supervised_us_segm,                         # Adam optimizer
    scheduler_supervised_us_segm,                         # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path710,                                      # Path to store logs      
    "M10_segmentation",                                   # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [275]:
# 3. Save the finetuned segmentation model
model_filename = "M10_segmentation_savedModel.pth"
model_path = os.path.join(storing_path710, model_filename)
torch.save(model_supervised_us_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_supervised_us)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_supervised_us = segm_run_testing(
    model_supervised_us_segm,
    criterion_supervised_us_segm,
    segm_us_finetune_test_dataloader,
    storing_path710,
    "M10_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_supervised_us)

### 7.11. With models that aren't <font color='red'>Pretrained</font> and employing the <font color='pink'>Mammography</font> dataset

In [None]:
# 1. Prepare model (Supervised - no pretraining)
supervised_mammo_backbone = torchvision.models.resnet18(weights = None)

# Change the first convolutional layer to accept 1 input channel
supervised_mammo_backbone.conv1 = nn.Conv2d(1, 64, kernel_size = (7, 7), stride = (2, 2), padding = (3, 3), bias = False)

model_supervised_mammo_segm = segm_model_supervised(base_model_backbone = supervised_mammo_backbone, num_classes = 1)
model_supervised_mammo_segm.to(device)

# Combined dice and binary cross entropy loss function used for segmentation here
criterion_supervised_mammo_segm = DiceBCELoss()
criterion_supervised_mammo_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_supervised_mammo_segm = torch.optim.Adam(model_supervised_mammo_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_supervised_mammo_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_supervised_mammo_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_supervised_mammo = segm_run_finetuning(
    model_supervised_mammo_segm,                          # model resulting from pretraining         
    criterion_supervised_mammo_segm,                      # DiceBCELoss (chosen loss function for segmentation)
    segm_mg_finetune_train_dataloader,                    # Dataloader for training data
    segm_mg_finetune_val_dataloader,                      # Dataloader for validation data
    optimizer_supervised_mammo_segm,                      # Adam optimizer
    scheduler_supervised_mammo_segm,                      # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path711,                                      # Path to store logs      
    "M11_segmentation",                                   # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [281]:
# 3. Save the finetuned segmentation model
model_filename = "M11_segmentation_savedModel.pth"
model_path = os.path.join(storing_path711, model_filename)
torch.save(model_supervised_mammo_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_supervised_mammo)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_supervised_mammo = segm_run_testing(
    model_supervised_mammo_segm,
    criterion_supervised_mammo_segm,
    segm_mg_finetune_test_dataloader,
    storing_path711,
    "M11_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_supervised_mammo)

### 7.12. With models that aren't <font color='red'>Pretrained</font> and employing the <font color='orange'>Multimodal</font> dataset

In [None]:
# 1. Prepare model (Supervised - no pretraining)
supervised_multi_backbone = torchvision.models.resnet18(weights = None)

# Change the first convolutional layer to accept 1 input channel
supervised_multi_backbone.conv1 = nn.Conv2d(1, 64, kernel_size = (7, 7), stride = (2, 2), padding = (3, 3), bias = False)

model_supervised_multi_segm = segm_model_supervised(base_model_backbone = supervised_multi_backbone, num_classes = 1)
model_supervised_multi_segm.to(device)

# Combined dice and binary cross entropy loss function used for segmentation here
criterion_supervised_multi_segm = DiceBCELoss()
criterion_supervised_multi_segm.to(device)

# Adam optimizer same as pretrain cases but with finetune learning rate
optimizer_supervised_multi_segm = torch.optim.Adam(model_supervised_multi_segm.parameters(), lr = lr_finetune, weight_decay = weight_decay)

# ReduceLROnPlateau scheduler (good for classification and segmentation)
scheduler_supervised_multi_segm = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_supervised_multi_segm, mode = 'min', patience = 3, factor = 0.1)

In [None]:
# 2. Finetune loop
segm_ft_results_path_supervised_multi = segm_run_finetuning(
    model_supervised_multi_segm,                          # model resulting from pretraining         
    criterion_supervised_multi_segm,                      # DiceBCELoss (chosen loss function for segmentation)
    segm_multi_finetune_train_dataloader,                 # Dataloader for training data
    segm_multi_finetune_val_dataloader,                   # Dataloader for validation data
    optimizer_supervised_multi_segm,                      # Adam optimizer
    scheduler_supervised_multi_segm,                      # Learning rate scheduler
    device,                                               # Device to run training on (GPU/CPU)
    total_epoch_count,                                    # Number of epochs to finetune
    storing_path712,                                      # Path to store logs      
    "M12_segmentation",                                   # Model identifier
    checkpoint_interval                                   # Save checkpoints every N epochs
)

In [287]:
# 3. Save the finetuned segmentation model
model_filename = "M12_segmentation_savedModel.pth"
model_path = os.path.join(storing_path712, model_filename)
torch.save(model_supervised_multi_segm.state_dict(), model_path)

In [None]:
# 4. Saving plots based on results of finetune loop
segm_plot_finetuning_results(segm_ft_results_path_supervised_multi)

In [None]:
# 5. Final test run of finetuned segmentation model
segm_test_results_path_supervised_multi = segm_run_testing(
    model_supervised_multi_segm,
    criterion_supervised_multi_segm,
    segm_multi_finetune_test_dataloader,
    storing_path712,
    "M12_segmentation",
    device
)

In [None]:
# 6. Saving plots based on results of final test of finetuned segmentation model
segm_plot_testing_results(segm_test_results_path_supervised_multi)