### Settings

In [1]:
import os
import time
import numpy as np
import scanpy as sc
import torch
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from utils_ae import (SCDataset, dense_and_unique, evaluate, normalization_pro, normalization_rna,setup_seed)
import warnings
warnings.filterwarnings("ignore")

class Args:
    def __init__(self):
        self.batch_size = 256 
        self.test_batch_size = 256
        self.epochs = 30 
        self.lr = 1e-3
        self.no_cuda = True
        self.seed = 1105
        self.repeat = 1
        self.frac_finetune_test = 0.1
        self.resume = False
        self.RNA_path = '/dataset/mBrain/adata_RNA.h5ad'
        self.Pro_path = '/dataset/mBrain/adata_ADT.h5ad'
        self.method_flag = 'CoSRep'
        self.shared_dim_rna = 128 
        self.specific_dim_rna = 128 
        self.shared_dim_pro = 128 
        self.specific_dim_pro = 128 
        self.num_hidden_pro = 128 
        self.num_hidden_rna = 512   
        self.dropout = 0.25  
        self.dim_rna = None
        self.dim_pro = None

args = Args()
setup_seed(args.seed+args.repeat)

### Preprocess

In [2]:
#---- Load Single Cell Data ----#
# Load single-cell RNA and protein data from specified paths
script_directory = os.getcwd()
base_path = os.path.abspath(os.path.join(script_directory, '../'))
scRNA_adata = sc.read_h5ad(base_path + args.RNA_path)
scP_adata = sc.read_h5ad(base_path + args.Pro_path)

#---- Convert to Dense Matrix ----#
# Ensure data is in dense format and has unique indices
scP_adata = dense_and_unique(scP_adata)
scRNA_adata = dense_and_unique(scRNA_adata)

# Print basic information about the original data
print('Total number of origin RNA genes: ', scRNA_adata.n_vars)
print('Total number of origin proteins: ', scP_adata.n_vars)
print('Total number of origin cells: ', scRNA_adata.n_obs)
print('# of NAN in X', np.isnan(scRNA_adata.X).sum())
print('# of NAN in X', np.isnan(scP_adata.X).sum())

#--- Separate Training and Testing Set ---
# Split data into training, validation, and testing sets
train_val_index, test_index = train_test_split(
    scRNA_adata.obs.index, 
    test_size=0.1, 
    random_state=args.seed + args.repeat
)
# Second split: Take 10% of the remaining 90% as validation set (effectively 9% of the original data)
train_index, val_index = train_test_split(
    train_val_index, 
    test_size=0.1111,  
    random_state=args.seed + args.repeat
)

# Assign data subsets based on indices
train_rna = scRNA_adata[train_index]
test_rna = scRNA_adata[test_index]
train_protein = scP_adata[train_index]
test_protein = scP_adata[test_index]
val_rna = scRNA_adata[val_index]
val_protein = scP_adata[val_index]

#---- Normalization ----#
# Normalize RNA and protein data for training, testing, and validation sets
train_rna = normalization_rna(train_rna)
train_protein = normalization_pro(train_protein)
test_rna = normalization_rna(test_rna)
test_protein = normalization_pro(test_protein)
val_rna = normalization_rna(val_rna)
val_protein = normalization_pro(val_protein)

Total number of origin RNA genes:  19848
Total number of origin proteins:  34
Total number of origin cells:  13052
# of NAN in X 0
# of NAN in X 0


### Build dataloaders

In [3]:
# Create data loaders for training, testing, and validation datasets.
# Each DataLoader wraps the RNA and protein data using the SCDataset class,
# and loads data in batches according to the specified batch size.
# The training loader shuffles the data, while the test and validation loaders maintain order.
train_loader = DataLoader(SCDataset(train_rna, train_protein), batch_size=args.batch_size, shuffle=True, drop_last=False)
test_loader = DataLoader(SCDataset(test_rna, test_protein), batch_size=args.test_batch_size, shuffle=False, drop_last=False)
val_loader = DataLoader(SCDataset(val_rna, val_protein), batch_size=args.test_batch_size, shuffle=False, drop_last=False)

# Set input and output dimensions
# Configure the model's input and output dimensions based on the number of features in scRNA_adata and scP_adata.
args.dim_rna = scRNA_adata.shape[1]
args.dim_pro = scP_adata.shape[1]


In [4]:
### Train the model

In [5]:
# Import the CoSRep model
from model import CoSRep

# Check if CUDA is available and set the device accordingly
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Initialize the model and move it to the selected device
model = CoSRep(args).to(device)

# Set up the Adam optimizer with the specified learning rate
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Record the start time for training duration calculation
start_time = time.time()

# Initialize variables for tracking the best validation loss and early stopping
best_loss = float('inf')
epochs_no_improve = 0
early_stop_patience = 3  # Number of epochs to wait before early stopping

# Create directory for saving the best model
os.makedirs('./best_model', exist_ok=True)

# Extract dataset name from the RNA path for model saving
dataset_name = args.RNA_path.strip('/').split('/')[1]
best_model_path = f'./best_model/{dataset_name}_{args.method_flag}_{args.repeat}_best_model.pth'

# Training loop over the specified number of epochs
for epoch in range(args.epochs):
    # Set model to training mode
    model.train()
    total_loss = 0
    
    # Iterate over the training data loader
    for rna, pro in train_loader:
        # Move data to the selected device and convert to float
        rna = rna.to(device).float()
        pro = pro.to(device).float()
        
        # Forward pass: compute predictions and loss
        pred, loss, _ = model(rna, pro)  # Use protein data as ground truth for loss calculation
        
        # Backward pass: compute gradients and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Accumulate the total loss
        total_loss += loss.item()

    # Calculate average training loss for the epoch
    train_loss = total_loss / len(train_loader)
    
    # Save the current best model based on validation loss

    # Set model to evaluation mode
    model.eval()
    all_val_preds, all_val_trues = [], []
    val_loss = float('inf')
    
    # Evaluate on the validation set
    with torch.no_grad():
        for rna, pro in val_loader:
            # Move data to the selected device and convert to float
            rna = rna.to(device).float()
            pro = pro.to(device).float()
            
            # Forward pass: compute predictions and validation loss
            pred, _, val_loss = model(rna, pro)
            all_val_preds.append(pred.cpu())
            all_val_trues.append(pro.cpu())
    
    # Concatenate predictions and ground truths for evaluation
    pred_all = torch.cat(all_val_preds)
    true_all = torch.cat(all_val_trues)
    
    # Update best model if validation loss improves
    if val_loss < best_loss:
        best_loss = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), best_model_path)
    else:
        epochs_no_improve += 1
        # Trigger early stopping if no improvement for specified patience
        if epochs_no_improve >= early_stop_patience:
            print(f"Early stopping triggered at epoch {epoch + 1}")
            break
            
    # ðŸ§ª Evaluate on the test set (do not modify)
    model.eval()
    all_preds, all_trues = [], []
    
    with torch.no_grad():
        for rna, pro in test_loader:
            # Move data to the selected device and convert to float
            rna = rna.to(device).float()
            
            # Forward pass: compute predictions
            pred, _, _, _ = model(rna)
            all_preds.append(pred.cpu())
            all_trues.append(pro.cpu())
    
    # Concatenate predictions and ground truths for final evaluation
    pred_all = torch.cat(all_preds)
    true_all = torch.cat(all_trues)

    # Compute evaluation metrics
    rmse, ccc, pcc_cell, pcc_pro = evaluate(pred_all, true_all)
    
    # Print epoch results
    print(f'[CosRep] Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, RMSE={rmse:.4f}, CCC={ccc:.4f}, pcc_cell={pcc_cell:.4f}, pcc_pro={pcc_pro:.4f}')

[CosRep] Epoch 1, Train Loss: 1.7837, Val Loss: 0.5307, RMSE=0.4872, CCC=0.5342, pcc_cell=0.5583, pcc_pro=0.5794
[CosRep] Epoch 2, Train Loss: 1.2931, Val Loss: 0.4360, RMSE=0.4421, CCC=0.6110, pcc_cell=0.5929, pcc_pro=0.6230
[CosRep] Epoch 3, Train Loss: 1.1387, Val Loss: 0.4353, RMSE=0.4291, CCC=0.6440, pcc_cell=0.5901, pcc_pro=0.6329
[CosRep] Epoch 4, Train Loss: 1.0222, Val Loss: 0.3946, RMSE=0.4379, CCC=0.6355, pcc_cell=0.5755, pcc_pro=0.6233
[CosRep] Epoch 5, Train Loss: 0.9459, Val Loss: 0.4064, RMSE=0.4398, CCC=0.6460, pcc_cell=0.5656, pcc_pro=0.6339
[CosRep] Epoch 6, Train Loss: 0.8755, Val Loss: 0.4080, RMSE=0.4318, CCC=0.6402, pcc_cell=0.5691, pcc_pro=0.6355
[CosRep] Epoch 7, Train Loss: 0.8381, Val Loss: 0.3868, RMSE=0.4423, CCC=0.6466, pcc_cell=0.5659, pcc_pro=0.6367
[CosRep] Epoch 8, Train Loss: 0.8209, Val Loss: 0.3898, RMSE=0.4438, CCC=0.6530, pcc_cell=0.5687, pcc_pro=0.6452
[CosRep] Epoch 9, Train Loss: 1.0255, Val Loss: 0.3973, RMSE=0.4382, CCC=0.6367, pcc_cell=0.5587

### Test the model

In [6]:
# Initialize variable to store predictions
pred_all = None

#  Test set evaluation
#  Load the best model for final inference
model.load_state_dict(torch.load(best_model_path))
model.eval()

# Initialize lists to store predictions, ground truths, and latent representations
all_preds, all_trues, all_fused_rna, all_shared_rna, all_specific_rna = [], [], [], [], []

# Perform inference on the test set without gradient computation
with torch.no_grad():
    for rna, pro in test_loader:
        # Move RNA data to the selected device and convert to float
        rna = rna.to(device).float()
        
        # Forward pass: get predictions and latent representations
        pred, z_fused_rna, z_shared_rna, z_specific_rna = model(rna)
        
        # Store predictions, ground truths, and latent representations
        all_preds.append(pred.cpu())
        all_trues.append(pro.cpu())
        all_fused_rna.append(z_fused_rna.cpu())
        all_shared_rna.append(z_shared_rna.cpu())
        all_specific_rna.append(z_specific_rna.cpu())

# Concatenate all predictions, ground truths, and latent representations
pred_all = torch.cat(all_preds)
true_all = torch.cat(all_trues)
fused_rna_all = torch.cat(all_fused_rna)
shared_rna_all = torch.cat(all_shared_rna)
specific_rna_all = torch.cat(all_specific_rna)

# Evaluate the model's performance using custom metrics
rmse, ccc, pcc_cell, pcc_pro = evaluate(pred_all, true_all)

# Print evaluation results
print(f'[CosRep] RMSE={rmse:.4f}, CCC={ccc:.4f}, pcc_cell={pcc_cell:.4f}, pcc_pro={pcc_pro:.4f}')

# Calculate and print the total training time
elapsed_time = time.time() - start_time
print(f'Training Time: {elapsed_time // 60}m {elapsed_time % 60}s')

[CosRep] RMSE=0.4411, CCC=0.6519, pcc_cell=0.5559, pcc_pro=0.6413
Training Time: 1.0m 46.54414081573486s
