In [1]:
import os
import torch
import argparse
import matplotlib.pyplot as plt

Notebook Config

In [2]:
trunk = "2i"  # Options: "serum", "2i", "both"

ReextractHVGs = False
RetrainAutoencoder = False
RetrainDensityModel = False 

# 0: Downloading the Schiebinger Dataset

In [3]:
%%bash
./SchiebingerDownload.sh

--- Schiebinger Dataset Downloader ---
Directory Data/Schiebinger already exists and is not empty.
Continuing: will attempt to extract/unpack any archives found in this directory.
Skipping cleanup as no new download was performed.
---
Dataset is located in: Data/Schiebinger


# 1: Extracting HVGs of chosen Branch


In [4]:
data_directory = "Data/Schiebinger"
output_directory = "Pipeline/HVGs"
num_highly_variable_genes = 2000

# Define the specific output filename
tensor_filename = f"schiebinger_hvg_tensor_trunk-{trunk}_{num_highly_variable_genes}hvg.pt"
tensor_path = os.path.join(output_directory, tensor_filename)



In [5]:
from Pipeline.firstSelectHVGs import run_hvg_extraction

if os.path.exists(tensor_path) and not ReextractHVGs:
    print(f"File {tensor_filename} already exists. Skipping HVG extraction.")
else:
    print(f"Starting HVG extraction...")
    # --- 3. Call the function directly with your parameters ---
    # This is clean, robust, and doesn't involve any argparse messiness.
    hvg_fig = run_hvg_extraction(
        data_dir=data_directory,
        output_dir=output_directory,
        output_file=tensor_filename,
        trunk=trunk,
        n_hvg=num_highly_variable_genes,
        min_counts=2000,  
        max_counts=50000,
        min_cells=50,
        debug=True
    )

    if hvg_fig:
        print("\nDisplaying diagnostic plot:")
        plt.show()



File schiebinger_hvg_tensor_trunk-2i_2000hvg.pt already exists. Skipping HVG extraction.


# 2: Training the Autoencoder

In [6]:
model_save_path = f"Models/Autoencoder/trunk-{trunk}.pt"
latent_save_path = f"LatentSpace/trunk-{trunk}_latent.pt"

bottleneck = 24

latent_dims = [660, 220 , 66, bottleneck]

batch_size = 64
overdispersion = 0.3

num_epochs = 30



In [7]:
from Pipeline.secondTrainAutoencoder import run_autoencoder_training

if os.path.exists(model_save_path) and not RetrainAutoencoder:
    print(f"Model already exists at {model_save_path}. Skipping training.")
else:
    # Call the training function directly with clear, explicit parameters.
    # This is robust, readable, and provides full IDE support.
    run_autoencoder_training(
        tensor_file=tensor_path,
        model_save_path=model_save_path,
        latent_save_path=latent_save_path,
        latent_dims=latent_dims,
        num_epochs=num_epochs,
        batch_size=batch_size,
        overdispersion=overdispersion,
        lr=5e-4,      
        val_split=0.2, 
        debug=False
    )

Model already exists at Models/Autoencoder/trunk-2i.pt. Skipping training.


## 2.1: Visualizing the Latent Space

In [8]:
from Genodesic.Visualizers import UMAP3D

latent_data_bundle = torch.load(latent_save_path)

# Extract the numpy arrays for plotting
latent_reps = latent_data_bundle['latent_reps'].numpy()
timepoints = latent_data_bundle['timepoints'].numpy().flatten()



print(f"Loaded {latent_reps.shape[0]} latent vectors.")

# UMAP3D(
#     latent_reps=latent_reps,
#     color_by_timepoints=timepoints,
#     title=f"Latent Space UMAP (Trunk: {trunk})"
# )

cuML found. Using GPU for UMAP acceleration.
Loaded 172756 latent vectors.


# 3: Setting up Density Models

In [9]:
from Scripts.train import run_training

model_save_path = "Models/DensityModels/vpsde.pt"

notebook_overrides = {
    "model_type": "vpsde", # Options: "vpsde", "otcfm", "rqnsf"
    "data_file": latent_save_path, 
    "model_save_path": model_save_path,
    "dim": bottleneck,
    "num_epochs": 50,
    "batch_size": 64
}

if RetrainDensityModel:
    trained_model = run_training(config_overrides=notebook_overrides)

In [10]:
from Genodesic.DensityModels import OptimalFlowModel, ScoreSDEModel, RQNSFModel

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


model = ScoreSDEModel.from_checkpoint(model_save_path, device=DEVICE)


In [11]:
import torch
# Load the saved paths
saved_paths = torch.load('Genodesic_tutorial_paths.pt')
phi_initial = saved_paths['phi_initial']
phi_relaxed = saved_paths['phi_relaxed']


In [18]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any, List

# Assume these functions are defined in your Genodesic library
from Genodesic.PathTools.PathEvaluation import (
    calculate_euclidean_segment_lengths,
    calculate_path_density,
    calculate_fermat_length,
)
from Genodesic.PathTools import DensityBasedResampling # And other utilities

# ===================================================================
#
# 1. DATA PROCESSING FUNCTION
#
# ===================================================================

def analyze_paths(
    paths_to_analyze: Dict[str, torch.Tensor],
    model,
    beta: float,
    density_s_steps: int,
    fermat_s_steps: int
) -> Dict[str, Dict[str, Any]]:
    """
    Analyzes a dictionary of paths, computing all required metrics.

    Args:
        paths_to_analyze: A dictionary where keys are path names (e.g., 'Initial')
                          and values are the path tensors.
        model: The density/score model.
        beta: The beta parameter for the Fermat metric.
        density_s_steps: Interpolation steps for density calculation.
        fermat_s_steps: Integration steps for Fermat length.

    Returns:
        A dictionary containing the calculated metrics for each path.
    """
    print("--- Starting Path Analysis ---")
    results = {}
    for name, path_tensor in paths_to_analyze.items():
        print(f"Processing '{name}' path...")
        
        # Calculate all metrics for the current path
        lengths = calculate_euclidean_segment_lengths(path_tensor)
        densities = calculate_path_density(path_tensor, model, density_s_steps)
        fermat_length = calculate_fermat_length(path_tensor, model, beta, fermat_s_steps)
        
        # Store results in a structured dictionary
        results[name] = {
            "path_tensor": path_tensor,
            "lengths": lengths,
            "total_length": np.sum(lengths),
            "densities": densities,
            "mean_log_density": np.mean(densities),
            "fermat_length": fermat_length,
        }
    print("--- Analysis Complete ---\n")
    return results

# ===================================================================
#
# 2. MODULAR PLOTTING & REPORTING FUNCTIONS
#
# ===================================================================

def plot_density_distributions(
    analysis_results: Dict[str, Dict[str, Any]],
    colors: Dict[str, str] = None
):
    """
    Plots the log-density distribution for an arbitrary number of paths.
    
    Args:
        analysis_results: The output from the analyze_paths function.
        colors: An optional dictionary mapping path names to plot colors.
    """
    print("Plotting log-density distributions...")
    plt.figure(figsize=(12, 8))
    
    default_colors = ['#FF6347', '#3CB371', '#1E90FF', '#FFD700', '#DB7093']
    
    for i, (name, data) in enumerate(analysis_results.items()):
        color = colors.get(name) if colors else default_colors[i % len(default_colors)]
        
        plt.hist(
            data['densities'],
            bins=40,
            alpha=0.7,
            label=f"{name} Path (Mean Log-Density: {data['mean_log_density']:.2f})",
            density=True,
            color=color
        )
        
    plt.title('Distribution of Midpoint Log-Densities', fontsize=16)
    plt.xlabel('Log-Density', fontsize=12)
    plt.ylabel('Normalized Frequency', fontsize=12)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.show()

def print_path_statistics(analysis_results: Dict[str, Dict[str, Any]], beta: float):
    """Prints a formatted summary table of the analysis results."""
    print("--- Path Statistics ---")
    for name, data in analysis_results.items():
        print(f"\n>> {name} Path:")
        print(f"  Euclidean Length: Total={data['total_length']:.3f}, Mean Seg={np.mean(data['lengths']):.3f}")
        print(f"  Log-Density:      Mean={data['mean_log_density']:.3f}, Std={np.std(data['densities']):.3f}")
        print(f"  Fermat Length:    {data['fermat_length']:.3f} (beta={beta})")

# ===================================================================
#
# 3. MAIN SCRIPT
#
# ===================================================================

# --- Configuration ---
BETA_PARAM = 0.33
DENSITY_STEPS = 3
INTEGRATION_STEPS = 4

# --- Load and Define Paths ---
# This is now the main place to add or remove paths for analysis.
saved_paths = torch.load('Genodesic_tutorial_paths.pt')

# Add any other paths you want to compare
# resampled_path = DensityBasedResampling(...)
# test_path = DensityBasedResampling(...)

paths = {
    "Initial": saved_paths['phi_initial'],
    "Relaxed": saved_paths['phi_relaxed'],
    "Negative": negative_beta, 
    # "Test": test_path           # Example: Easy to add new paths
}

# Define corresponding colors for plotting (optional but recommended)
path_colors = {
    "Initial": "#FF6347",  # Tomato Red
    "Relaxed": "#3CB371",  # Medium Sea Green
    "Resampled": "#1E90FF", # Dodger Blue
    "Test": "#FFD700"     # Gold
}


# --- Run Analysis and Reporting ---
# 1. Process all paths in one go
path_analysis_results = analyze_paths(
    paths, model, BETA_PARAM, DENSITY_STEPS, INTEGRATION_STEPS
)

# 2. Print the summary statistics
print_path_statistics(path_analysis_results, BETA_PARAM)

# 3. Plot the density distributions
plot_density_distributions(path_analysis_results, colors=path_colors)

# Note: Your pseudotime plotting functions are already well-refactored
# to handle multiple paths, so they can be used as-is with the new structure.

Calculating segment lengths...
Calculating midpoint log-densities...
Calculating Fermat lengths...
Done.

