# Biomarker detection in OLIVES using pretrained Models


### Step 1: Import data
Consistent for all models. Only change output size!

In [1]:
import os
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import Dataset, Subset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, f1_score, classification_report
import numpy as np
from tqdm import tqdm

# set the size of the image according to your model needs
imageSize = 224 # ResNet works with 224x224 pixels

# Custom Dataset
class BiomarkerDataset(Dataset):
    def __init__(self, label_file, transform=None, num_frames=0):
        """
        Args:
            label_file (str): Path to the CSV file.
            transform (callable, optional): Transform to be applied on a sample.
            num_frames (int): Number of adjacent frames to use in the input sequence (1 adjacent frame -> 3 consecutive images).
        """
        self.data = pd.read_csv(label_file)
        self.transform = transform
        self.num_frames = num_frames
        
        # Exclude indices which don't have enough adjacent images
        self.valid_indices = self.data[(self.data.iloc[:, 1] > num_frames) & (self.data.iloc[:, 1] < (50-num_frames))].index.tolist()

    def __len__(self):
        # we can't use the length of the data since we have to exclude the first and last image (for num_frames=1) of each OCT scan
        return len(self.valid_indices)

    def __getitem__(self, idx):
        
        # Base path
        img_base_path = '/storage/ice1/shared/d-pace_community/makerspace-datasets/MEDICAL/OLIVES/OLIVES'
        
        # Get the actual data index
        index = self.valid_indices[idx]
        
        # Initialize
        images = []
        
        # Load a sequence of consecutive images
        for i in range(index - self.num_frames, index + self.num_frames +1):
            img_path = img_base_path + self.data.iloc[i, 0]
            img = Image.open(img_path).convert("L") # 'L' is for grayscale; can be removed!?
            
            if self.transform is not None:
                # apply data transformations (transforms it to tensor)
                img = self.transform(img)
            
            # stack torch tensor
            img = img.squeeze(0)  # Removes the first dimension if it's 1
            images.append(img)
        
        # Stack the 3 grayscale images along the channel dimension
        # Resulting tensor shape will be [3, H, W]
        images = torch.stack(images, dim=0)
        # print(images.shape) # debugging
        
        # Biomarker columns
        labels = torch.tensor(self.data.iloc[index, 2:18].astype(float), dtype=torch.float32)
        
        # Get extra clinical data
        clinical_data = {
            "Eye_ID": self.data.iloc[index, 18],
            "BCVA": self.data.iloc[index, 19],
            "CST": self.data.iloc[index, 20],
            "Patient_ID": self.data.iloc[index, 21],
        }
        
        return images, labels, clinical_data
    
    
# Define transformers

# Values for normalization taken from example paper
mean = 0.1706
std = 0.2112

# train with data augmentation
train_transformer = transforms.Compose([   
    # transforms.RandomCrop((0.7, 1.0)),  # RandomCrop between 70% to 100% of original size
    # transforms.RandomPerspective(distortion_scale=0.2, p=0.5, fill=0),  # Add perspective shift
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Adjust color properties
    transforms.RandomRotation(degrees=10, fill=0),  # Rotates randomly between + and - degree and fills new pixels with black
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip
    transforms.Resize(imageSize), # Resize to models needs
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean, std) # we have to calculate these values for our dataset
])
# train without data augmentation
test_transformer = transforms.Compose([   
    transforms.Resize(imageSize), # Resize to models needs
    transforms.CenterCrop(imageSize), # shouldn't do anything
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])


# set up train loader (just example since cross validation uses new ones)
train_dataset = BiomarkerDataset(label_file='OLIVES_Dataset_Labels/BiomarkerLabel_train_data.csv', transform=train_transformer, num_frames=1)
trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, drop_last=True, pin_memory=True)

# set up test loader (this one actually is being used)
test_dataset = BiomarkerDataset(label_file='OLIVES_Dataset_Labels/BiomarkerLabel_train_data.csv', transform=test_transformer, num_frames=1)
testloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=32, pin_memory=True)


### Step 2: Train model
Could be easily adapted for different models. Uses cross-validation.

In [10]:
## --- Settings ---
num_epochs=10
batch_size=64
num_workers=32 # need this amount of CPUs for parallel data loading
k_folds=5
patience=5  # Number of epochs to wait for improvement

# get to cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## ---- Model ----
# Import pretrained model (choose one of them based on performance)
# model = models.resnet50(pretrained=True)
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) # this is like pre-trained true
model.name = "ResNet50"
# Adapt it to the given task (update final layer)
model.fc = nn.Linear(model.fc.in_features, 16)  # Number of output classes: 16 (biomarkers)
# shift to GPU
model = model.to(device)

# Loss function, optimizer nad Learning rate sheduler
loss_fn = nn.BCEWithLogitsLoss()  # For multi-label classification
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=0.9) # NOT GOOD: This set every f1 score to 0!!
    # weight decay to reduce overfitting 
scheduler = ExponentialLR(optimizer, gamma=0.9)

# Creates all needed folders to store the model weights if they don't exist already
os.makedirs("ModelWeights_TempSaves", exist_ok=True)
os.makedirs(f"TrainedModels/{model.name}", exist_ok=True)



In [3]:
# --- Train/Test Loops ---
# Training loop
def train_loop(model, train_loader, optimizer, loss_fn, device):
    # Set model to train mode
    model.train()
    
    # Initialize
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    # for images, labels, _ in train_loader:
    for images, labels, _ in tqdm(train_loader, desc="Training"):
        # shift to cuda
        images = images.to(device)
        labels = labels.to(device)
        
        # Zero the parameter gradients 
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Track predictions and labels for metrics calculation
        running_loss += loss.item() * images.size(0)
        all_preds.append(outputs)
        all_labels.append(labels)
    
    # Average loss
    avg_loss = running_loss / len(train_loader.dataset)
    
    return avg_loss

# test loop
def test_loop(model, test_loader, loss_fn, device):
    # Set model to evaluation mode
    model.eval()
    
    # Initialize
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels, _ in test_loader:
        # for images, labels, _ in tqdm(val_loader, desc="Validating"):
            # Store labels since they won't be altered
            # all_labels.append(labels.numpy())
            
            # Shift to cuda
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Get metrics
            loss = loss_fn(outputs, labels)
            running_loss += loss.item() * images.size(0)
            
            # Sigmoid activation to get probabilities, then threshold at 0.5 for binary classification
            preds = torch.sigmoid(outputs) > 0.5 
            # preds = (torch.sigmoid(outputs) > 0.5).int()  # Apply sigmoid and threshold at 0.5

            # Store (numpy for easier processing later)
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    # Calculate average loss
    avg_loss = running_loss / len(test_loader.dataset)

    # Convert lists of predictions and labels into a 2D array where each row is a sample, each column is a biomarker
    all_preds = np.concatenate(all_preds, axis=0)  # Shape: (num_samples, num_biomarkers)
    all_labels = np.concatenate(all_labels, axis=0)  # Shape: (num_samples, num_biomarkers)
    
    # Calculate F1 score for each biomarker (column) independently
#     f1_scores = []
#     for i in range(all_labels.shape[1]):  # Iterate over each biomarker
#         f1 = f1_score(all_labels[:, i], all_preds[:, i], average='binary')  # Compute F1 score for the ith biomarker
#         f1_scores.append(f1)

    # Average loss
    val_loss = running_loss / len(test_loader.dataset)
    # return val_loss, f1_scores, all_preds, all_labels
    return val_loss, all_preds, all_labels


# --- Cross-Validation ---

# Initialize object to split dataset in kfold
kfold = KFold(n_splits=k_folds, shuffle=True, random_state=0)

fold_metrics = []
    
for fold, (train_idx, val_idx) in enumerate(kfold.split(train_dataset)):
    print(f"Fold {fold+1}/{k_folds}")

    # Split the training dataset into training and validation folds
    train_fold = Subset(train_dataset, train_idx)
    val_fold = Subset(train_dataset, val_idx)
    
    # Set up Dataloaders
    train_loader = DataLoader(train_fold, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_fold, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    # Reset parameters
    best_val_loss = float('inf')
    best_val_f1 = 0.0
    counter_NoImprovement = 0

    for epoch in range(num_epochs):
    # for epoch in tqdm(range(num_epochs), desc="Training Epochs", unit="epoch"):
        print(f"Epoch {epoch+1}/{num_epochs}")

        # Train the model for one epoch
        train_loss = train_loop(model, train_loader, optimizer, loss_fn, device)
        print(f"Train Loss: {train_loss:.4f}")

        # Validate the model after training using validation fold
        val_loss, all_preds, all_labels = test_loop(model, val_loader, loss_fn, device)
        val_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
        print(f"Validation Loss: {val_loss:.4f}, Validation F1: {val_f1:.4f}")
        
        # Update learning rate
        scheduler.step()
        
        # Validate the model after training using validation fold
        val_loss, all_preds, all_labels = test_loop(model, val_loader, loss_fn, device)
        print(f"Validation Loss: {val_loss:.4f}")
        
#         # Early stopping logic: Check if F1 improved
#         if val_f1 > best_val_f1:
#             best_val_f1 = val_f1
#             no_improvement = 0  # Reset counter
#             torch.save(model.state_dict(), f"ModelWeights_TempSaves/best_{model.name}_fold_{fold+1}.pth")
#         else:
#             no_improvement += 1

#         # Stop training if no improvement for 'patience' epochs
#         if no_improvement >= patience:
#             print("Early stopping triggered.")
#             break
        
        # Save the best model based on validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"ModelWeights_TempSaves/best_{model.name}_fold_{fold+1}.pth")

    # Load the weights of the model with the best validationg loss
    model.load_state_dict(torch.load(f"ModelWeights_TempSaves/best_{model.name}_fold_{fold+1}.pth", weights_only=True))
    
#     # Convert lists of tensors to numpy arrays for evaluation
#     all_preds = torch.cat(all_preds, dim=0).cpu().numpy()
#     all_labels = torch.cat(all_labels, dim=0).cpu().numpy()
 
    # Get Accuracy
    # preds = (torch.sigmoid(torch.tensor(all_preds)) > 0.5).int()  # Apply sigmoid and threshold at 0.5
    val_accuracy = accuracy_score(all_labels, all_preds)
    fold_metrics.append(val_accuracy)
    print(f"Validation Accuracy for Fold {fold+1}: {val_accuracy:.4f}")

avg_accuracy = np.mean(fold_metrics)
print(f"\nAverage Accuracy over all folds: {avg_accuracy:.4f}")


Fold 1/5
Epoch 1/10


Training: 100%|██████████| 90/90 [00:07<00:00, 11.40it/s]

Train Loss: 0.3048





Validation Loss: 0.2070, Validation F1: 0.6765
Validation Loss: 0.2078
Epoch 2/10


Training: 100%|██████████| 90/90 [00:05<00:00, 16.02it/s]

Train Loss: 0.1776





Validation Loss: 0.1657, Validation F1: 0.7688
Validation Loss: 0.1654
Epoch 3/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.24it/s]

Train Loss: 0.1445





Validation Loss: 0.1377, Validation F1: 0.8200
Validation Loss: 0.1386
Epoch 4/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.06it/s]

Train Loss: 0.1231





Validation Loss: 0.1275, Validation F1: 0.8351
Validation Loss: 0.1259
Epoch 5/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.14it/s]

Train Loss: 0.1109





Validation Loss: 0.1262, Validation F1: 0.8394
Validation Loss: 0.1243
Epoch 6/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.05it/s]

Train Loss: 0.0992





Validation Loss: 0.1089, Validation F1: 0.8656
Validation Loss: 0.1085
Epoch 7/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.34it/s]

Train Loss: 0.0910





Validation Loss: 0.1058, Validation F1: 0.8707
Validation Loss: 0.1054
Epoch 8/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.33it/s]

Train Loss: 0.0838





Validation Loss: 0.0999, Validation F1: 0.8797
Validation Loss: 0.0989
Epoch 9/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.31it/s]

Train Loss: 0.0792





Validation Loss: 0.0994, Validation F1: 0.8816
Validation Loss: 0.0977
Epoch 10/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.99it/s]

Train Loss: 0.0735





Validation Loss: 0.0933, Validation F1: 0.8892
Validation Loss: 0.0939
Validation Accuracy for Fold 1: 0.5570
Fold 2/5
Epoch 1/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.18it/s]

Train Loss: 0.0770





Validation Loss: 0.0647, Validation F1: 0.9295
Validation Loss: 0.0638
Epoch 2/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.26it/s]

Train Loss: 0.0712





Validation Loss: 0.0620, Validation F1: 0.9253
Validation Loss: 0.0611
Epoch 3/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.04it/s]

Train Loss: 0.0680





Validation Loss: 0.0646, Validation F1: 0.9251
Validation Loss: 0.0635
Epoch 4/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.36it/s]

Train Loss: 0.0646





Validation Loss: 0.0624, Validation F1: 0.9269
Validation Loss: 0.0631
Epoch 5/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.46it/s]

Train Loss: 0.0621





Validation Loss: 0.0641, Validation F1: 0.9248
Validation Loss: 0.0644
Epoch 6/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.39it/s]

Train Loss: 0.0594





Validation Loss: 0.0588, Validation F1: 0.9292
Validation Loss: 0.0622
Epoch 7/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.09it/s]

Train Loss: 0.0582





Validation Loss: 0.0599, Validation F1: 0.9287
Validation Loss: 0.0613
Epoch 8/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.82it/s]

Train Loss: 0.0562





Validation Loss: 0.0615, Validation F1: 0.9282
Validation Loss: 0.0601
Epoch 9/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.95it/s]

Train Loss: 0.0530





Validation Loss: 0.0624, Validation F1: 0.9240
Validation Loss: 0.0602
Epoch 10/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.21it/s]

Train Loss: 0.0522





Validation Loss: 0.0602, Validation F1: 0.9287
Validation Loss: 0.0594
Validation Accuracy for Fold 2: 0.6732
Fold 3/5
Epoch 1/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.02it/s]

Train Loss: 0.0545





Validation Loss: 0.0473, Validation F1: 0.9476
Validation Loss: 0.0462
Epoch 2/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.01it/s]

Train Loss: 0.0524





Validation Loss: 0.0477, Validation F1: 0.9425
Validation Loss: 0.0475
Epoch 3/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.93it/s]

Train Loss: 0.0510





Validation Loss: 0.0492, Validation F1: 0.9433
Validation Loss: 0.0486
Epoch 4/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.28it/s]

Train Loss: 0.0502





Validation Loss: 0.0491, Validation F1: 0.9422
Validation Loss: 0.0490
Epoch 5/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.28it/s]

Train Loss: 0.0484





Validation Loss: 0.0495, Validation F1: 0.9416
Validation Loss: 0.0496
Epoch 6/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.17it/s]

Train Loss: 0.0477





Validation Loss: 0.0492, Validation F1: 0.9437
Validation Loss: 0.0488
Epoch 7/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.23it/s]

Train Loss: 0.0461





Validation Loss: 0.0491, Validation F1: 0.9406
Validation Loss: 0.0486
Epoch 8/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.03it/s]

Train Loss: 0.0470





Validation Loss: 0.0486, Validation F1: 0.9425
Validation Loss: 0.0481
Epoch 9/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.40it/s]

Train Loss: 0.0460





Validation Loss: 0.0504, Validation F1: 0.9390
Validation Loss: 0.0496
Epoch 10/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.88it/s]

Train Loss: 0.0455





Validation Loss: 0.0497, Validation F1: 0.9398
Validation Loss: 0.0509
Validation Accuracy for Fold 3: 0.7215
Fold 4/5
Epoch 1/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.89it/s]

Train Loss: 0.0515





Validation Loss: 0.0443, Validation F1: 0.9483
Validation Loss: 0.0437
Epoch 2/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.44it/s]

Train Loss: 0.0515





Validation Loss: 0.0418, Validation F1: 0.9535
Validation Loss: 0.0433
Epoch 3/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.80it/s]

Train Loss: 0.0507





Validation Loss: 0.0428, Validation F1: 0.9511
Validation Loss: 0.0434
Epoch 4/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.81it/s]

Train Loss: 0.0510





Validation Loss: 0.0436, Validation F1: 0.9482
Validation Loss: 0.0434
Epoch 5/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.95it/s]

Train Loss: 0.0511





Validation Loss: 0.0441, Validation F1: 0.9512
Validation Loss: 0.0449
Epoch 6/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.21it/s]

Train Loss: 0.0494





Validation Loss: 0.0438, Validation F1: 0.9513
Validation Loss: 0.0445
Epoch 7/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.26it/s]

Train Loss: 0.0494





Validation Loss: 0.0437, Validation F1: 0.9494
Validation Loss: 0.0432
Epoch 8/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.88it/s]

Train Loss: 0.0485





Validation Loss: 0.0443, Validation F1: 0.9509
Validation Loss: 0.0445
Epoch 9/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.38it/s]

Train Loss: 0.0504





Validation Loss: 0.0445, Validation F1: 0.9495
Validation Loss: 0.0436
Epoch 10/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.05it/s]

Train Loss: 0.0480





Validation Loss: 0.0432, Validation F1: 0.9508
Validation Loss: 0.0442
Validation Accuracy for Fold 4: 0.7502
Fold 5/5
Epoch 1/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.00it/s]

Train Loss: 0.0506





Validation Loss: 0.0406, Validation F1: 0.9522
Validation Loss: 0.0407
Epoch 2/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.38it/s]

Train Loss: 0.0496





Validation Loss: 0.0407, Validation F1: 0.9524
Validation Loss: 0.0406
Epoch 3/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.89it/s]

Train Loss: 0.0497





Validation Loss: 0.0423, Validation F1: 0.9505
Validation Loss: 0.0409
Epoch 4/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.19it/s]

Train Loss: 0.0490





Validation Loss: 0.0419, Validation F1: 0.9512
Validation Loss: 0.0415
Epoch 5/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.18it/s]

Train Loss: 0.0492





Validation Loss: 0.0423, Validation F1: 0.9511
Validation Loss: 0.0422
Epoch 6/10


Training: 100%|██████████| 90/90 [00:06<00:00, 14.91it/s]

Train Loss: 0.0486





Validation Loss: 0.0413, Validation F1: 0.9508
Validation Loss: 0.0423
Epoch 7/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.01it/s]

Train Loss: 0.0486





Validation Loss: 0.0409, Validation F1: 0.9525
Validation Loss: 0.0410
Epoch 8/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.06it/s]

Train Loss: 0.0483





Validation Loss: 0.0420, Validation F1: 0.9521
Validation Loss: 0.0425
Epoch 9/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.41it/s]

Train Loss: 0.0483





Validation Loss: 0.0413, Validation F1: 0.9498
Validation Loss: 0.0418
Epoch 10/10


Training: 100%|██████████| 90/90 [00:05<00:00, 15.44it/s]

Train Loss: 0.0487





Validation Loss: 0.0414, Validation F1: 0.9512
Validation Loss: 0.0419
Validation Accuracy for Fold 5: 0.7731

Average Accuracy over all folds: 0.6950


Now Test

In [8]:
# ============================
# 5. Evaluation
# ============================

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def evaluate_model(model, loader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for batch in loader:
            if not batch:  # Handle empty batches
                continue

            images, labels, clinical_data = batch
            images, labels = images.to(device), labels.to(device)
            outputs = model(images) # .logits
            y_true.append(labels.cpu().numpy())
            y_pred.append(outputs.cpu().numpy())
    y_true = np.vstack(y_true)
    y_pred = np.vstack(y_pred)
    return y_true, y_pred

# Evaluate
y_true, y_pred = evaluate_model(model, testloader)

# Convert predicted probabilities to binary predictions
y_pred_binary = (sigmoid(y_pred) > 0.5).astype(int)
# y_pred_binary = (y_pred > 0.5).astype(int)

# Ensure `y_true` is binary
y_true_binary = (y_true > 0.5).astype(int)

# Metrics
report = classification_report(y_true_binary, y_pred_binary,zero_division=0)
report_data = classification_report(y_true_binary, y_pred_binary, output_dict=True,zero_division=0) # this is not clean to print but easier to extract
weighted_f1 = report_data['weighted avg']['f1-score']
samples_f1 = report_data['samples avg']['f1-score']
print("Classification Report:")
print(report)

Classification Report:
              precision    recall  f1-score   support

           0       0.91      0.59      0.72        69
           1       0.90      0.69      0.78       516
           2       0.92      0.83      0.87        29
           3       0.96      0.67      0.79       277
           4       0.92      0.92      0.92      4699
           5       0.99      0.98      0.98      2130
           6       1.00      0.99      0.99      4068
           7       0.98      0.95      0.96       677
           8       0.90      0.90      0.90      2102
           9       0.00      0.00      0.00         7
          10       0.99      0.98      0.98      2285
          11       0.96      0.94      0.95      3028
          12       0.97      0.92      0.94       180
          13       0.00      0.00      0.00         9
          14       1.00      0.50      0.67        10
          15       0.83      0.35      0.49        57

   micro avg       0.96      0.94      0.95     20143
   

Save the trained model

In [9]:
# Save model weights to a file
torch.save(model.state_dict(), f"TrainedModels/{model.name}/{model.name}_f1weighted{weighted_f1:.4f}_k{k_folds}_e{num_epochs}_weights.pth")
# torch.save(model.state_dict(), f"TrainedModels/{model.name}/{model.name}_f1samples{samples_f1:.4f}_k{k_folds}_e{num_epochs}_weights.pth")