In [1]:
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import os
from PIL import Image
from sklearn.model_selection import train_test_split

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

In [3]:
# from skmultilearn.model_selection import iterative_train_test_split
# t_train, y_train, t_test, y_test = iterative_train_test_split(X, y, test_size = 0.2)

In [2]:
# Dataset Load and split

# 96 eyes, 49 OCT images, 2 visits, 16 biomarkers (binary)
# 96 eyes, 49 OCT images, 2 visits, 496 x 504 OCT images (grayscale)
scan_N = 9408
oct_N = 49
eye_N = 96
sh = [496, 504]

csv_file = '~/scratch/OLIVES/OLIVES/Biomarker_Clinical_Data_Images_Updated.csv'
data = pd.read_csv(csv_file)
col_names = data.columns
file_paths = data['Path (Trial/Arm/Folder/Visit/Eye/Image Name)'].values #[9408,]
file_paths = file_paths.reshape([eye_N,2*oct_N])
bio_markers = data[col_names[2:18]].values
bio_markers = bio_markers.reshape([eye_N,2*oct_N,-1])

clin_data = data[col_names[19:21]].values
clin_data = clin_data.reshape([eye_N,2*oct_N,-1])

home_dir = '/home/hice1/hsuh45/scratch/OLIVES/OLIVES/'



In [3]:
# Check for rows with Nan and identify the rows (get rid of them after Data split)
rows_with_nan = data[data.isna().any(axis=1)]
# data = data.dropna()

In [4]:
# DeiT preprocessing
# transform_deit = transforms.Compose([
#     transforms.Resize((512,512)),                # Resize to square dimensions
#     transforms.ToTensor(),                        # Convert to tensor
#     transforms.Normalize(mean=[0.5], std=[0.5])   # Normalize (adjust mean/std for grayscale)
# ])
# SwinT preprocessing
# resize to square + fit the input size (due to small dataset)
transform_swin = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3 channels
    transforms.CenterCrop(496),
    transforms.Resize((224,224)),                # 224x224 or 384x384
    transforms.ToTensor(),                        # Convert to tensor
#     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
    transforms.Normalize(mean=[0.485, 0.456, 0.406], # Normalize with ImageNet stats
        std=[0.229, 0.224, 0.225])
    
])


# Create DataLoaders with the preprocessed data
class OCTDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, index):
        
        img = Image.open(home_dir + self.file_paths[index][0]).convert("L")
        label = self.labels[index]  # Shape: [sample N, bio_marker_N]
        
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(label, dtype=torch.float32)
        
        return img, label

# Eye-wise split
# Split dataset into train/val/test
train_val_files, test_files, train_val_labels, test_labels = train_test_split(
    file_paths, bio_markers, test_size=0.2, random_state=42
)

train_files, val_files, train_labels, val_labels = train_test_split(
    train_val_files, train_val_labels, test_size=0.25, random_state=42
)
print(train_files.shape, val_files.shape, test_files.shape)
# Eye-wise -> scan-wise
train_files = train_files.reshape([-1,1])
val_files = val_files.reshape([-1,1])
test_files = test_files.reshape([-1,1])

train_labels = train_labels.reshape([-1,16])
val_labels = val_labels.reshape([-1,16])
test_labels = test_labels.reshape([-1,16])

######## Get rid of data points with Nan values #########
train_nan = ~np.isnan(train_labels).any(axis=1)
val_nan = ~np.isnan(val_labels).any(axis=1)
test_nan = ~np.isnan(test_labels).any(axis=1)

train_labels = train_labels[train_nan]
val_labels = val_labels[val_nan]
test_labels = test_labels[test_nan]

train_files = train_files[train_nan]
val_files = val_files[val_nan]
test_files = test_files[test_nan]
#########################################################

train_dataset = OCTDataset(train_files, train_labels, transform=transform_swin)
val_dataset = OCTDataset(val_files, val_labels, transform=transform_swin)
test_dataset = OCTDataset(test_files, test_labels, transform=transform_swin)

# Make DataLoader
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

(57, 98) (19, 98) (20, 98)


In [7]:
train_files.shape

(5582, 1)

In [8]:
np.isnan(test_labels).sum()

0

In [9]:
# Quick analysis of biomarker distribution

print(np.sum(train_labels, axis=(0)) / np.sum(train_labels))
print(np.sum(val_labels, axis=(0)) / np.sum(val_labels))

# discrepancy between train vs val class distribution (not too severe (?))

[7.81488491e-03 2.13688259e-02 2.44215154e-04 1.50192319e-02
 2.32920203e-01 9.93345137e-02 1.98485866e-01 3.23585078e-02
 1.13132670e-01 4.27376519e-04 1.12888455e-01 1.54343977e-01
 8.73069174e-03 1.83161365e-04 6.10537884e-04 2.13688259e-03]
[3.30305533e-03 2.16763006e-02 1.85796862e-03 1.42444261e-02
 2.27497936e-01 1.20767960e-01 2.24194880e-01 2.00247729e-02
 1.27580512e-01 0.00000000e+00 9.49628406e-02 1.32741536e-01
 7.63831544e-03 2.06440958e-04 0.00000000e+00 3.30305533e-03]


In [11]:
# !pip install timm
torch.cuda.is_available()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


True

In [5]:
import timm 
import tqdm


# Load Swin Transformer model
# model = timm.create_model('swin_base_patch4_window12_384', pretrained=True) 
# Load DeiT model
model = timm.create_model('deit_base_patch16_224', pretrained=True) 
print(model.head)
###### Parameters ######
lr = 1e-4
num_classes = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 10


# Modify the classifier head for multi-class output
model.head = nn.Sequential(
    nn.Linear(model.head.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, num_classes),  # 16 biomarkers
    nn.Sigmoid()  # Multi-label classification (probabilities for each class)
)

print(model.head)
model = model.to('cuda')

# Freeze Vision Encoder layers if needed
for param in model.parameters():
    param.requires_grad = False
for param in model.head.parameters():
    param.requires_grad = True

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
criterion = torch.nn.BCELoss()  # Binary Cross-Entropy Loss for multi-label classification
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Training and validation
def train_one_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0

    for images, labels in tqdm.tqdm(train_loader):
        images, labels = images.to('cuda'), labels.to('cuda')
#         print(images.shape, labels.shape)

        optimizer.zero_grad()
        outputs = model(images)
#         print(outputs.shape, labels.shape, images.shape)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_loader)

def validate_one_epoch(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm.tqdm(val_loader):
            images, labels = images.to('cuda'), labels.to('cuda')
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            all_outputs.append(outputs.cpu())
            all_labels.append(labels.cpu())

    return running_loss / len(val_loader)

# Training loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss = validate_one_epoch(model, val_loader, criterion, device)
    
    print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    scheduler.step()

# Test the model
def test_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    threshold = 0.5
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to('cuda'), labels.to('cuda')
            outputs = model(images)
            predictions = (outputs > threshold).float()  # Threshold at 0.5 for binary decisions
            correct += (predictions == labels).sum().item()
            total += labels.numel()

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

# Evaluate on test set
test_model(model, test_loader, device)

Linear(in_features=768, out_features=1000, bias=True)
Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=16, bias=True)
  (4): Sigmoid()
)
Epoch 1/10


100%|██████████| 175/175 [01:55<00:00,  1.51it/s]
 52%|█████▏    | 30/58 [00:22<00:20,  1.34it/s]


KeyboardInterrupt: 

In [25]:
# Get F1 scores and AUC 
import sklearn
test_shape = test_labels.shape

def test_with_eval_metric(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    threshold = 0.5
    target = np.zeros(test_shape)
    pred = np.zeros(test_shape)
    batch_size = 32
    
    with torch.no_grad():
        for batch, (images, labels) in enumerate(test_loader):
            images, labels = images.to('cuda'), labels.to('cuda')
            outputs = model(images)
            pred[batch*batch_size: (batch+1)*batch_size] = (outputs > threshold).float().cpu()  # Threshold at 0.5 for binary decisions
            target[batch*batch_size: (batch+1)*batch_size] = labels.cpu()
#             correct += (predictions == labels).sum().item()
#             total += labels.numel()
    print("Finished iteration")
    print(pred[0])
    print(target[0])
    for i in range(num_classes):
        f1 = sklearn.metrics.f1_score(target[:,i], pred[:,i], zero_division =0)
        try:
            auc = sklearn.metrics.roc_auc_score(target[:,i],pred[:,i])
        except ValueError:
            auc = np.nan
        print(f1, auc)
#     accuracy = correct / total
#     print(f"Test Accuracy: {accuracy * 100:.2f}%")

test_with_eval_metric(model, test_loader, device)

Finished iteration
[0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
0.0 0.5
0.0 0.5
0.0 0.5
0.08571428571428572 0.5234961383661482
0.8512679917751885 0.7004670206640651
0.3918799646954987 0.5852806999180057
0.6673407482305359 0.6723738922057378
0.0213903743315508 0.5041510611735331
0.3088512241054614 0.5859964787220338
0.0 0.5
0.8469860896445132 0.8742755553127461
0.8386363636363636 0.8520721444226941
0.31746031746031744 0.5943396226415094
0.0 0.5
0.0 nan
0.0 0.5


In [16]:
test_labels.shape[0]



(1960, 16)

In [56]:

# Modify the classifier head for multi-class output
model.head = nn.Sequential(
    timm.layers.SelectAdaptivePool2d(pool_type='avg'), #, flatten=Identity()),
    nn.Linear(model.head.in_features, 512),
    nn.Identity(),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, num_classes),  # 16 biomarkers
    nn.Sigmoid()  # Multi-label classification (probabilities for each class)
)
model.head
# model.head.fc = nn.Linear(model.head.in_features,512)
# model.head.relu = nn.ReLU()
# model.head.drop2 = nn.Dropout(0.3)
# model.head.fc2 = nn.Linear(512,16)
# model.head.sig = nn.Sigmoid()
# model.head

Sequential(
  (0): SelectAdaptivePool2d(pool_type=avg, flatten=Identity())
  (1): Linear(in_features=1024, out_features=512, bias=True)
  (2): Identity()
  (3): ReLU()
  (4): Dropout(p=0.3, inplace=False)
  (5): Linear(in_features=512, out_features=16, bias=True)
  (6): Sigmoid()
)