# Setup

In [4]:
# imports

# External modules
import sys
import os
import torch
import tqdm
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from torchvision import transforms
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.models import resnet18
from transformers import AutoImageProcessor, AutoModelForImageClassification
from huggingface_hub import from_pretrained_keras
from tqdm.notebook import tqdm as tqdm
from sklearn.model_selection import train_test_split

# fixing paths
project_root = os.path.abspath("..")
sys.path.append(project_root)

# owned modules
from src.datasets import SCINDataset
from src.models import FeatureExtractor, ClinicalOutcomePredictor, Adversary
from src.utils import custom_collate_fn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#models
google_vit_huge = "google/vit-huge-patch14-224-in21k"
google_derm = "google/derm-foundation"

# seeds
torch.manual_seed(42)
np.random.seed(42)

# Loading dataset

In [5]:
# loading dataset

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),     
])

protected_attribute = 'combined_race'
dataset = SCINDataset(
    root_dir="../data/external/scin/dataset",
    labels_csv="scin_labels.csv",
    cases_csv="scin_cases.csv",
    transform=transform,
    protected_attr=protected_attribute
)

num_classes = len(dataset.label_encoder.classes_)
num_protected_attributes = len(dataset.protected_label_encoder.classes_)

indices = list(range(len(dataset)))
train_indices, val_indices = train_test_split(indices, test_size=0.2, shuffle=True, random_state=42)

train_data = Subset(dataset, train_indices)
val_data = Subset(dataset, val_indices)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)


In [None]:
# validating dataset train/val split

print("Number of training samples:", len(train_data))
print("Number of validation samples:", len(val_data))
print(len(dataset))


# Predicting outcomes only (Google ViT Huge)

In [None]:
# training parameters

batch_size = 16
lr = 1e-6
num_epochs = 20
# lambda_ = 1
dropout = 0.3

In [None]:
# model definition/parameters

model = AutoModelForImageClassification.from_pretrained(google_vit_huge).to(device)
model.classifier = torch.nn.Linear(model.config.hidden_size, num_classes)

model.vit.embeddings.dropout.p = dropout
for layer in model.vit.encoder.layer:
	layer.attention.attention.dropout.p = dropout
	layer.attention.output.dropout.p = dropout
	layer.output.dropout.p = dropout

for param in model.parameters():
	param.requires_grad = False

for param in model.classifier.parameters():
	param.requires_grad = True
 
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=lr)

In [None]:
# training loop

train_accuracy = []
val_accuracy = []

for epoch in tqdm(range(num_epochs), desc="Overall Training Progress"):
    
    start_time = time.time()
    
    # TRAINING
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 20)

    running_loss = 0.0
    correct = 0
    total_samples = 0

    train_epoch_time = time.time()

    model.train()
    for batch in tqdm(train_loader, desc=f"Training Epoch [{epoch+1}/{num_epochs}]"):
        images, outcomes = [x.to(device) for x in batch[:2]]

        optimizer.zero_grad()

        outcomes_pred = model(images).logits
        loss = criterion(outcomes_pred, outcomes)
        
        loss.backward()
        optimizer.step()

        predicted = torch.argmax(outcomes_pred, dim=1)
        total_samples += outcomes.size(0)
        correct += (predicted == outcomes).sum().item()
        
        running_loss += loss.item()


    epoch_accuracy = 100 * correct / total_samples
    train_accuracy.append(epoch_accuracy)

    train_time = time.time() - train_epoch_time
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss:.4f}, "
          f"Accuracy: {epoch_accuracy:.2f}%")


    # VALIDATION    
    val_correct = 0
    val_total = 0

    val_epoch_time = time.time()

    model.eval()
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch [{epoch+1}/{num_epochs}]"):
            images, outcomes = [x.to(device) for x in batch[:2]]

            outcomes_pred = model(images).logits
            
            predicted = torch.argmax(outcomes_pred, dim=1)
            val_total += outcomes.size(0)
            val_correct += (predicted == outcomes).sum().item()

    val_epoch_accuracy = 100 * val_correct / val_total
    val_accuracy.append(val_epoch_accuracy)

    val_time = time.time() - val_epoch_time    
    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {val_epoch_accuracy:.2f}%")
        
    epoch_time = time.time() - start_time
    print(f"Epoch [{epoch+1}/{num_epochs}] completed in {epoch_time:.2f}s "
          f"(Train: {train_time:.2f}s, Val: {val_time:.2f}s)")

# Predicting race only (Google ViT Huge)

In [6]:
# training parameters

batch_size = 16
lr = 1e-6
num_epochs = 20
# lambda_ = 1
dropout = 0.3

In [7]:
# model definition/parameters

model = AutoModelForImageClassification.from_pretrained(google_vit_huge).to(device)
model.classifier = torch.nn.Linear(model.config.hidden_size, num_protected_attributes)

model.vit.embeddings.dropout.p = dropout
for layer in model.vit.encoder.layer:
	layer.attention.attention.dropout.p = dropout
	layer.attention.output.dropout.p = dropout
	layer.output.dropout.p = dropout

for param in model.parameters():
	param.requires_grad = False

for param in model.classifier.parameters():
	param.requires_grad = True
 
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=lr)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-huge-patch14-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# training loop

train_accuracy = []
val_accuracy = []

for epoch in tqdm(range(num_epochs), desc="Overall Training Progress"):
    
    start_time = time.time()
    
    # TRAINING
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 20)

    running_loss = 0.0
    correct = 0
    total_samples = 0

    train_epoch_time = time.time()

    model.train()
    for batch in tqdm(train_loader, desc=f"Training Epoch [{epoch+1}/{num_epochs}]"):
        images, _, attributes = [x.to(device) for x in batch]

        optimizer.zero_grad()

        attributes_pred = model(images).logits
        loss = criterion(attributes_pred, attributes)
        
        loss.backward()
        optimizer.step()

        predicted = torch.argmax(attributes_pred, dim=1)
        total_samples += attributes.size(0)
        correct += (predicted == attributes).sum().item()
        
        running_loss += loss.item()


    epoch_accuracy = 100 * correct / total_samples
    train_accuracy.append(epoch_accuracy)

    train_time = time.time() - train_epoch_time
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss:.4f}, "
          f"Accuracy: {epoch_accuracy:.2f}%")


    # VALIDATION    
    val_correct = 0
    val_total = 0

    val_epoch_time = time.time()

    model.eval()
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch [{epoch+1}/{num_epochs}]"):
            images, _, attributes = [x.to(device) for x in batch]

            attributes_pred = model(images).logits
            
            predicted = torch.argmax(attributes_pred, dim=1)
            val_total += attributes.size(0)
            val_correct += (predicted == attributes).sum().item()

    val_epoch_accuracy = 100 * val_correct / val_total
    val_accuracy.append(val_epoch_accuracy)

    val_time = time.time() - val_epoch_time    
    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {val_epoch_accuracy:.2f}%")
        
    epoch_time = time.time() - start_time
    print(f"Epoch [{epoch+1}/{num_epochs}] completed in {epoch_time:.2f}s "
          f"(Train: {train_time:.2f}s, Val: {val_time:.2f}s)")

# Adversarial pipeline

In [None]:
# training parameters
batch_size = 64
lr = 1e-6
num_epochs = 20
lambda_ = 1
dropout = 0.3

In [None]:
# model definition/parameters
model = AutoModelForImageClassification.from_pretrained(google_vit_huge).to(device)
model.classifier = torch.nn.Identity()

print(model)
extractor = FeatureExtractor(model).to(device)
predictor = ClinicalOutcomePredictor(embedding_dim=extractor.embedding_dim, num_outcomes=num_classes).to(device)
adversary = Adversary(embedding_dim=extractor.embedding_dim, num_protected_attributes=num_protected_attributes).to(device)

primary_model = {'extractor': extractor, 'predictor': predictor}

for param in model.parameters():
    param.requires_grad = True

for param in predictor.fc.parameters():
    param.requires_grad = True

criterion_primary = torch.nn.CrossEntropyLoss()
criterion_adversary = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(
    list(extractor.parameters()) + 
    list(predictor.parameters()) + 
    list(adversary.parameters()), lr=lr
)

In [None]:
# training loop
train_accuracy = []
val_accuracy = []

for epoch in tqdm(range(num_epochs), desc="Overall Training Progress"):
    
    start_time = time.time()
    
    # TRAINING
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 20)

    primary_model['extractor'].train()
    primary_model['predictor'].train()
    adversary.train()

    total_loss = 0.0
    total_primary_loss = 0.0
    total_adversary_loss = 0.0
    correct = 0
    total_samples = 0

    train_epoch_time = time.time()

    for batch in tqdm(train_loader, desc=f"Training Epoch [{epoch+1}/{num_epochs}]"):
        images, outcomes, protected_attributes = [x.to(device) for x in batch]

        optimizer.zero_grad()

        embeddings = primary_model['extractor'](images)
        outcomes_pred = primary_model['predictor'](embeddings)
        protected_pred = adversary(embeddings)

        loss_primary = criterion_primary(outcomes_pred, outcomes)
        loss_adversary = criterion_adversary(protected_pred, protected_attributes)
        loss = loss_primary - lambda_ * loss_adversary

        loss.backward()
        optimizer.step()

        predicted = torch.argmax(outcomes_pred, dim=1)
        total_samples += outcomes.size(0)
        correct += (predicted == outcomes).sum().item()

        total_loss += loss.item()
        total_primary_loss += loss_primary.item()
        total_adversary_loss += loss_adversary.item()

    epoch_accuracy = 100 * correct / total_samples
    train_accuracy.append(epoch_accuracy)

    train_time = time.time() - train_epoch_time
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}, "
          f"Primary Loss: {total_primary_loss:.4f}, Adversary Loss: {total_adversary_loss:.4f}, "
          f"Accuracy: {epoch_accuracy:.2f}%")

    # VALIDATION    
    val_correct = 0
    val_total = 0

    val_epoch_time = time.time()
    primary_model['extractor'].eval()
    primary_model['predictor'].eval()
    adversary.eval()

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch [{epoch+1}/{num_epochs}]"):
            images, outcomes, protected_attributes = [x.to(device) for x in batch]

            embeddings = primary_model['extractor'](images)
            outcomes_pred = primary_model['predictor'](embeddings)

            predicted = torch.argmax(outcomes_pred, dim=1)
            val_total += outcomes.size(0)
            val_correct += (predicted == outcomes).sum().item()

    val_epoch_accuracy = 100 * val_correct / val_total
    val_accuracy.append(val_epoch_accuracy)

    val_time = time.time() - val_epoch_time    
    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {val_epoch_accuracy:.2f}%")
        
    epoch_time = time.time() - start_time
    print(f"Epoch [{epoch+1}/{num_epochs}] completed in {epoch_time:.2f}s "
          f"(Train: {train_time:.2f}s, Val: {val_time:.2f}s)")