# Biomarker Detection in OLIVES using a Pretrained ResNet50 Model and Fusion Mechanisms for Clinical Data


### Step 1: Import data
Consistent for all models. Only change output size!

In [2]:
import os
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR, CyclicLR
from torch.utils.data import Dataset, Subset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, f1_score, classification_report
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

# set the size of the image according to your model needs
imageSize = 224 # ResNet works with 224x224 pixels

# Custom Dataset
class BiomarkerDataset(Dataset):
    def __init__(self, label_file, transform=None, num_frames=0):
        """
        Args:
            label_file (str): Path to the CSV file.
            transform (callable, optional): Transform to be applied on a sample.
            num_frames (int): Number of adjacent frames to use in the input sequence (1 adjacent frame -> 3 consecutive images).
        """
        self.data = pd.read_csv(label_file)
        self.transform = transform
        self.num_frames = num_frames
        
        # Normalization
        self.bcva_mean = self.data.iloc[:, 19].mean()
        self.bcva_std = self.data.iloc[:, 19].std()
        self.cst_mean = self.data.iloc[:, 20].mean()
        self.cst_std = self.data.iloc[:, 20].std()

        
        # Exclude indices which don't have enough adjacent images
        self.valid_indices = self.data[(self.data.iloc[:, 1] > num_frames) & (self.data.iloc[:, 1] < (50-num_frames))].index.tolist()

    def __len__(self):
        # we can't use the length of the data since we have to exclude the first and last image (for num_frames=1) of each OCT scan
        return len(self.valid_indices)

    def __getitem__(self, idx):
        
        # Base path
        img_base_path = '/storage/ice1/shared/d-pace_community/makerspace-datasets/MEDICAL/OLIVES/OLIVES'
        
        # Get the actual data index
        index = self.valid_indices[idx]
        
        # Initialize
        images = []
        
        # Load a sequence of consecutive images
        for i in range(index - self.num_frames, index + self.num_frames +1):
            img_path = img_base_path + self.data.iloc[i, 0]
            img = Image.open(img_path).convert("L") # 'L' is for grayscale; can be removed!?
            
            if self.transform is not None:
                # apply data transformations (transforms it to tensor)
                img = self.transform(img)
            
            # stack torch tensor
            img = img.squeeze(0)  # Removes the first dimension if it's 1
            images.append(img)
        
        # Stack the 3 grayscale images along the channel dimension
        # Resulting tensor shape will be [3, H, W]
        images = torch.stack(images, dim=0)
        # print(images.shape) # debugging
        
        # Biomarker columns
        labels = torch.tensor(self.data.iloc[index, 2:18].astype(float), dtype=torch.float32)
        
        # Get clinical data:
        eye_id = self.data.iloc[index, 18]
        bcva = self.data.iloc[index, 19]
        cst = self.data.iloc[index, 20]
        patient_id = self.data.iloc[index, 21]
        
        # Normalize
        bcva = (bcva - self.bcva_mean) / self.bcva_std
        cst = (cst - self.cst_mean) / self.cst_std

        # Convert clinical data to tensor
        clinical_data = torch.tensor([eye_id, bcva, cst, patient_id], dtype=torch.float32)
        
        return images, labels, clinical_data
    
    
# Define transformers

# Values for normalization taken from example paper
mean = 0.1706
std = 0.2112

# train with data augmentation
train_transformer = transforms.Compose([
    # WORSE PERFORMANCE # transforms.RandomPerspective(distortion_scale=0.1, p=0.5, fill=0),  # Add perspective shift
    # WORSE PERFORMANCE # transforms.RandomResizedCrop(size=imageSize, scale=(0.9, 1.0)), # RandomCrop between 70% to 100% of original size
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Adjust color properties
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip
    transforms.RandomRotation(degrees=10, fill=0),  # Rotates randomly between + and - degree and fills new pixels with black
    transforms.Resize(imageSize), # Resize to models needs
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean, std) # we have to calculate these values for our dataset
])
# train without data augmentation
test_transformer = transforms.Compose([   
    transforms.Resize(imageSize), # Resize to models needs
    transforms.CenterCrop(imageSize), # shouldn't do anything
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])


# set up train loader (just example since cross validation uses new ones)
train_dataset = BiomarkerDataset(label_file='OLIVES_Dataset_Labels/BiomarkerLabel_train_data.csv', transform=train_transformer, num_frames=1)
trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, drop_last=True, pin_memory=True)

# set up test loader (this one actually is being used)
test_dataset = BiomarkerDataset(label_file='OLIVES_Dataset_Labels/BiomarkerLabel_train_data.csv', transform=test_transformer, num_frames=1)
testloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)


### Step 2: Train model
First we initialize our model as well as some training parameters.

In [3]:
## --- Settings ---
num_epochs=40
batch_size=64
num_workers=16 # need this amount of CPUs for parallel data loading
k_folds=5
patience=8  # Number of epochs to wait for improvement

# get to cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## ---- Fusion model ----
class FusionModel(nn.Module):
    def __init__(self, num_eyeID, num_patientID):
        super(FusionModel, self).__init__()
        
        # Image feature extractor
        self.image_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.image_model.BackboneOutput = self.image_model.fc.in_features # should be 2048
        self.image_model.fc = nn.Identity() # remove final layer
        
        # Freeze ResNet parameters
#         for param in self.image_model.parameters():
#             param.requires_grad = False
        
        # Embedding layers for categorical data
        # These are only categorial data. Therefore just using the numerical value doesn't fit
        # The only difference a high patient ID f.e. should make, is that it's a new patient. Shouldn't lead to higher percentages of a biomaker
        # Alternativ: One-Hot Encoding
        self.patient_id_embedding = nn.Embedding(num_patientID, 16)
        self.eye_id_embedding = nn.Embedding(num_eyeID, 16)
        
        # Fully connected layers for numerical clinical data (e.g., BCVA, CST)
        self.clinical_fc = nn.Sequential(
            nn.Linear(2, 32),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Combine and compress all clinical data
        self.clinical_fusion_fc = nn.Sequential(
            nn.Linear(16 + 16 + 32, 32),  # Combine all clinical data
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(32, 64)
        )
        
#         # Attention mechanism
#         self.attention = nn.Sequential(
#             nn.Linear(16 + 16, 32),  # Combined image and clinical features
#             nn.ReLU(),
#             nn.Linear(32, 2),        # Output: attention weights for image and clinical features
#             nn.Softmax(dim=1)        # Normalize weights
#         )
        
        # Final fusion layers
        self.fusion_fc = nn.Sequential(
            nn.Linear(self.image_model.BackboneOutput + 64, 16),  # Combine weighted features
#             nn.ReLU(),
#             nn.Dropout(0.1),
#             nn.Linear(256, 16)  # 16 output biomarkers
         )
    
    def forward(self, image, clinical_data):
        # Extract features from the image using ResNet50
        image_features = self.image_model(image)
        
        # Clinical Data
        # Split clinical data into components
        eye_id = clinical_data[:, 0].long()
        bcva = clinical_data[:, 1].view(-1, 1).float()
        cst = clinical_data[:, 2].view(-1, 1).float()
        patient_id = clinical_data[:, 3].long()
        # Embed categorical clinical data
        eye_id_features = self.eye_id_embedding(eye_id).squeeze(1)
        patient_id_features = self.patient_id_embedding(patient_id).squeeze(1)
        # Process numerical clinical data
        numerical_features = torch.cat([bcva, cst], dim=1)
        clinical_features = self.clinical_fc(numerical_features)
        # Combine all clinical features
        combined_clinical_features = torch.cat([patient_id_features, eye_id_features, clinical_features], dim=1)
        fused_clinical_features = self.clinical_fusion_fc(combined_clinical_features)
        
#         # Attention Mechanism
#         # Compute attention weights
#         combined_features = torch.cat([image_features, fused_clinical_features], dim=1)
#         attention_weights = self.attention(combined_features)
#         # Apply attention weights to image and clinical features
#         image_weighted = attention_weights[:, 0].unsqueeze(1) * image_features
#         clinical_weighted = attention_weights[:, 1].unsqueeze(1) * fused_clinical_features       
#         # Combine weighted features
#         fused_features = image_weighted + clinical_weighted
        
        fused_features = torch.cat([image_features, fused_clinical_features], dim=1)
   
        # Final prediction through fusion layers
        output = self.fusion_fc(fused_features)
        
        return output

# Make sure there are enough different IDs available
num_eyeID = 100
num_patientID = 300
# Create the model
model = FusionModel(num_eyeID,num_patientID).to(device)
model.name = "FusionModelV2_ResNet50"

# DIDN'T IMPROVE # Compute class weights
# # We need to get some information on the class balance fro the original dataset
# BiomarkerLabel_df = pd.read_csv('OLIVES_Dataset_Labels/Biomarker_Clinical_Data_Images.csv')
# BiomarkerLabel_df = BiomarkerLabel_df.fillna(0) # NaN fix
# labels = BiomarkerLabel_df.columns[2:18] # Extract label columns names (biomarkers)
# df_labels = BiomarkerLabel_df[labels]
# # Compute class weights
# class_counts = df_labels.sum(axis=0).astype(int)
# # num_classes = len(class_counts)
# total_samples = len(df_labels)
# class_weights = total_samples / class_counts
# class_weights = torch.tensor(class_weights, dtype=torch.float32)          
                                
# Loss function, optimizer nad Learning rate sheduler
# loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights.to(device)) # Loss function for multi-label classification
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=0.9) # NOT GOOD: This set every f1 score to 0!!
    # weight decay to reduce overfitting 
# Scheduler (have to be used at different positions in the loop)
# scheduler = ExponentialLR(optimizer, gamma=0.9)
schedulerBatch = CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=2000, mode='triangular')

# Creates all needed folders to store the model weights if they don't exist already
os.makedirs("ModelWeights_TempSaves", exist_ok=True)
os.makedirs(f"TrainedModels/{model.name}", exist_ok=True)



Now we go over to the training process where we do a cross-validation.

In [3]:
# --- Train/Test Loops ---
# Training loop
def train_loop(model, train_loader, optimizer, loss_fn, device):
    # Set model to train mode
    model.train()
    
    # Initialize
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    # for images, labels, _ in train_loader:
    for images, labels, clinical_data in tqdm(train_loader, desc="Training"):
        # shift to cuda
        images = images.to(device)
        labels = labels.to(device)
        clinical_data = clinical_data.to(device)
        
        # Zero the parameter gradients 
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images, clinical_data)
        loss = loss_fn(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        schedulerBatch.step() # CyclicLR after each batch
        
        # Track predictions and labels for metrics calculation
        running_loss += loss.item() * images.size(0)
        all_preds.append(outputs)
        all_labels.append(labels)
    
    # Average loss
    avg_loss = running_loss / len(train_loader.dataset)
    
    return avg_loss

# test loop
def test_loop(model, test_loader, loss_fn, device):
    # Set model to evaluation mode
    model.eval()
    
    # Initialize
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels, clinical_data in test_loader:
        # for images, labels, _ in tqdm(val_loader, desc="Validating"):
            # Store labels since they won't be altered
            # all_labels.append(labels.numpy())
            
            # Shift to cuda
            images = images.to(device)
            labels = labels.to(device)
            clinical_data = clinical_data.to(device)
        
            # Forward pass
            outputs = model(images, clinical_data)
            
            # Get metrics
            loss = loss_fn(outputs, labels)
            running_loss += loss.item() * images.size(0)
            
            # Sigmoid activation to get probabilities, then threshold at 0.5 for binary classification
            preds = torch.sigmoid(outputs) > 0.5 
            # preds = (torch.sigmoid(outputs) > 0.5).int()  # Apply sigmoid and threshold at 0.5

            # Store (numpy for easier processing later)
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    # Calculate average loss
    avg_loss = running_loss / len(test_loader.dataset)

    # Convert lists of predictions and labels into a 2D array where each row is a sample, each column is a biomarker
    all_preds = np.concatenate(all_preds, axis=0)  # Shape: (num_samples, num_biomarkers)
    all_labels = np.concatenate(all_labels, axis=0)  # Shape: (num_samples, num_biomarkers)
    
    # Calculate F1 score for each biomarker (column) independently
#     f1_scores = []
#     for i in range(all_labels.shape[1]):  # Iterate over each biomarker
#         f1 = f1_score(all_labels[:, i], all_preds[:, i], average='binary')  # Compute F1 score for the ith biomarker
#         f1_scores.append(f1)

    # Average loss
    val_loss = running_loss / len(test_loader.dataset)
    # return val_loss, f1_scores, all_preds, all_labels
    return val_loss, all_preds, all_labels


# --- Cross-Validation ---

# Initialize object to split dataset in kfold
kfold = KFold(n_splits=k_folds, shuffle=True, random_state=0)

fold_metrics = []
    
for fold, (train_idx, val_idx) in enumerate(kfold.split(train_dataset)):
    print(f"Fold {fold+1}/{k_folds}")

    # Split the training dataset into training and validation folds
    train_fold = Subset(train_dataset, train_idx)
    val_fold = Subset(train_dataset, val_idx)
    
    # Set up Dataloaders
    train_loader = DataLoader(train_fold, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_fold, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    # Reset parameters
    best_val_loss = float('inf')
    best_val_f1 = 0.0
    counter_NoImprovement = 0

    for epoch in range(num_epochs):
    # for epoch in tqdm(range(num_epochs), desc="Training Epochs", unit="epoch"):
        print(f"Epoch {epoch+1}/{num_epochs}")

        # Train the model for one epoch
        train_loss = train_loop(model, train_loader, optimizer, loss_fn, device)
        print(f"Train Loss: {train_loss:.4f}")

        # Validate the model after training using validation fold
        val_loss, all_preds, all_labels = test_loop(model, val_loader, loss_fn, device)
        val_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
        print(f"Validation Loss: {val_loss:.4f}, Validation F1: {val_f1:.4f}")
        
        # Update learning rate
#         scheduler.step() # ExponentialLR after each epoch
        
        # Early stopping logic: Check if F1 improved
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            counter_NoImprovement = 0  # Reset counter
            torch.save(model.state_dict(), f"ModelWeights_TempSaves/best_{model.name}_fold_{fold+1}.pth")
        else:
            counter_NoImprovement += 1

        # Stop training if no improvement for 'patience' epochs
        if counter_NoImprovement >= patience:
            print("Early stopping triggered.")
            counter_NoImprovement = 0
            break
        
#         # Save the best model based on validation loss
#         if val_loss < best_val_loss:
#             best_val_loss = val_loss
#             torch.save(model.state_dict(), f"ModelWeights_TempSaves/best_{model.name}_fold_{fold+1}.pth")

    # Load the weights of the model with the best validationg loss
    model.load_state_dict(torch.load(f"ModelWeights_TempSaves/best_{model.name}_fold_{fold+1}.pth", weights_only=True))
 
    # Get Accuracy
    # preds = (torch.sigmoid(torch.tensor(all_preds)) > 0.5).int()  # Apply sigmoid and threshold at 0.5
    val_accuracy = accuracy_score(all_labels, all_preds)
    fold_metrics.append(val_accuracy)
    print(f"Validation Accuracy for Fold {fold+1}: {val_accuracy:.4f}")

avg_accuracy = np.mean(fold_metrics)
print(f"\nAverage Accuracy over all folds: {avg_accuracy:.4f}")


Fold 1/5
Epoch 1/40


Training: 100%|██████████| 90/90 [00:11<00:00,  7.63it/s]

Train Loss: 0.4899





Validation Loss: 0.2594, Validation F1: 0.5443
Epoch 2/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.14it/s]

Train Loss: 0.2235





Validation Loss: 0.1935, Validation F1: 0.7205
Epoch 3/40


Training: 100%|██████████| 90/90 [00:06<00:00, 13.67it/s]

Train Loss: 0.1714





Validation Loss: 0.1549, Validation F1: 0.7978
Epoch 4/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.03it/s]

Train Loss: 0.1415





Validation Loss: 0.1376, Validation F1: 0.8312
Epoch 5/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.06it/s]

Train Loss: 0.1196





Validation Loss: 0.1290, Validation F1: 0.8501
Epoch 6/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.08it/s]

Train Loss: 0.1104





Validation Loss: 0.1317, Validation F1: 0.8443
Epoch 7/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.25it/s]

Train Loss: 0.1027





Validation Loss: 0.1181, Validation F1: 0.8603
Epoch 8/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.30it/s]

Train Loss: 0.0986





Validation Loss: 0.1202, Validation F1: 0.8473
Epoch 9/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.20it/s]

Train Loss: 0.0920





Validation Loss: 0.1177, Validation F1: 0.8677
Epoch 10/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.11it/s]

Train Loss: 0.0862





Validation Loss: 0.1485, Validation F1: 0.8321
Epoch 11/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.23it/s]

Train Loss: 0.0838





Validation Loss: 0.1187, Validation F1: 0.8616
Epoch 12/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.02it/s]

Train Loss: 0.0830





Validation Loss: 0.1095, Validation F1: 0.8642
Epoch 13/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.19it/s]

Train Loss: 0.0771





Validation Loss: 0.1054, Validation F1: 0.8777
Epoch 14/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.05it/s]

Train Loss: 0.0759





Validation Loss: 0.1194, Validation F1: 0.8700
Epoch 15/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.21it/s]

Train Loss: 0.0753





Validation Loss: 0.1191, Validation F1: 0.8752
Epoch 16/40


Training: 100%|██████████| 90/90 [00:06<00:00, 13.97it/s]

Train Loss: 0.0697





Validation Loss: 0.1026, Validation F1: 0.8815
Epoch 17/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.24it/s]

Train Loss: 0.0700





Validation Loss: 0.1338, Validation F1: 0.8627
Epoch 18/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.05it/s]

Train Loss: 0.0683





Validation Loss: 0.1167, Validation F1: 0.8773
Epoch 19/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.21it/s]

Train Loss: 0.0647





Validation Loss: 0.1068, Validation F1: 0.8831
Epoch 20/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.12it/s]

Train Loss: 0.0653





Validation Loss: 0.1024, Validation F1: 0.8842
Epoch 21/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.17it/s]

Train Loss: 0.0605





Validation Loss: 0.1268, Validation F1: 0.8522
Epoch 22/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.06it/s]

Train Loss: 0.0618





Validation Loss: 0.1040, Validation F1: 0.8908
Epoch 23/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.13it/s]

Train Loss: 0.0587





Validation Loss: 0.1050, Validation F1: 0.8881
Epoch 24/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.25it/s]

Train Loss: 0.0527





Validation Loss: 0.0930, Validation F1: 0.8954
Epoch 25/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.07it/s]

Train Loss: 0.0490





Validation Loss: 0.0930, Validation F1: 0.9027
Epoch 26/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.26it/s]

Train Loss: 0.0438





Validation Loss: 0.0995, Validation F1: 0.9043
Epoch 27/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.19it/s]

Train Loss: 0.0410





Validation Loss: 0.0968, Validation F1: 0.9022
Epoch 28/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.22it/s]

Train Loss: 0.0386





Validation Loss: 0.1111, Validation F1: 0.8968
Epoch 29/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.22it/s]

Train Loss: 0.0352





Validation Loss: 0.0925, Validation F1: 0.9085
Epoch 30/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.13it/s]

Train Loss: 0.0303





Validation Loss: 0.1064, Validation F1: 0.9087
Epoch 31/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.05it/s]

Train Loss: 0.0297





Validation Loss: 0.1060, Validation F1: 0.9043
Epoch 32/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.14it/s]

Train Loss: 0.0245





Validation Loss: 0.1033, Validation F1: 0.9085
Epoch 33/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.10it/s]

Train Loss: 0.0198





Validation Loss: 0.1078, Validation F1: 0.9113
Epoch 34/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.04it/s]

Train Loss: 0.0176





Validation Loss: 0.1055, Validation F1: 0.9191
Epoch 35/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.19it/s]

Train Loss: 0.0137





Validation Loss: 0.1104, Validation F1: 0.9136
Epoch 36/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.22it/s]

Train Loss: 0.0136





Validation Loss: 0.1153, Validation F1: 0.9127
Epoch 37/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.04it/s]

Train Loss: 0.0108





Validation Loss: 0.1152, Validation F1: 0.9150
Epoch 38/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.21it/s]

Train Loss: 0.0086





Validation Loss: 0.1151, Validation F1: 0.9178
Epoch 39/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.12it/s]

Train Loss: 0.0065





Validation Loss: 0.1181, Validation F1: 0.9187
Epoch 40/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.11it/s]

Train Loss: 0.0050





Validation Loss: 0.1196, Validation F1: 0.9204
Validation Accuracy for Fold 1: 0.6438
Fold 2/5
Epoch 1/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.16it/s]

Train Loss: 0.0327





Validation Loss: 0.0106, Validation F1: 0.9922
Epoch 2/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.16it/s]

Train Loss: 0.0182





Validation Loss: 0.0081, Validation F1: 0.9930
Epoch 3/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.08it/s]

Train Loss: 0.0133





Validation Loss: 0.0075, Validation F1: 0.9929
Epoch 4/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.08it/s]

Train Loss: 0.0107





Validation Loss: 0.0063, Validation F1: 0.9949
Epoch 5/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.12it/s]

Train Loss: 0.0089





Validation Loss: 0.0062, Validation F1: 0.9951
Epoch 6/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.20it/s]

Train Loss: 0.0089





Validation Loss: 0.0069, Validation F1: 0.9943
Epoch 7/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.08it/s]

Train Loss: 0.0089





Validation Loss: 0.0084, Validation F1: 0.9914
Epoch 8/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.21it/s]

Train Loss: 0.0085





Validation Loss: 0.0089, Validation F1: 0.9901
Epoch 9/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.39it/s]

Train Loss: 0.0090





Validation Loss: 0.0094, Validation F1: 0.9914
Epoch 10/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.07it/s]

Train Loss: 0.0118





Validation Loss: 0.0160, Validation F1: 0.9829
Epoch 11/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.21it/s]

Train Loss: 0.0134





Validation Loss: 0.0167, Validation F1: 0.9813
Epoch 12/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.15it/s]

Train Loss: 0.0152





Validation Loss: 0.0236, Validation F1: 0.9724
Epoch 13/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.14it/s]

Train Loss: 0.0172





Validation Loss: 0.0303, Validation F1: 0.9661
Early stopping triggered.
Validation Accuracy for Fold 2: 0.8321
Fold 3/5
Epoch 1/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.06it/s]

Train Loss: 0.0242





Validation Loss: 0.0275, Validation F1: 0.9727
Epoch 2/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.20it/s]

Train Loss: 0.0295





Validation Loss: 0.0349, Validation F1: 0.9637
Epoch 3/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.03it/s]

Train Loss: 0.0284





Validation Loss: 0.0338, Validation F1: 0.9628
Epoch 4/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.24it/s]

Train Loss: 0.0264





Validation Loss: 0.0379, Validation F1: 0.9588
Epoch 5/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.32it/s]

Train Loss: 0.0256





Validation Loss: 0.0490, Validation F1: 0.9461
Epoch 6/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.20it/s]

Train Loss: 0.0256





Validation Loss: 0.0497, Validation F1: 0.9445
Epoch 7/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.25it/s]

Train Loss: 0.0277





Validation Loss: 0.0651, Validation F1: 0.9343
Epoch 8/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.27it/s]

Train Loss: 0.0300





Validation Loss: 0.0529, Validation F1: 0.9431
Epoch 9/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.26it/s]

Train Loss: 0.0278





Validation Loss: 0.0673, Validation F1: 0.9336
Early stopping triggered.
Validation Accuracy for Fold 3: 0.7047
Fold 4/5
Epoch 1/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.15it/s]

Train Loss: 0.0546





Validation Loss: 0.0485, Validation F1: 0.9481
Epoch 2/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.14it/s]

Train Loss: 0.0472





Validation Loss: 0.0597, Validation F1: 0.9329
Epoch 3/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.09it/s]

Train Loss: 0.0431





Validation Loss: 0.0612, Validation F1: 0.9305
Epoch 4/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.12it/s]

Train Loss: 0.0406





Validation Loss: 0.0707, Validation F1: 0.9266
Epoch 5/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.10it/s]

Train Loss: 0.0412





Validation Loss: 0.0795, Validation F1: 0.9113
Epoch 6/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.19it/s]

Train Loss: 0.0340





Validation Loss: 0.0603, Validation F1: 0.9352
Epoch 7/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.09it/s]

Train Loss: 0.0295





Validation Loss: 0.0601, Validation F1: 0.9367
Epoch 8/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.35it/s]

Train Loss: 0.0233





Validation Loss: 0.0654, Validation F1: 0.9365
Epoch 9/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.27it/s]

Train Loss: 0.0224





Validation Loss: 0.0520, Validation F1: 0.9468
Early stopping triggered.
Validation Accuracy for Fold 4: 0.7418
Fold 5/5
Epoch 1/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.11it/s]

Train Loss: 0.0463





Validation Loss: 0.0452, Validation F1: 0.9491
Epoch 2/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.14it/s]

Train Loss: 0.0395





Validation Loss: 0.0422, Validation F1: 0.9552
Epoch 3/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.13it/s]

Train Loss: 0.0300





Validation Loss: 0.0393, Validation F1: 0.9571
Epoch 4/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.19it/s]

Train Loss: 0.0263





Validation Loss: 0.0433, Validation F1: 0.9534
Epoch 5/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.10it/s]

Train Loss: 0.0222





Validation Loss: 0.0408, Validation F1: 0.9550
Epoch 6/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.07it/s]

Train Loss: 0.0173





Validation Loss: 0.0347, Validation F1: 0.9635
Epoch 7/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.17it/s]

Train Loss: 0.0140





Validation Loss: 0.0394, Validation F1: 0.9575
Epoch 8/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.08it/s]

Train Loss: 0.0130





Validation Loss: 0.0354, Validation F1: 0.9635
Epoch 9/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.00it/s]

Train Loss: 0.0095





Validation Loss: 0.0327, Validation F1: 0.9650
Epoch 10/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.06it/s]

Train Loss: 0.0080





Validation Loss: 0.0326, Validation F1: 0.9660
Epoch 11/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.00it/s]

Train Loss: 0.0053





Validation Loss: 0.0307, Validation F1: 0.9686
Epoch 12/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.15it/s]

Train Loss: 0.0036





Validation Loss: 0.0306, Validation F1: 0.9710
Epoch 13/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.11it/s]

Train Loss: 0.0027





Validation Loss: 0.0293, Validation F1: 0.9738
Epoch 14/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.08it/s]

Train Loss: 0.0022





Validation Loss: 0.0280, Validation F1: 0.9729
Epoch 15/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.11it/s]

Train Loss: 0.0021





Validation Loss: 0.0286, Validation F1: 0.9720
Epoch 16/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.20it/s]

Train Loss: 0.0016





Validation Loss: 0.0267, Validation F1: 0.9737
Epoch 17/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.12it/s]

Train Loss: 0.0013





Validation Loss: 0.0266, Validation F1: 0.9760
Epoch 18/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.05it/s]

Train Loss: 0.0012





Validation Loss: 0.0294, Validation F1: 0.9726
Epoch 19/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.10it/s]

Train Loss: 0.0010





Validation Loss: 0.0279, Validation F1: 0.9727
Epoch 20/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.15it/s]

Train Loss: 0.0010





Validation Loss: 0.0267, Validation F1: 0.9736
Epoch 21/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.15it/s]

Train Loss: 0.0011





Validation Loss: 0.0271, Validation F1: 0.9753
Epoch 22/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.13it/s]

Train Loss: 0.0015





Validation Loss: 0.0295, Validation F1: 0.9745
Epoch 23/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.23it/s]

Train Loss: 0.0015





Validation Loss: 0.0334, Validation F1: 0.9713
Epoch 24/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.19it/s]

Train Loss: 0.0022





Validation Loss: 0.0384, Validation F1: 0.9670
Epoch 25/40


Training: 100%|██████████| 90/90 [00:06<00:00, 14.04it/s]

Train Loss: 0.0039





Validation Loss: 0.0420, Validation F1: 0.9656
Early stopping triggered.
Validation Accuracy for Fold 5: 0.8312

Average Accuracy over all folds: 0.7507


Now Test

In [7]:
# ============================
# 5. Evaluation
# ============================

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def evaluate_model(model, loader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating", unit="batch"):
            if not batch:  # Handle empty batches
                continue
            images, labels, clinical_data = batch
            images, labels, clinical_data = images.to(device), labels.to(device), clinical_data.to(device)
            outputs = model(images,clinical_data) # .logits
            y_true.append(labels.cpu().numpy())
            y_pred.append(outputs.cpu().numpy())
    y_true = np.vstack(y_true)
    y_pred = np.vstack(y_pred)
    return y_true, y_pred

# Evaluate
y_true, y_pred = evaluate_model(model, testloader)

# Convert predicted probabilities to binary predictions
y_pred_binary = (sigmoid(y_pred) > 0.5).astype(int)
# y_pred_binary = (y_pred > 0.5).astype(int)

# Ensure `y_true` is binary
y_true_binary = (y_true > 0.5).astype(int)

# Metrics
report = classification_report(y_true_binary, y_pred_binary,zero_division=0)
report_data = classification_report(y_true_binary, y_pred_binary, output_dict=True,zero_division=0) # this is not clean to print but easier to extract
weighted_f1 = report_data['weighted avg']['f1-score']
macro_f1 = report_data['macro avg']['f1-score']
print("Classification Report:")
print(report)

# Store the predicitons in a csv file
# get the biomarker names
BiomarkerLabel_df = pd.read_csv('OLIVES_Dataset_Labels/BiomarkerLabel_train_data.csv')
biomarkers = BiomarkerLabel_df.columns[2:18] # Extract label columns names (biomarkers)

# Convert the predictions into a pandas DataFrame with biomarker names as columns
df_predictions = pd.DataFrame(y_pred_binary, columns=biomarkers)

# Add the "Index" name (as the row index)
df_predictions.insert(0, "Index", df_predictions.index.to_series().apply(lambda x: f"{x+1:04d}"))

# Save the DataFrame to a CSV file
df_predictions.to_csv("predictions_biomarkers.csv",  index=False)

Evaluating: 100%|██████████| 112/112 [00:03<00:00, 29.82batch/s]


Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        69
           1       0.99      0.96      0.98       516
           2       0.97      0.97      0.97        29
           3       0.99      0.97      0.98       277
           4       0.99      0.99      0.99      4699
           5       1.00      1.00      1.00      2130
           6       1.00      1.00      1.00      4068
           7       1.00      0.99      1.00       677
           8       0.99      0.98      0.98      2102
           9       1.00      1.00      1.00         7
          10       1.00      1.00      1.00      2285
          11       1.00      1.00      1.00      3028
          12       1.00      1.00      1.00       180
          13       1.00      0.89      0.94         9
          14       1.00      1.00      1.00        10
          15       0.98      0.89      0.94        57

   micro avg       1.00      0.99      0.99     20143
   

Save the trained model

In [8]:
# Save model weights to a file
torch.save(model.state_dict(), f"TrainedModels/{model.name}/{model.name}_f1weighted{weighted_f1:.4f}_k{k_folds}_e{num_epochs}_p{patience}_weights.pth")
# torch.save(model.state_dict(), f"TrainedModels/{model.name}/{model.name}_f1samples{samples_f1:.4f}_k{k_folds}_e{num_epochs}_p{patience}_weights.pth")

Include a grad cam

In [1]:
# Define Grad-CAM class for ResNet50
class GradCam:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        # Register hooks
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output  # Save activations

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]  # Save gradients

        # Find the target layer
        for name, module in self.model.image_model.named_modules():
            if name == self.target_layer:
                module.register_forward_hook(forward_hook)
                module.register_backward_hook(backward_hook)

    def generate(self, input_image, clinical_data=None, target_class=None):
        # Forward pass
        output = self.model(input_image, clinical_data)

        # Use the top predicted class if target_class is not specified
        if target_class is None:
            target_class = output.argmax(dim=1).item()

        # Backward pass
        self.model.zero_grad()
        class_score = output[0, target_class]
        class_score.backward()

        # Ensure gradients and activations are captured
        if self.gradients is None or self.activations is None:
            raise ValueError("Gradients or activations were not captured properly.")

        # Compute Grad-CAM
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)  # Average gradients spatially
        cam = torch.sum(weights * self.activations, dim=1).squeeze()  # Weighted sum of activations

        # Apply ReLU and normalize to [0, 1]
        cam = torch.relu(cam).detach().cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min())

        return cam

    def overlay_heatmap(self, heatmap, original_image, alpha=0.4):
        # Resize heatmap to match the original image size
        heatmap = cv2.resize(heatmap, (original_image.size[1], original_image.size[0]))
        heatmap = np.uint8(255 * heatmap)  # Convert to uint8

        # Create heatmap image
        heatmap_img = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

        # Overlay heatmap
        overlay = cv2.addWeighted(original_image, 1 - alpha, heatmap_img, alpha, 0)
        return overlay
    
# Visualization function for Grad-CAM
def visualize_gradcam(model, gradcam, data_loader, labels, num_samples=5):
    model.eval()
    count = 0

    for images, _, clinical_data in data_loader:
        images = images.to(device)
        clinical_data = clinical_data.to(device)

        for i in range(images.size(0)):
            input_image = images[i].unsqueeze(0)  # Single image
            input_clinical = clinical_data[i].unsqueeze(0)
            original_image = np.array(to_pil_image(input_image.squeeze(0).cpu()))  # Convert to numpy

            for label_idx, label_name in enumerate(labels):
                # Generate Grad-CAM heatmap
                heatmap = gradcam.generate(input_image, clinical_data=input_clinical, target_class=label_idx)

                # Overlay heatmap on the original image
                overlay = gradcam.overlay_heatmap(heatmap, original_image)

                # Plot
                plt.figure(figsize=(8, 4))
                plt.subplot(1, 2, 1)
                plt.title(f"Original Image: {label_name}")
                plt.imshow(original_image, cmap="gray")
                plt.axis("off")

                plt.subplot(1, 2, 2)
                plt.title(f"Grad-CAM Heatmap: {label_name}")
                plt.imshow(overlay)
                plt.axis("off")

                plt.show()

                count += 1
                if count >= num_samples:
                    return
        break

# Initialize Grad-CAM with the chosen target layer
gradcam = GradCam(model.image_model, target_layer="layer4")

# Visualize Grad-CAM
visualize_gradcam(model, gradcam, data_loader, labels, num_samples=5)
