### Tutorial 3: Training a WSI Classification Model with ABMIL

This tutorial will guide you step-by-step to train an attention-based multiple instance learning model using Trident patch embeddings. 


#### A- Installation and patch feature extraction using UNI

#### Step 1: Download a dataset of whole-slide images

You can use your own WSIs or download a publicly available dataset, e.g. from:

- **CPTAC CCRCC WSIs**: Download from the [TCIA Cancer Imaging Archive](https://www.cancerimagingarchive.net/collection/cptac-ccrcc/).
- **Storage**: Save all WSIs into a local directory, e.g.,  
  ```bash
  ./CPTAC-CCRCC_v1/CCRCC
  ```

#### Step 2:  Run UNI feature extraction:

Navigate to the base directory of Trident and execute the following command:

```bash
python run_batch_of_slides.py --task all \
  --wsi_dir ./CPTAC-CCRCC_v1/CCRCC \
  --job_dir ./trident_processed \
  --patch_encoder uni_v1 \
  --mag 20 \
  --patch_size 256
```


#### B- Training an ABMIL model

In [None]:
import datasets
import pandas as pd

# Download labels as csv
datasets.load_dataset(
    'MahmoodLab/Patho-Bench', 
    cache_dir='./tutorial-3',
    dataset_to_download='cptac_ccrcc',     
    task_in_dataset='BAP1_mutation',           
    trust_remote_code=True
)

# Visualize my labels and splits
df = pd.read_csv('tutorial-3/cptac_ccrcc/BAP1_mutation/k=all.tsv', sep="\t")
print(df.value_counts('BAP1_mutation'))

df

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score

from trident.slide_encoder_models.load import ABMILSlideEncoder

# Build binary classification model
class BinaryClassificationModel(nn.Module):
    def __init__(self, input_feature_dim=768, n_heads=1, head_dim=512, dropout=0., gated=True, hidden_dim=256):
        super().__init__()
        self.feature_encoder = ABMILSlideEncoder(
            input_feature_dim=input_feature_dim, 
            n_heads=n_heads, 
            head_dim=head_dim, 
            dropout=dropout, 
            gated=gated
        )
        self.classifier = nn.Sequential(
            nn.Linear(input_feature_dim, hidden_dim),  # head_dim from ABMILSlideEncoder output
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Binary classification
        )

    def forward(self, x):
        features = self.feature_encoder(x)  # Encode features
        logits = self.classifier(features).squeeze(1)  # Output logits
        return logits

# Initialize model
model = BinaryClassificationModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Custom dataset
class H5Dataset(Dataset):
    def __init__(self, feats_path, df, split, num_features=512):
        self.df = df[df["fold_1"] == split]
        self.feats_path = feats_path
        self.num_features = num_features
        self.split = split
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        features = h5py.File(os.path.join(self.feats_path, row['slide_id'] + '.h5'), "r")  
        features = torch.from_numpy(features["features"][:])
        
        if self.split == 'train':
            num_available = features.shape[0]
            if num_available >= self.num_features:
                indices = np.random.choice(num_available, self.num_features, replace=False)
            else:
                indices = np.random.choice(num_available, self.num_features, replace=True)  # Oversampling
            features = features[indices]
        
        label = torch.tensor(row["BAP1_mutation"], dtype=torch.float32)
        return features, label

# Create dataloaders
feats_path = './cptac_ccrcc/20x_512px_0px_overlap/features_conch_v15'
batch_size = 8
train_loader = DataLoader(H5Dataset(feats_path, df, "train"), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(H5Dataset(feats_path, df, "test"), batch_size=1, shuffle=False)

# Training setup
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=4e-4)

# Training loop on Fold 0. 
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.
    for features, labels in train_loader:
        features, labels = {'features': features.to(device)}, labels.to(device)
        optimizer.zero_grad()
        outputs = model(features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")


In [None]:
# Evaluation
model.eval()
all_labels, all_outputs = [], []
correct = 0
total = 0

with torch.no_grad():
    for features, labels in test_loader:
        features, labels = {'features': features.to(device)}, labels.to(device)
        outputs = model(features)
        
        # Convert logits to probabilities and binary predictions
        predicted = (outputs > 0).float()  # Since BCEWithLogitsLoss expects raw logits
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        all_outputs.append(outputs.cpu().numpy())  
        all_labels.append(labels.cpu().numpy())
        print(f"Logits: {outputs.item():.4f}, Label: {labels.item()}")

# Compute AUC
all_outputs = np.concatenate(all_outputs)
all_labels = np.concatenate(all_labels)
auc = roc_auc_score(all_labels, all_outputs)

# Compute accuracy
accuracy = correct / total
print(f"Test AUC: {auc:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")
