# SPACER Quick Start Notebook


In [None]:
·

In [9]:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import scanpy as sc
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from model.dataset import BagsDataset, custom_collate_fn
from model.model import MIL, EarlyStopping

In [2]:

def load_all_genes(reference_gene_file):
    all_genes = pd.read_csv(reference_gene_file)
    return all_genes['Gene'].values.tolist()


Define your interested geneset, in this study we use all human/mouse genes as our reference geneset


In [3]:
#define your interested geneset, in this study we use all human/mouse genes as our reference geneset
all_genes = pd.read_csv('data/human_filtered.csv')
all_genes

Unnamed: 0,Gene
0,C1orf141
1,PKP1
2,HIVEP3
3,GLMN
4,SLC44A5
...,...
23177,EPCAM
23178,CEACAM21
23179,CEACAM6
23180,KRT8


In [4]:
all_genes = all_genes['Gene'].values.tolist()

Load the data

In [12]:
# Load dataset and create DataLoader(details data structure in data preparation section)
adata = sc.read_h5ad('/work/OSPH/s439765/data4spacer/spatial_transcriptomics/VisiumHD/processed/Colon_Cancer_P2T_cell.h5ad')
dataset = BagsDataset(
    adata,
    immune_cell='tcell',
    radius=50,
    max_instances=500,
    n_genes=3000,
    resolution='high',
    k=2,  # Ensure 'k' matches the number of bags per batch
)

Immune cell: T
[1 0 2]
Tumor cells shape after filtering: (172613, 18085)
Selecting top 3000 genes based on mean expression
Preprocessed data: (334914, 3000)


Creating Bags with radius 50: 100%|███████████████████████| 334914/334914 [03:27<00:00, 1613.95it/s]


Total batches created: 6612


If you want run the model for multiple data

In [15]:
adatas = pd.read_csv('data/sample.csv')
adatas #make sure you have same data structure as in sample.csv

Unnamed: 0,adata,radius,resolution
0,/project/shared/cli_wang/spatial_TCR/data/trai...,150,low
1,/project/shared/cli_wang/spatial_TCR/data/trai...,150,low
2,/project/shared/cli_wang/spatial_TCR/data/trai...,150,low
3,/project/shared/cli_wang/spatial_TCR/data/trai...,150,low


In [17]:
dataset = BagsDataset(
    'data/sample.csv',
    immune_cell='tcell',
    max_instances=500,
    n_genes=3000,
    k=2,  # Ensure 'k' matches the number of bags per batch
)

Immune cell: T
Reading adata from /project/shared/cli_wang/spatial_TCR/data/train_validate/Visium/HumanOvarianCancer/T_cell.h5ad
['0', '1']
Categories (2, object): ['0', '1']
Tumor cells shape after filtering: (1226, 17943)
Selecting top 3000 genes based on mean expression
Percentile value: 4.840213894844055
adata.obs[T] after binarization: AAACAAGTATCTCCCA-1    0
AAACAATCTACTAGCA-1    0
AAACACCAATAACTGC-1    0
AAACAGAGCGACTCCT-1    0
AAACAGCTTTCAGAAG-1    0
Name: T, dtype: int64
Processing: adata=T_cell.h5ad, radius=150, resolution=low
Reading adata from /project/shared/cli_wang/spatial_TCR/data/train_validate/Visium/HumanOvarianCancerWholeTranscriptome/T_cell.h5ad
['1', '0']
Categories (2, object): ['0', '1']
Tumor cells shape after filtering: (3043, 36601)
Selecting top 3000 genes based on mean expression
Percentile value: 0.0
adata.obs[T] after binarization: AAACAAGTATCTCCCA-1    0
AAACACCAATAACTGC-1    1
AAACAGGGTCTATATT-1    0
AAACATTTCCCGGATT-1    0
AAACCCGAACGAAATC-1    0
Name:

Creating Bags with radius 150: 100%|█████████████████████████| 3455/3455 [00:00<00:00, 19170.34it/s]
Creating Bags with radius 150: 100%|██████████████████████████| 3491/3491 [00:00<00:00, 8365.62it/s]
Creating Bags with radius 150: 100%|█████████████████████████| 3138/3138 [00:00<00:00, 12010.87it/s]
Creating Bags with radius 150: 100%|██████████████████████████| 4671/4671 [00:00<00:00, 8497.31it/s]


Total batches created: 2354


In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print(f"Using device: {device} ({torch.cuda.get_device_name(torch.cuda.current_device())})")
else:
    print(f"Using device: {device}")
print("=====================================")

Using device: cpu


In [19]:
model = MIL(all_genes).to(device)
criterion = nn.BCELoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.02)
early_stopping = EarlyStopping(patience=5, delta=0.001)

In [22]:
output_dir = 'sample_output'
os.makedirs(output_dir, exist_ok=True)

In [23]:
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)

In [24]:
best_val_loss = float('inf')
best_model_path = os.path.join(output_dir, 'best_model.pth')

# Save spacer scores before training
spacer_scores_before_training = model.immunogenicity.ig.clone().detach().cpu()
spacer_scores_before_training = [score.item() for score in spacer_scores_before_training]

In [29]:
def save_metrics(epoch, train_loss, val_loss, val_auroc, a, b, alpha, beta, output_dir):
    file_path = os.path.join(output_dir, 'training_metrics.csv')
    if not os.path.exists(file_path):
        # Create the CSV file with headers
        with open(file_path, 'w') as f:
            f.write('Epoch,Train Loss,Val Loss,Val AUROC,a,b,alpha,beta\n')
    
    # Append metrics for the current epoch
    with open(file_path, 'a') as f:
        f.write(f'{epoch},{train_loss},{val_loss},{val_auroc},{a},{b},{alpha},{beta}\n')

def save_spacer_scores(epoch, all_genes, spacer_scores_before_training, spacer_scores_after_training, output_dir):
    # Create a DataFrame with IG scores before and after the current epoch
    spacer_score_data = {
        'Gene': all_genes,
        'SPACER Score Before Training': spacer_scores_before_training,
        'SPACER Score After Training': spacer_scores_after_training,
    }
    df = pd.DataFrame(spacer_score_data)
    
    # Calculate the difference and add it as a new column
    df['Difference'] = df['SPACER Score After Training'] - df['SPACER Score Before Training']
    df = df.sort_values(by='Difference', ascending=False)

    # Save to a CSV file for each epoch
    output_path = os.path.join(output_dir, f'spacer_score_changes_epoch_{epoch+1}.csv')
    df.to_csv(output_path, index=False)

Training

In [32]:
num_epochs = 2
selection = 'negative' # Choose 'positive(induce)' or 'negative(repel)' based on your research focus

In [33]:

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
        
    # Lists to store outputs and labels for AUROC calculation
    all_outputs = []
    all_labels = []
        
    with tqdm(train_loader, unit="batch") as tepoch:
        for i, batch_data in enumerate(tepoch):
            tepoch.set_description(f"Epoch {epoch+1}/{num_epochs}")
            optimizer.zero_grad()

            # Unpack the batch data
            distances_list, gene_expressions_list, labels_list, core_idxs_list, gene_names_list, cell_ids_list = batch_data
                
            # Move data to device and prepare labels
            distances_list = [distances.to(device) for distances in distances_list]
            gene_expressions_list = [gene_exp.to(device) for gene_exp in gene_expressions_list]
            labels = torch.stack(labels_list).float().to(device)
            current_genes_list = gene_names_list  # List of gene names for each bag

            # Forward pass
            outputs = model(distances_list, gene_expressions_list, current_genes_list)
                
            if outputs is None:
                 continue  # Skip this batch if the model returns None
                
            if outputs.shape[0] != labels.shape[0]:
                # Handle mismatch in batch sizes if necessary
                continue
                
            # Compute BCE loss
            if selection == 'negative':
                labels = 1 - labels
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
            running_loss += loss.item()
            tepoch.set_postfix(loss=loss.item())
                
            # Accumulate outputs and labels for AUROC calculation
            all_outputs.extend(outputs.detach().cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    train_loss = running_loss / len(train_loader)
        
    # Compute Training AUROC
    try:
        epoch_auc = roc_auc_score(all_labels, all_outputs)
    except ValueError:
        epoch_auc = float('nan')  # Handle case where AUROC can't be computed
        
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, AUROC: {epoch_auc:.4f}')
        
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_all_outputs = []
    val_all_labels = []
    with torch.no_grad():
        with tqdm(val_loader, unit="batch") as vtepoch:
            for val_batch_data in vtepoch:
                # Unpack validation batch data
                val_distances_list, val_gene_expressions_list, val_labels_list, val_core_idxs_list, val_gene_names_list, val_cell_ids_list = val_batch_data
                    
                # Move data to device and prepare labels
                val_distances_list = [distances.to(device) for distances in val_distances_list]
                val_gene_expressions_list = [gene_exp.to(device) for gene_exp in val_gene_expressions_list]
                val_labels = torch.stack(val_labels_list).float().to(device)
                val_current_genes_list = val_gene_names_list  # List of gene names for each bag
                    
                # Forward pass
                val_outputs = model(val_distances_list, val_gene_expressions_list, val_current_genes_list)
                    
                if val_outputs is None:
                    continue  # Skip this batch if the model returns None
                    
                if val_outputs.shape[0] != val_labels.shape[0]:
                    # Handle mismatch in batch sizes if necessary
                    continue
                    
                # Compute BCE loss
                if selection == 'negative':
                    val_labels = 1 - val_labels
                loss = criterion(val_outputs, val_labels)
                val_loss += loss.item()
                vtepoch.set_postfix(val_loss=loss.item())
                    
                # Accumulate outputs and labels for AUROC calculation
                val_all_outputs.extend(val_outputs.detach().cpu().numpy())
                val_all_labels.extend(val_labels.cpu().numpy())
            
        val_loss /= len(val_loader)
            
            # Compute Validation AUROC
        try:
            val_epoch_auc = roc_auc_score(val_all_labels, val_all_outputs)
        except ValueError:
            val_epoch_auc = float('nan')  # Handle case where AUROC can't be computed
            
        print(f'Validation Loss: {val_loss:.4f}, Validation AUROC: {val_epoch_auc:.4f}')

    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"Best model saved with validation loss {val_loss:.4f}")
            
    torch.save(model.state_dict(), os.path.join(output_dir, f'model_epoch_{epoch+1}.pth'))
        
    a = model.distance.a.clone().detach().cpu().numpy()
    b = model.gene_expression.b.clone().detach().cpu()
    alpha = model.alpha.clone().detach().cpu()
    beta = model.beta.clone().detach().cpu()
    # Save metrics
    save_metrics(epoch+1, train_loss, val_loss, val_epoch_auc,a,b,alpha,beta, output_dir)

    # Save IG scores after each epoch
    spacer_scores_after_training = model.immunogenicity.ig.clone().detach().cpu()
    spacer_scores_after_training = [score.item() for score in spacer_scores_after_training]
    save_spacer_scores(epoch, all_genes, spacer_scores_before_training, spacer_scores_after_training, output_dir)

Epoch 1/2: 100%|██████████| 2118/2118 [00:12<00:00, 172.67batch/s, loss=0.672]


Epoch [1/2], Loss: 0.6846, AUROC: 0.5780


100%|██████████| 236/236 [00:01<00:00, 213.92batch/s, val_loss=0.519]


Validation Loss: 0.6849, Validation AUROC: 0.5764
Best model saved with validation loss 0.6849


Epoch 2/2: 100%|██████████| 2118/2118 [00:12<00:00, 172.50batch/s, loss=0.587]


Epoch [2/2], Loss: 0.6830, AUROC: 0.5834


100%|██████████| 236/236 [00:00<00:00, 250.45batch/s, val_loss=0.62] 


Validation Loss: 0.6848, Validation AUROC: 0.5816
Best model saved with validation loss 0.6848
