### Tutorial 3: Training a WSI Classification Model with ABMIL

This tutorial will guide you step-by-step to train an attention-based multiple instance learning model using Trident patch embeddings. Then, we will generation attention heatmaps using the pretrained model. 


#### A- Installation and patch feature extraction using UNI

#### Step 1: Download a dataset of whole-slide images

You can use your own WSIs or download a publicly available dataset, e.g. from:

- **CPTAC CCRCC WSIs**: Download from the [TCIA Cancer Imaging Archive](https://www.cancerimagingarchive.net/collection/cptac-ccrcc/).
- **Store WSIs**: Save all WSIs into a local directory, e.g.,  
  ```bash
  ./CPTAC-CCRCC_v1/CCRCC
  ```

#### Step 2:  Run CONCH v1.5 feature extraction:

Navigate to the base directory of Trident and execute the following command:

```bash
python run_batch_of_slides.py --task all \
  --wsi_dir ./CPTAC-CCRCC_v1/CCRCC \
  --job_dir ./tutorial-3 \
  --patch_encoder conch_v15 \
  --mag 20 \
  --patch_size 512
```


#### B- Download labels with data splits

Here, we use Patho-Bench CPTAC-CCRCC labels for predicting BAP1 mutation. 

In [None]:
import datasets
import pandas as pd

# Download labels as csv
datasets.load_dataset(
    'MahmoodLab/Patho-Bench', 
    cache_dir='./tutorial-3',
    dataset_to_download='cptac_ccrcc',     
    task_in_dataset='BAP1_mutation',           
    trust_remote_code=True
)

# Visualize my labels and splits
df = pd.read_csv('tutorial-3/cptac_ccrcc/BAP1_mutation/k=all.tsv', sep="\t")
df

In [None]:
# Check the label distribution
df_counts = df['BAP1_mutation'].value_counts().reset_index()
df_counts.columns = ['BAP1_mutation', 'Count']
df_counts


#### C- Training an ABMIL model

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score

from trident.slide_encoder_models import ABMILSlideEncoder

# Set deterministic behavior
SEED = 1234
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class BinaryClassificationModel(nn.Module):
    def __init__(self, input_feature_dim=768, n_heads=1, head_dim=512, dropout=0., gated=True, hidden_dim=256):
        super().__init__()
        self.feature_encoder = ABMILSlideEncoder(
            freeze=False,
            input_feature_dim=input_feature_dim, 
            n_heads=n_heads, 
            head_dim=head_dim, 
            dropout=dropout, 
            gated=gated
        )
        self.classifier = nn.Sequential(
            nn.Linear(input_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, return_raw_attention=False):
        if return_raw_attention:
            features, attn = self.feature_encoder(x, return_raw_attention=True)
        else:
            features = self.feature_encoder(x)
        logits = self.classifier(features).squeeze(1)
        
        if return_raw_attention:
            return logits, attn
        
        return logits

# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BinaryClassificationModel().to(device)

# Custom dataset
class H5Dataset(Dataset):
    def __init__(self, feats_path, df, split, num_features=512):
        self.df = df[df["fold_0"] == split]
        self.feats_path = feats_path
        self.num_features = num_features
        self.split = split
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        with h5py.File(os.path.join(self.feats_path, row['slide_id'] + '.h5'), "r") as f:
            features = torch.from_numpy(f["features"][:])

        if self.split == 'train':
            num_available = features.shape[0]
            if num_available >= self.num_features:
                indices = torch.randperm(num_available, generator=torch.Generator().manual_seed(SEED))[:self.num_features]
            else:
                indices = torch.randint(num_available, (self.num_features,), generator=torch.Generator().manual_seed(SEED))  # Oversampling
            features = features[indices]

        label = torch.tensor(row["BAP1_mutation"], dtype=torch.float32)
        return features, label

# Create dataloaders
feats_path = './tutorial-3/cptac_ccrcc/20x_512px_0px_overlap/features_conch_v15'
batch_size = 8
train_loader = DataLoader(H5Dataset(feats_path, df, "train"), batch_size=batch_size, shuffle=True, worker_init_fn=lambda _: np.random.seed(SEED))
test_loader = DataLoader(H5Dataset(feats_path, df, "test"), batch_size=1, shuffle=False, worker_init_fn=lambda _: np.random.seed(SEED))

# Training setup
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=4e-4)

# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.
    for features, labels in train_loader:
        features, labels = {'features': features.to(device)}, labels.to(device)
        optimizer.zero_grad()
        outputs = model(features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")


#### D- Evaluating the ABMIL model

In [None]:
# Evaluation
model.eval()
all_labels, all_outputs = [], []
correct = 0
total = 0

with torch.no_grad():
    for features, labels in test_loader:
        features, labels = {'features': features.to(device)}, labels.to(device)
        outputs = model(features)
        
        # Convert logits to probabilities and binary predictions
        predicted = (outputs > 0).float()  # Since BCEWithLogitsLoss expects raw logits
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        all_outputs.append(outputs.cpu().numpy())  
        all_labels.append(labels.cpu().numpy())

# Compute AUC
all_outputs = np.concatenate(all_outputs)
all_labels = np.concatenate(all_labels)
auc = roc_auc_score(all_labels, all_outputs)

# Compute accuracy
accuracy = correct / total
print(f"Test AUC: {auc:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")


#### E- Extract attention heatmap for the freshly trained model

In [None]:
from trident import OpenSlideWSI, visualize_heatmap
from trident.segmentation_models import segmentation_model_factory
from trident.patch_encoder_models import encoder_factory as patch_encoder_factory

# a. Load WSI to process
job_dir = './tutorial-3/heatmap_viz'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
slide = OpenSlideWSI(slide_path='./CPTAC-CCRCC_v1/CCRCC/C3L-00418-22.svs', lazy_init=False)

# b. Run segmentation 
segmentation_model = segmentation_model_factory("hest")
geojson_contours = slide.segment_tissue(segmentation_model=segmentation_model, job_dir=job_dir, device=device)

# c. Run patch coordinate extraction
coords_path = slide.extract_tissue_coords(
    target_mag=20,
    patch_size=512,
    save_coords=job_dir,
    overlap=256, 
)

# d. Run patch feature extraction
patch_encoder = patch_encoder_factory("conch_v15").eval().to(device)
patch_features_path = slide.extract_patch_features(
    patch_encoder=patch_encoder,
    coords_path=coords_path,
    save_features=os.path.join(job_dir, f"features_conch_v15"),
    device=device
)

#  e. Run inference 
with h5py.File(patch_features_path, 'r') as f:
    coords = f['coords'][:]
    patch_features = f['features'][:]
    coords_attrs = dict(f['coords'].attrs)

batch = {'features': torch.from_numpy(patch_features).float().to(device).unsqueeze(0)}
logits, attention = model(batch, return_raw_attention=True)

# f. generate heatmap
heatmap_save_path = visualize_heatmap(
    wsi=slide,
    scores=attention.cpu().numpy().squeeze(),  
    coords=coords,
    vis_level=1,
    patch_size_level0=coords_attrs['patch_size_level0'],
    normalize=True,
    num_top_patches_to_save=10,
    output_dir=job_dir
)

