## SCpliceVAE Interactive Test Notebook

#### 1. Setup and Imports

In [1]:
# ==============================================================================
# 1. CLASS DEFINITIONS HERE AND IMPORTS
# ==============================================================================

import os
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import scipy.sparse as sp
from scipy.sparse import csr_matrix, issparse
from collections import defaultdict
import json 
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import anndata as ad
import os
import matplotlib.pyplot as plt  # For plotting results
from typing import Dict, List, Optional, Tuple
from sklearn.model_selection import train_test_split

# Ensure the path to your module is in sys.path
# Modify this to point to where your modules are located
module_path = '/gpfs/commons/home/kisaev/multivi_tools_splicing/src/SCplice_vae'  # Change this to your module path
if module_path not in sys.path:
    sys.path.append(module_path)

from dataloaders import * 

from partial_vae import (
    PartialEncoder, 
    PartialDecoder, 
    PartialVAE, 
    binomial_loss_function, 
    beta_binomial_loss_function,
)

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


#### 2. Load Some Simulated Data for Testing

In [2]:
# Load AnnData object with simulated single cell splicing data 
# Create synthetic data
adata = ad.read_h5ad('/gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/SIMULATED/simulated_data_2025-03-12.h5ad')

# Basic info about the data
print("\nAnnData Summary:")
print(adata)
print("\nLayers:")
for layer_name, layer in adata.layers.items():
    print(f"  {layer_name}: {type(layer)}, shape {layer.shape}")


AnnData Summary:
AnnData object with n_obs × n_vars = 19942 × 9798
    obs: 'cell_id', 'age', 'batch', 'cell_ontology_class', 'method', 'mouse.id', 'sex', 'tissue', 'old_cell_id_index', 'cell_clean', 'cell_id_index', 'subtissue_clean', 'cell_type_grouped', 'cell_type'
    var: 'junction_id', 'event_id', 'splice_motif', 'label_5_prime', 'label_3_prime', 'annotation_status', 'gene_name', 'gene_id', 'num_junctions', 'position_off_5_prime', 'position_off_3_prime', 'CountJuncs', 'non_zero_count_cells', 'non_zero_cell_prop', 'annotation_status_score', 'non_zero_cell_prop_score', 'splice_motif_score', 'junction_id_index', 'chr', 'start', 'end', 'index', '0', '1', '2', '3', '4', '5', '6', '7', '8', 'sample_label', 'difference', 'true_label'
    uns: 'age_colors', 'cell_type_colors', 'neighbors', 'pca_explained_variance_ratio', 'tissue_colors', 'umap'
    obsm: 'X_leafletFA', 'X_pca', 'X_umap', 'phi_init_100_waypoints', 'phi_init_30_waypoints'
    varm: 'psi_init_100_waypoints', 'psi_init_30_w



In [3]:
# Only keep the layers you're interested in
trimmed_layers = {
    key: adata.layers[key]
    for key in ["junc_ratio", "cell_by_cluster_matrix", "cell_by_junction_matrix"]
}

# Create trimmed AnnData
adata_trimmed = ad.AnnData(
    X=None,  # Don't include full X matrix
    obs=adata.obs.copy(),
    var=adata.var.copy(),
    layers=trimmed_layers
)

In [4]:
# --- Configuration ---
X_LAYER_NAME = 'junc_ratio'
JUNCTION_COUNTS_LAYER_NAME = 'cell_by_junction_matrix'
CLUSTER_COUNTS_LAYER_NAME = 'cell_by_cluster_matrix'
BATCH_SIZE = 512 # Adjust as needed
NUM_WORKERS = 2 # Adjust based on your system

# --- Create Train/Validation Split ---
all_indices = np.arange(adata_trimmed.n_obs)
train_indices, val_indices = train_test_split(all_indices, test_size=0.3, random_state=42) # 10% validation

# --- Create Datasets ---
train_dataset = AnnDataDataset(
    adata_trimmed,
    x_layer=X_LAYER_NAME,
    junction_counts_layer=JUNCTION_COUNTS_LAYER_NAME,
    cluster_counts_layer=CLUSTER_COUNTS_LAYER_NAME,
    obs_indices=train_indices.tolist() # Pass the list of indices for the training set
)

val_dataset = AnnDataDataset(
    adata_trimmed,
    x_layer=X_LAYER_NAME,
    junction_counts_layer=JUNCTION_COUNTS_LAYER_NAME,
    cluster_counts_layer=CLUSTER_COUNTS_LAYER_NAME,
    obs_indices=val_indices.tolist() # Pass the list of indices for the validation set
)

# --- Create DataLoaders ---
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,         # Don't shuffle training data
    num_workers=NUM_WORKERS,
    pin_memory=True,      # Can speed up CPU->GPU transfer
    drop_last=False       # Keep the last batch even if smaller
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,        # No need to shuffle validation data
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=False
)

# --- Get dataset size and number of batches (for loss function scaling) ---
n_train_samples = len(train_dataset)
k_train_batches = len(train_loader)
n_val_samples = len(val_dataset)
k_val_batches = len(val_loader)

print(f"Training samples: {n_train_samples}, Batches: {k_train_batches}")
print(f"Validation samples: {n_val_samples}, Batches: {k_val_batches}")

Training samples: 13959, Batches: 28
Validation samples: 5983, Batches: 12


In [5]:
# ==============================================================================
# 2. CONFIGURATION & HYPERPARAMETERS
# ==============================================================================

# --- Data ---
X_LAYER_NAME = 'junc_ratio'
JUNCTION_COUNTS_LAYER_NAME = 'cell_by_junction_matrix'
CLUSTER_COUNTS_LAYER_NAME = 'cell_by_cluster_matrix'
INPUT_DIM = adata_trimmed.n_vars # Get input dimension from data
print(f"Input Dimension (n_vars): {INPUT_DIM}")

# --- Model Architecture ---
# INPUT_DIM will be set from data
CODE_DIM = 16             # Dimension K for feature embeddings (junction embeddings)
H_HIDDEN_DIM = 64         # Hidden dim for encoder's h_layer 
ENCODER_HIDDEN_DIM = 128  # Hidden dim for encoder's final MLP
LATENT_DIM = 10           # Dimension Z for latent space
DECODER_HIDDEN_DIM = 128  # Hidden dim for decoder
DROPOUT_RATE = 0.01

# --- Training ---
LOSS_TYPE = 'binomial' # Choose 'binomial' or 'beta_binomial'
LEARN_CONCENTRATION = True  # Set True for beta-binomial if you want learnable concentration
FIXED_CONCENTRATION = None  # Set to a float (e.g., 10.0) if using beta-binomial with FIXED concentration
                            # If set, overrides LEARN_CONCENTRATION=True

NUM_EPOCHS = 100          # Max number of epochs
LEARNING_RATE = 0.01
PATIENCE = 10             # Early stopping patience (epochs)
SCHEDULE_STEP_SIZE = 10   # LR scheduler step size
SCHEDULE_GAMMA = 0.1      # LR scheduler factor

# --- Output & Logging ---
OUTPUT_DIR = "./vae_training_output" # Directory to save results

Input Dimension (n_vars): 9798


In [6]:
# ==============================================================================
# 3. MODEL INITIALIZATION & LOSS FUNCTION SELECTION
# ==============================================================================

# --- Setup Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Instantiate Model ---
# Determine if concentration should be learned based on settings
should_learn_concentration = (LOSS_TYPE == 'beta_binomial') and (FIXED_CONCENTRATION is None) and LEARN_CONCENTRATION

model = PartialVAE(
    input_dim=INPUT_DIM,
    code_dim=CODE_DIM,
    h_hidden_dim=H_HIDDEN_DIM,
    encoder_hidden_dim=ENCODER_HIDDEN_DIM,
    latent_dim=LATENT_DIM,
    decoder_hidden_dim=DECODER_HIDDEN_DIM,
    dropout_rate=DROPOUT_RATE,
    learn_concentration=should_learn_concentration # Pass the determined flag
)
model.to(device)
print("Model initialized:")
print(model)

# --- Choose Loss Function ---
if LOSS_TYPE == 'binomial':
    chosen_loss_function = binomial_loss_function
    print("Using Binomial Loss.")
    if FIXED_CONCENTRATION is not None or LEARN_CONCENTRATION:
         print("Warning: Concentration parameters ignored for binomial loss.")
elif LOSS_TYPE == 'beta_binomial':
    chosen_loss_function = beta_binomial_loss_function
    print("Using Beta-Binomial Loss.")
    if FIXED_CONCENTRATION is not None:
        print(f"Using FIXED concentration: {FIXED_CONCENTRATION}")
    elif should_learn_concentration:
        print("Using LEARNABLE concentration.")
    else:
        print("Warning: Beta-binomial selected but no concentration specified (fixed or learnable). Check config.")
else:
    raise ValueError(f"Unknown LOSS_TYPE: '{LOSS_TYPE}'. Choose 'binomial' or 'beta_binomial'.")


Using device: cuda
Model initialized:
PartialVAE(
  (encoder): PartialEncoder(
    (h_layer): Sequential(
      (0): Linear(in_features=18, out_features=64, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.01, inplace=False)
      (3): Linear(in_features=64, out_features=16, bias=True)
      (4): ReLU()
    )
    (encoder_mlp): Sequential(
      (0): Linear(in_features=16, out_features=128, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.01, inplace=False)
      (3): Linear(in_features=128, out_features=20, bias=True)
    )
  )
  (decoder): PartialDecoder(
    (z_processor): Sequential(
      (0): Linear(in_features=10, out_features=128, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.01, inplace=False)
    )
    (j_layer): Sequential(
      (0): Linear(in_features=145, out_features=128, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.01, inplace=False)
      (3): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
Using Binomial Loss.


In [7]:
# check the first element in the train dataloader 
for batch in train_loader:
    print("Batch keys:", batch.keys())
    print("Batch shape:", {k: v.shape for k, v in batch.items()})
    break

Batch keys: dict_keys(['x', 'mask', 'junction_counts', 'cluster_counts'])
Batch shape: {'x': torch.Size([512, 9798]), 'mask': torch.Size([512, 9798]), 'junction_counts': torch.Size([512, 9798]), 'cluster_counts': torch.Size([512, 9798])}


In [None]:
# ==============================================================================
# 4. TRAINING EXECUTION
# ==============================================================================

# Create output directory if it doesn't exist
if OUTPUT_DIR:
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    print(f"Output will be saved to: {os.path.abspath(OUTPUT_DIR)}")

# --- Start Training ---
try:
    train_losses, val_losses, epochs_trained = model.train_model(
        loss_function=chosen_loss_function,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        num_epochs=NUM_EPOCHS,
        learning_rate=LEARNING_RATE,
        patience=PATIENCE,
        fixed_concentration=FIXED_CONCENTRATION, # Pass the fixed value if set
        schedule_step_size=SCHEDULE_STEP_SIZE,
        schedule_gamma=SCHEDULE_GAMMA,
        output_dir=OUTPUT_DIR,
        # --- Specify keys matching AnnDataDataset output ---
        input_key='x',
        mask_key='mask',
        junction_counts_key='junction_counts',
        cluster_counts_key='cluster_counts'
    )
except Exception as e:
    print(f"\nAn error occurred during training: {e}")
    import traceback
    traceback.print_exc() # Print detailed traceback
    exit()


Output will be saved to: /gpfs/commons/home/kisaev/multivi_tools_splicing/testing/vae_training_output
Beginning training on device: cuda:0
Epoch 001/100 | Train Loss: 11405411.6925 | LR: 1.0e-02
          | Val Loss:   81754.1829
          | Val loss improved (inf -> 81754.1829). Saving model to ./vae_training_output/model/best_model.pth
Epoch 002/100 | Train Loss: 245180.5818 | LR: 1.0e-02
          | Val Loss:   145214.1185
          | Val loss did not improve. Bad epochs: 1/10
Epoch 003/100 | Train Loss: 115836.0516 | LR: 1.0e-02
          | Val Loss:   117241.7461
          | Val loss did not improve. Bad epochs: 2/10
Epoch 004/100 | Train Loss: 215245.4170 | LR: 1.0e-02
          | Val Loss:   115809.9805
          | Val loss did not improve. Bad epochs: 3/10
Epoch 005/100 | Train Loss: 102042.7833 | LR: 1.0e-02
          | Val Loss:   46610.8962
          | Val loss improved (81754.1829 -> 46610.8962). Saving model to ./vae_training_output/model/best_model.pth
Epoch 006/100 | Tra

In [None]:
# ==============================================================================
# 5. POST-TRAINING (Example: Plot Losses)
# ==============================================================================
print("\nTraining complete.")

if train_losses and val_losses:
    plt.figure(figsize=(10, 6))
    epochs = range(1, epochs_trained + 1)
    plt.plot(epochs, train_losses, label='Training Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('VAE Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    # Save the plot
    plot_path = "training_loss_plot.png"
    if OUTPUT_DIR:
        plot_path = os.path.join(OUTPUT_DIR, "training_loss_plot.png")
    plt.savefig(plot_path)
    print(f"Loss plot saved to {plot_path}")
    plt.show() # Uncomment to display plot interactively

In [None]:
adata.obs

In [None]:
# --- Select Data for Visualization (e.g., all data or a subset) ---
# Using all data here, adjust if needed (e.g., use val_indices from training)
indices_to_use = np.arange(adata.n_obs)
print(f"Getting latent representations for {len(indices_to_use)} cells...")

CELL_TYPE_COLUMN = "cell_type"

# Ensure labels are in a usable format (e.g., strings)
cell_labels = adata.obs[CELL_TYPE_COLUMN][indices_to_use].astype(str).values

# --- Create Dataset and DataLoader for Inference ---
# Ensure AnnDataDataset uses the BOOLEAN mask fix
inference_dataset = AnnDataDataset(
    adata,
    x_layer=X_LAYER_NAME,
    junction_counts_layer=JUNCTION_COUNTS_LAYER_NAME, # Still needed by dataset init
    cluster_counts_layer=CLUSTER_COUNTS_LAYER_NAME, # Still needed by dataset init
    obs_indices=indices_to_use.tolist()
)

inference_loader = DataLoader(
    inference_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False, # DO NOT shuffle for inference if matching labels
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=False
)

# --- Iterate and Collect Latent Representations ---
latent_reps_list = []
with torch.no_grad(): # Ensure no gradients are calculated
    for i, batch in enumerate(inference_loader):
        print(f"  Processing batch {i+1}/{len(inference_loader)}...", end='\r')
        x_batch = batch['x'] # Data loader provides tensors
        mask_batch = batch['mask'] # Data loader provides tensors

        # Use the model's method (handles device and eval mode)
        latent_batch_np = model.get_latent_rep(x_batch, mask_batch)
        latent_reps_list.append(latent_batch_np)

print("\nFinished collecting latent representations.")

# Concatenate all batch results
all_latent_reps = np.concatenate(latent_reps_list, axis=0)
print(f"Shape of collected latent representations: {all_latent_reps.shape}") # Should be (n_obs, latent_dim)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import anndata as ad
import os
import matplotlib.pyplot as plt
import seaborn as sns # For better plotting aesthetics
from umap import UMAP # Import UMAP
from typing import Dict, List, Optional, Tuple

# --- Visualization ---
# UMAP parameters (can be tuned)
N_NEIGHBORS = 15
MIN_DIST = 0.1
UMAP_METRIC = 'euclidean' # Distance metric in latent space

In [None]:
# ==============================================================================
# 5. DIMENSIONALITY REDUCTION (UMAP)
# ==============================================================================

print("Running UMAP for dimensionality reduction...")
reducer = UMAP(
    n_components=2,       # Reduce to 2 dimensions for plotting
    n_neighbors=N_NEIGHBORS,    # Controls local vs global structure (adjust)
    min_dist=MIN_DIST,      # Controls tightness of clusters (adjust)
    metric=UMAP_METRIC,     # Distance metric in the latent space
    random_state=42       # For reproducibility
)

embedding_2d = reducer.fit_transform(all_latent_reps)
print(f"Shape of 2D UMAP embedding: {embedding_2d.shape}") # Should be (n_obs, 2)

# ==============================================================================
# 6. PLOTTING
# ==============================================================================

print("Generating plot...")

# --- Create the plot ---
plt.figure(figsize=(12, 10)) # Adjust figure size as needed

# Use seaborn for potentially nicer aesthetics and easier legend handling
# Choose a suitable palette, 'tab20' works for up to 20 categories, 'viridis'/'plasma' for continuous-like
# Adjust 's' for point size, 'alpha' for transparency
num_unique_labels = len(np.unique(cell_labels))
palette = sns.color_palette('tab20', n_colors=num_unique_labels) # Example palette

scatter = sns.scatterplot(
    x=embedding_2d[:, 0],
    y=embedding_2d[:, 1],
    hue=cell_labels,      # Color points by cell type
    palette=palette,      # Color map
    s=5,                  # Point size
    alpha=0.7,            # Point transparency
    linewidth=0           # No border around points
)

# --- Customize plot ---
plt.title(f'UMAP Projection of VAE Latent Space (Z={LATENT_DIM})', fontsize=16)
plt.xlabel('UMAP 1', fontsize=12)
plt.ylabel('UMAP 2', fontsize=12)
plt.xticks([]) # Hide axis ticks for cleaner look
plt.yticks([])
