In [1]:
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import os
from PIL import Image
from sklearn.model_selection import train_test_split

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

In [2]:
# Dataset Load and split

# 96 eyes, 49 OCT images, 2 visits, 16 biomarkers (binary)
# 96 eyes, 49 OCT images, 2 visits, 496 x 504 OCT images (grayscale)
scan_N = 9408
oct_N = 49
eye_N = 96
sh = [496, 504]

csv_file = '~/scratch/OLIVES/OLIVES/Biomarker_Clinical_Data_Images_Updated.csv'
data = pd.read_csv(csv_file)
col_names = data.columns
file_paths = data['Path (Trial/Arm/Folder/Visit/Eye/Image Name)'].values #[9408,]
file_paths = file_paths.reshape([eye_N,2*oct_N])
bio_markers = data[col_names[2:18]].values
bio_markers = bio_markers.reshape([eye_N,2*oct_N,-1])

clin_data = data[col_names[19:21]].values
clin_data = clin_data.reshape([eye_N,2*oct_N,-1])

home_dir = '/home/hice1/hsuh45/scratch/OLIVES/OLIVES/'



In [3]:
# Check for rows with Nan and identify the rows (get rid of them after Data split)
rows_with_nan = data[data.isna().any(axis=1)]
# data = data.dropna()

In [5]:
# DeiT preprocessing

transform_deit = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3 channels
    transforms.CenterCrop(496),
    transforms.Resize((384,384)),                # 224x224 or 384x384
    transforms.ToTensor(),                        # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], # Normalize with ImageNet stats
        std=[0.229, 0.224, 0.225])
])


# Create DataLoaders with the preprocessed data
class OCTDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, index):
        
        img = Image.open(home_dir + self.file_paths[index][0]).convert("L")
        label = self.labels[index]  # Shape: [sample N, bio_marker_N]
        
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(label, dtype=torch.float32)
        
        return img, label



In [6]:
'''
Data setup - No Stratification (OCT only)

'''

# Eye-wise split
# Split dataset into train/val/test
train_val_files, test_files, train_val_labels, test_labels = train_test_split(
    file_paths, bio_markers, test_size=0.2, random_state=42
)

train_files, val_files, train_labels, val_labels = train_test_split(
    train_val_files, train_val_labels, test_size=0.25, random_state=42
)
print(train_files.shape, val_files.shape, test_files.shape)
# Eye-wise -> scan-wise
train_files = train_files.reshape([-1,1])
val_files = val_files.reshape([-1,1])
test_files = test_files.reshape([-1,1])

train_labels = train_labels.reshape([-1,16])
val_labels = val_labels.reshape([-1,16])
test_labels = test_labels.reshape([-1,16])

######## Get rid of data points with Nan values #########
train_nan = ~np.isnan(train_labels).any(axis=1)
val_nan = ~np.isnan(val_labels).any(axis=1)
test_nan = ~np.isnan(test_labels).any(axis=1)

train_labels = train_labels[train_nan]
val_labels = val_labels[val_nan]
test_labels = test_labels[test_nan]

train_files = train_files[train_nan]
val_files = val_files[val_nan]
test_files = test_files[test_nan]
#########################################################

train_dataset = OCTDataset(train_files, train_labels, transform=transform_deit)
val_dataset = OCTDataset(val_files, val_labels, transform=transform_deit)
test_dataset = OCTDataset(test_files, test_labels, transform=transform_deit)

# Make DataLoader
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

(57, 98) (19, 98) (20, 98)


In [4]:
'''
Dataset setup with Stratification (OCT only)
'''

# Eye-wise split
# Split dataset into train/val/test

####### Handle rows with Nan -> mark them with ones(16,) ############
nan_rows = np.isnan(bio_markers).any(axis=2)  # Shape (96, 98), True where NaNs are present
print(np.sum(nan_rows))
# Replace NaN rows ones (checked that no data rows are filled with 1's)
alternating_row = np.ones(16)  
bio_markers[nan_rows] = alternating_row

####### Data stratification based on positive label count per eye #######
pos_count = np.sum(bio_markers,axis=(1,2)) # pos label count per eye
bins = [0,200,300,400,np.inf]
# print(pos_count)
stratify_bins = np.digitize(pos_count, bins)
# print(stratify_bins)
train_val_files, test_files, train_val_labels, test_labels = train_test_split(
    file_paths, bio_markers, test_size=0.2,
    stratify = stratify_bins,
    random_state=42
)

pos_count = np.sum(train_val_labels,axis=(1,2)) # pos label count per eye
bins = [0,200,300,400,np.inf]
stratify_bins = np.digitize(pos_count, bins)
# print(pos_count)
train_files, val_files, train_labels, val_labels = train_test_split(
    train_val_files, train_val_labels, test_size=0.25, 
    stratify = stratify_bins,
    random_state=42
)
print(train_files.shape, val_files.shape, test_files.shape)
##########################################################
# Eye-wise -> scan-wise
train_files = train_files.reshape([-1,1])
val_files = val_files.reshape([-1,1])
test_files = test_files.reshape([-1,1])

train_labels = train_labels.reshape([-1,16])
val_labels = val_labels.reshape([-1,16])
test_labels = test_labels.reshape([-1,16])

############################################
# Get rid of marked Nan rows
train_nan = np.sum(train_labels,axis=1)!=16
val_nan = np.sum(val_labels,axis=1)!=16
test_nan = np.sum(test_labels,axis=1)!=16

train_labels = train_labels[train_nan]
val_labels = val_labels[val_nan]
test_labels = test_labels[test_nan]

train_files = train_files[train_nan]
val_files = val_files[val_nan]
test_files = test_files[test_nan]

########################################################

train_dataset = OCTDataset(train_files, train_labels, transform=transform_deit)
val_dataset = OCTDataset(val_files, val_labels, transform=transform_deit)
test_dataset = OCTDataset(test_files, test_labels, transform=transform_deit)

# Make DataLoader
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

12
(57, 98) (19, 98) (20, 98)


In [72]:
# Check distribution of bio markers
print(np.sum(bio_markers,axis=(0,1)))
# Marker 2,9,13,14,15 has <100 markers -> Stratification based on these


array([ 178.,  616.,   44.,  385., 6350., 2996., 5228.,  819., 2848.,
         22., 3015., 4099.,  245.,   22.,   22.,   88.])

In [None]:
'''
Dataset setup with Stratification for low-sample markers (OCT only)
'''

# TODO : 
# Here, stratify_labels indicates whether a sample contains positive instances for any rare biomarker (1 for positive, 0 otherwise).


# stratify_labels = np.any(bio_markers[:, rare_biomarkers] > 0, axis=1).astype(int)
# train_val_files, test_files, train_val_labels, test_labels = train_test_split(
#     file_paths, bio_markers, test_size=0.2, stratify=stratify_labels, random_state=42
# )
# train_val_stratify_labels = np.any(train_val_labels[:, rare_biomarkers] > 0, axis=1).astype(int)

# train_files, val_files, train_labels, val_labels = train_test_split(
#     train_val_files, train_val_labels, test_size=0.25, stratify=train_val_stratify_labels, random_state=42
# )

In [35]:
np.argwhere(np.isnan(clin_data))

array([], shape=(0, 3), dtype=int64)

In [21]:
# Quick analysis of biomarker distribution

train_ratio = np.sum(train_labels, axis=(0)) / np.sum(train_labels)
val_ratio = np.sum(val_labels, axis=(0)) / np.sum(val_labels)
test_ratio = np.sum(test_labels, axis=(0)) / np.sum(test_labels)
print(train_ratio/val_ratio)
print(test_ratio/train_ratio)

# discrepancy between train vs val class distribution (not too severe (?))

[2.36595641 0.98581517 0.13144202 1.05439362 1.02383436 0.82252373
 0.88532738 1.61592384 0.8867551         inf 1.18876451 1.16274063
 1.14301272 0.88723365        inf 0.64694121]
[ 0.50613819  1.25364514 13.98781913  0.69430339  1.09687196  1.3972887
  0.79621124  1.00012212  0.58006124  1.26205887  1.10529691  1.06702551
  1.09143086  5.88960805  0.          2.10343145]


  print(train_ratio/val_ratio)


In [70]:
(2* np.sum(train_labels,axis=0)),(2* np.sum(val_labels,axis=0)),(2* np.sum(test_labels,axis=0))

(array([2.740e+02, 5.780e+02, 5.800e+01, 2.620e+02, 7.258e+03, 3.390e+03,
        6.770e+03, 1.078e+03, 3.914e+03, 0.000e+00, 3.376e+03, 4.634e+03,
        2.340e+02, 4.000e+00, 2.000e+01, 1.020e+02]),
 array([  32.,  408.,    6.,  210., 2532., 1328., 1712.,  316.,  926.,
          14., 1222., 1740.,  136.,   14.,    0.,    8.]),
 array([2.600e+01, 2.220e+02, 0.000e+00, 2.740e+02, 2.886e+03, 1.250e+03,
        1.950e+03, 2.200e+02, 8.320e+02, 6.000e+00, 1.408e+03, 1.800e+03,
        9.600e+01, 2.000e+00, 0.000e+00, 4.200e+01]))

In [7]:
# Weighted BCE for multi-label imbalanced (pos vs. neg) data

train_pos_weights = torch.tensor(train_labels.shape[0] / (2* np.sum(train_labels,axis=0)))
# val_pos_weights = val_labels.shape[0] / (2* np.sum(val_labels,axis=0))
# test_pos_weights = test_labels.shape[0] / (2* np.sum(test_labels,axis=0))

class WeightedBinaryCrossEntropyLoss(nn.Module):
    def __init__(self, pos_weights):
        """
        pos_weights: Tensor of shape (num_biomarkers,) containing weights for positive labels.
        """
        super(WeightedBinaryCrossEntropyLoss, self).__init__()
        self.pos_weights = pos_weights

    def forward(self, logits, targets):
        """
        logits: Predicted logits from the model, shape (batch_size, num_biomarkers).
        targets: Ground truth binary labels, shape (batch_size, num_biomarkers).
        """
        loss = nn.BCEWithLogitsLoss(reduction='none')(logits, targets)  # Compute BCE loss
        weighted_loss = loss * self.pos_weights  # Apply positive weights
        return weighted_loss.mean()
    


class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        """
        Focal Loss for multi-label classification.

        Parameters:
        - gamma (float): Focusing parameter that reduces the loss for well-classified samples (default: 2.0).
        - alpha (float or Tensor): Balancing factor to address class imbalance (default: None).
          If a tensor is provided, it should be of shape (num_classes,).
        - reduction (str): Specifies the reduction to apply to the output: 'none', 'mean', 'sum' (default: 'mean').
        """
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, logits, targets):
        """
        Compute Focal Loss.

        Parameters:
        - logits (Tensor): Predicted logits of shape (batch_size, num_classes).
        - targets (Tensor): Ground truth labels of shape (batch_size, num_classes).

        Returns:
        - loss (Tensor): Calculated focal loss.
        """
        # Convert logits to probabilities using sigmoid
        probs = torch.sigmoid(logits)
        
        # Binary cross-entropy loss
        bce_loss = F.binary_cross_entropy(probs, targets, reduction='none')
        
        # Compute the modulating factor (1 - p_t)^gamma
        pt = probs * targets + (1 - probs) * (1 - targets)
        focal_factor = (1 - pt) ** self.gamma

        # Apply class balancing factor alpha if provided
        if self.alpha is not None:
            if isinstance(self.alpha, (float, int)):
                alpha_factor = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            elif isinstance(self.alpha, torch.Tensor):
                alpha_factor = self.alpha.unsqueeze(0) * targets + (1 - self.alpha).unsqueeze(0) * (1 - targets)
            else:
                raise ValueError("Alpha must be a float, int, or torch.Tensor.")
            focal_loss = alpha_factor * focal_factor * bce_loss
        else:
            focal_loss = focal_factor * bce_loss

        # Reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Initialize Focal Loss (Example)
# 1. gamma (Focusing Parameter):
# Controls the strength of the focusing effect.
# Higher values put more focus on hard-to-classify samples.

# 2. alpha (Class Balancing Factor):
# Helps address class imbalance.
# If alpha is a scalar, it applies the same balancing for all classes.
# If alpha is a tensor, it applies per-class balancing.

# 3. reduction:
# 'mean': Average loss across the batch.
# 'sum': Sum of the loss across the batch.
# 'none': No reduction is applied; returns loss for each sample.

focal_loss = FocalLoss(gamma=2.0, alpha=0.25)


In [14]:
import timm 
import tqdm

# Load model (model_name, )
model_name = 'deit_base_patch16_384'
input_dim = 384
print(f'model name {model_name}, input size {input_dim}')
model = timm.create_model(model_name, pretrained=True) 
print(model.head)
###### Parameters ######
lr = 1e-3
num_classes = 16
epochs = 10
########################

# Modify the classifier head for multi-class output
model.head = nn.Sequential(
    nn.Linear(model.head.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, num_classes),  # 16 biomarkers
    nn.Sigmoid()  # Multi-label classification (probabilities for each class)
)

print(model.head)
model = model.to('cuda')

# Freeze Vision Encoder layers if needed
for param in model.parameters():
    param.requires_grad = False
for param in model.head.parameters():
    param.requires_grad = True

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

######### Loss ###########
# criterion = torch.nn.BCELoss()  # Binary Cross-Entropy Loss for multi-label classification
criterion = WeightedBinaryCrossEntropyLoss(train_pos_weights.to('cuda'))

# train_pos_weights = torch.tensor(1/ (np.sum(train_labels,axis=0)))
# criterion = FocalLoss(gamma=2.0, alpha=train_pos_weights.to('cuda'))
##########################

# Training and validation
def train_one_epoch(model, train_loader, optimizer, criterion):
    model.train()
    running_loss = 0.0

    for images, labels in tqdm.tqdm(train_loader):
        images, labels = images.to('cuda'), labels.to('cuda')

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
#         print('loss : ',loss.item())
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_loader)

def validate_one_epoch(model, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    all_outputs = np.zeros([len(val_dataset),num_classes])
    all_labels = np.zeros([len(val_dataset),num_classes])
    i = 0
    with torch.no_grad():
        for images, labels in tqdm.tqdm(val_loader):
            images, labels = images.to('cuda'), labels.to('cuda')
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
#             print(i*batch_size, min((i+1)*batch_size,len(val_dataset)))
#             print(outputs.shape)
            all_outputs[i*batch_size: min((i+1)*batch_size,len(val_dataset)),:] = outputs.cpu()
            all_labels[i*batch_size: min((i+1)*batch_size,len(val_dataset)),:] = labels.cpu()
            i+=1

    return running_loss / len(val_loader), all_outputs, all_labels

# Training loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_outputs, val_labels = validate_one_epoch(model, val_loader, criterion)
    
    print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    scheduler.step()

# # Test the model
# def test_model(model, test_loader):
#     model.eval()
#     correct = 0
#     total = 0
#     threshold = 0.5
#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to('cuda'), labels.to('cuda')
#             outputs = model(images)
#             predictions = (outputs > threshold).float()  # Threshold at 0.5 for binary decisions
#             correct += (predictions == labels).sum().item()
#             total += labels.numel()

#     accuracy = correct / total
#     print(f"Test Accuracy: {accuracy * 100:.2f}%")

# # Evaluate on test set

# test_model(model, test_loader, device)

model name deit_base_patch16_224, input size 224
Linear(in_features=768, out_features=1000, bias=True)
Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=512, out_features=16, bias=True)
  (4): Sigmoid()
)
Epoch 1/10


100%|██████████| 175/175 [01:33<00:00,  1.87it/s]
100%|██████████| 59/59 [00:37<00:00,  1.56it/s]


Train Loss: 0.1157, Val Loss: 0.1156
Epoch 2/10


100%|██████████| 175/175 [01:32<00:00,  1.88it/s]
100%|██████████| 59/59 [00:46<00:00,  1.27it/s]


Train Loss: 0.1140, Val Loss: 0.1156
Epoch 3/10


100%|██████████| 175/175 [01:08<00:00,  2.54it/s]
100%|██████████| 59/59 [00:43<00:00,  1.36it/s]


Train Loss: 0.1138, Val Loss: 0.1155
Epoch 4/10


100%|██████████| 175/175 [01:45<00:00,  1.66it/s]
100%|██████████| 59/59 [00:39<00:00,  1.49it/s]


Train Loss: 0.1136, Val Loss: 0.1154
Epoch 5/10


100%|██████████| 175/175 [01:42<00:00,  1.71it/s]
100%|██████████| 59/59 [00:49<00:00,  1.19it/s]


Train Loss: 0.1135, Val Loss: 0.1152
Epoch 6/10


100%|██████████| 175/175 [01:54<00:00,  1.53it/s]
100%|██████████| 59/59 [00:26<00:00,  2.26it/s]


Train Loss: 0.1132, Val Loss: 0.1152
Epoch 7/10


100%|██████████| 175/175 [01:54<00:00,  1.53it/s]
100%|██████████| 59/59 [00:24<00:00,  2.40it/s]


Train Loss: 0.1131, Val Loss: 0.1153
Epoch 8/10


100%|██████████| 175/175 [01:47<00:00,  1.63it/s]
100%|██████████| 59/59 [00:24<00:00,  2.41it/s]


Train Loss: 0.1131, Val Loss: 0.1152
Epoch 9/10


100%|██████████| 175/175 [01:46<00:00,  1.65it/s]
100%|██████████| 59/59 [00:55<00:00,  1.07it/s]


Train Loss: 0.1130, Val Loss: 0.1152
Epoch 10/10


100%|██████████| 175/175 [02:18<00:00,  1.27it/s]
100%|██████████| 59/59 [01:02<00:00,  1.05s/it]

Train Loss: 0.1130, Val Loss: 0.1153





In [45]:
len(val_dataset)

1854

In [15]:
import sklearn
'''
Optimize biomarker-wise threshold for validation on eval metrics

'''

def optimal_thresholds(model_outputs, labels, metric='f1'):
    """
    Calculate optimal thresholds for each biomarker.
    
    Parameters:
        model_outputs (ndarray): Model predictions, shape (N, 16) where N is the number of samples.
        labels (ndarray): True binary labels, shape (N, 16).
        metric (str): Metric to optimize. Options: 'f1', 'auc'.

    Returns:
        thresholds (list): Optimal threshold for each biomarker.
        scores (list): Corresponding best scores for each biomarker.
    """
    num_biomarkers = model_outputs.shape[1]
    thresholds = []
    scores = []
    
    for i in range(num_biomarkers):
        best_threshold = 0.0
        best_score = 0.0
        
        # Thresholds to search
        thresholds_range = np.linspace(0, 1, 100)
        
        for threshold in thresholds_range:
            preds = (model_outputs[:, i] >= threshold).astype(int)
            
            if metric == 'f1':
                score = sklearn.metrics.f1_score(labels[:, i], preds)
            elif metric == 'auc':
                try:
                    score = sklearn.metrics.roc_auc_score(labels[:,i],model_outputs[:,i])
                except ValueError:
                    score = np.nan
#                 # AUC does not depend on a threshold
#                 score = sklearn.metrics.roc_auc_score(labels[:, i], model_outputs[:, i])
#                 best_threshold = None  # No threshold needed for AUC
                
            else:
                raise ValueError("Unsupported metric. Use 'f1' or 'auc'.")
            
            if score > best_score:
                best_score = score
                best_threshold = threshold
        
        thresholds.append(best_threshold)
        scores.append(best_score)
    
    return thresholds, scores

f1_th, f1_scores = optimal_thresholds(val_outputs, val_labels, metric='f1')
# auc_th, auc_scores = optimal_thresholds(val_outputs, val_labels, metric='auc')
print(f'F1 threshold {f1_th}, F1 validation scores {f1_scores}')
# print(f'AUC threshold {auc_th}, AUC validation scores {auc_scores}')
      

F1 threshold [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.010101010101010102, 0.0, 0.0, 0.0, 0.0], F1 validation scores [0.02545068928950159, 0.11443037974683544, 0.018094731240021287, 0.07942238267148015, 0.746971736204576, 0.48309572301425663, 0.7401894451962111, 0.10676156583629894, 0.5032154340836013, 0.008556149732620321, 0.40171673819742487, 0.6725043782837128, 0.04719454640797063, 0.009620523784072688, 0.008556149732620321, 0.02545068928950159]


In [17]:
# Get F1 scores and AUC 
import sklearn
import torch
import json
import numpy as np
import timm
import tqdm
from torch import nn
test_shape = test_labels.shape
 
def test_with_eval_metric(model, test_loader, threshold, batch_size, output_path):
    """
    Evaluate the model using test data and save F1, AUC scores, and thresholds as JSON.
 
    Parameters:
        model (torch.nn.Module): Trained model.
        test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
        threshold (list): List of thresholds for classification.
        batch_size (int): Batch size.
        output_path (str): Path to save the JSON metrics file.
    """
    model.eval()
    test_shape = (len(test_loader.dataset), len(threshold))
    target = np.zeros(test_shape)
    pred = np.zeros(test_shape)
 
    threshold_tensor = torch.tensor(threshold).to("cuda")
 
    with torch.no_grad():
        for batch, (images, labels) in enumerate(test_loader):
            images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            pred[
                batch * batch_size : min((batch + 1) * batch_size, len(test_loader.dataset))
            ] = (outputs > threshold_tensor).float().cpu()
            target[
                batch * batch_size : min((batch + 1) * batch_size, len(test_loader.dataset))
            ] = labels.cpu()
 
    metrics = {}
    for i in range(len(threshold)):
        f1 = sklearn.metrics.f1_score(target[:, i], pred[:, i], zero_division=0)
        try:
            auc = sklearn.metrics.roc_auc_score(target[:, i], pred[:, i])
        except ValueError:
            auc = np.nan
        metrics[f"Biomarker_{i}"] = {
            "F1": f1,
            "AUC": auc,
            "Threshold": threshold[i]
        }
 
    # Save to JSON
    with open(output_path, "w") as json_file:
        json.dump(metrics, json_file, indent=4)
    print(f1)
    print(f"Metrics saved to {output_path}")
 
 
# Eval test with optimized F1 Threshold
opt_output_path = "deit384_focal_lr_f1_threshold_metrics.json"
default_output_path = "deit384_focal_lr_0.5threshold_metrics.json"
test_with_eval_metric(model, test_loader, threshold=f1_th, batch_size=batch_size, output_path=opt_output_path)
test_with_eval_metric(model, test_loader, threshold=[0.5]*16, batch_size=batch_size, output_path=default_output_path)

0.02518891687657431
Metrics saved to deit284_focal_lr_f1_threshold_metrics.json
0.0
Metrics saved to deit284_focal_lr_0.5threshold_metrics.json


In [9]:
import torch
import json
import sklearn
import numpy as np
import timm
import tqdm
from torch import nn
 
# Function to save metrics as JSON
def save_metrics_to_json(file_path, metrics):
    """
    Save metrics to a JSON file.
    Parameters:
        file_path (str): Path to save the JSON file.
        metrics (dict): Dictionary containing evaluation metrics.
    """
    with open(file_path, 'w') as json_file:
        json.dump(metrics, json_file, indent=4)
        
    return
 
# Evaluation function with JSON output
def test_with_eval_metric_json(model, test_loader, threshold, batch_size, num_classes, output_path):
    """
    Evaluate the model using test data and save F1 and AUC scores as JSON.
 
    Parameters:
        model (torch.nn.Module): Trained model.
        test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
        threshold (list): List of thresholds for classification.
        batch_size (int): Batch size.
        num_classes (int): Number of output classes.
        output_path (str): Path to save the JSON metrics file.
    """
    model.eval()
    test_shape = (len(test_loader.dataset), num_classes)
    target = np.zeros(test_shape)
    pred = np.zeros(test_shape)
    threshold = torch.tensor(threshold).to('cuda')
    with torch.no_grad():
        for batch, (images, labels) in enumerate(test_loader):
            images, labels = images.to('cuda'), labels.to('cuda')
            outputs = model(images)
            pred[batch*batch_size: min((batch+1)*batch_size, len(test_loader.dataset))] = (outputs > threshold).float().cpu()
            target[batch*batch_size: min((batch+1)*batch_size, len(test_loader.dataset))] = labels.cpu()
 
    metrics = {}
    for i in range(num_classes):
        f1 = sklearn.metrics.f1_score(target[:, i], pred[:, i], zero_division=0)
        try:
            auc = sklearn.metrics.roc_auc_score(target[:, i], pred[:, i])
        except ValueError:
            auc = np.nan
        metrics[f"Biomarker_{i}"] = {"F1": f1, "AUC": auc}
 
    # Save to JSON
    save_metrics_to_json(output_path, metrics)
    print(f"Metrics saved to {output_path}")
    return
 
# Example usage for VGG16
# output_path = "prelim_deit224_stratO_BCE_metrics.json"
# output_path = "prelim_deit224_stratO_Weighted_metrics.json"
output_path = "prelim_deit384_metrics.json"
test_with_eval_metric_json(
    model, 
    test_loader, 
    threshold=f1_th, 
    batch_size=batch_size, 
    num_classes=num_classes, 
    output_path=output_path
)

Metrics saved to prelim_deit224_stratO_Focal_metrics.json
