# Cancer Dataset - Training and Plasticity Analysis

This notebook demonstrates LSD training on the KP_tracer cancer dataset and correlation analysis with scPlasticity scores.

**Key Features:**
- Training on KP_tracer cancer data with phylogenetic pseudotime prior
- Correlation between LSD entropy and scPlasticity
- Correlation between LSD pseudotime and scPlasticity

**Data Requirements:**
- KP_tracer dataset: `../../data/KP_tracer/train_adata.h5ad`
- Plasticity scores: `plasticity_scores.tsv` (included in this folder)

## Setup

In [None]:
import os
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from scipy.stats import pearsonr, spearmanr

# Import lsdpy components
from sclsd import LSD, LSDConfig, set_all_seeds, clear_pyro_state

# Set random seed
SEED = 42
set_all_seeds(SEED)

# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Load and Inspect Data

In [None]:
# Load the KP_tracer dataset
data_path = "../../data/KP_tracer/train_adata.h5ad"
adata = sc.read_h5ad(data_path)

print(f"Loaded {adata.n_obs} cells x {adata.n_vars} genes")
print(f"\nObs columns: {adata.obs.columns.tolist()}")

In [None]:
# Visualize the data
sc.pl.umap(
    adata,
    color=["transferred_phylo_pseudotime", "clusters"],
    title="Phylovelo pseudotime"
)

## Configure and Train LSD Model

In [None]:
# Configure model for cancer data
cfg = LSDConfig()
cfg.model.V_coeff = 1e-4
cfg.model.layer_dims.x_encoder = [512, 256, 128]
cfg.model.layer_dims.potential = [16, 16]
cfg.optimizer.adam.lr = 2e-3
cfg.optimizer.wasserstein_schedule.max_W = 1e-2
cfg.optimizer.kl_schedule.af = 2
cfg.walks.batch_size = 256
cfg.walks.path_len = 8
cfg.walks.num_walks = 8192
cfg.walks.random_state = SEED

# Model and save directory
model_dir = "./cancer_model"
os.makedirs(model_dir, exist_ok=True)

In [None]:
# Initialize LSD model
clear_pyro_state()
lsd = LSD(
    adata=adata,
    config=cfg,
    device=device,
    lib_size_key="librarysize",
    raw_count_key="raw"
)

# Use the transferred phylogenetic pseudotime as prior
lsd.set_prior_transition(prior_time_key="transferred_phylo_pseudotime")

# Sample random walks through the state space
lsd.prepare_walks()

print("Model initialized and walks prepared")

In [None]:
# Train the model
clear_pyro_state()
lsd.train(
    num_epochs=100,
    save_dir=model_dir,
    save_interval=50
)

## Extract and Visualize Results

In [None]:
# Extract LSD outputs
final_adata = lsd.get_adata()

print(f"Result obs columns: {final_adata.obs.columns.tolist()}")
print(f"Result obsm keys: {list(final_adata.obsm.keys())}")

In [None]:
# UMAP visualization of LSD outputs
sc.pl.umap(
    final_adata,
    color=["clusters", "entropy", "potential", "lsd_pseudotime"],
    title="Final UMAP: LSD outputs",
    ncols=2
)

In [None]:
# Diffusion representation embedding
sc.pl.embedding(
    final_adata,
    basis="X_diff_state",
    color=["clusters", "entropy", "potential", "lsd_pseudotime"],
    title="Diffusion State embedding",
    ncols=2
)

In [None]:
# Stream lines on UMAP
lsd.stream_lines(embedding="X_umap")

## Entropy vs Pseudotime

In [None]:
# Helper function for scatter plots
def scatter_time_feature(df, feature, ylabel, title):
    plt.figure(figsize=(10, 6))
    sns.scatterplot(
        data=df, x="time", y=feature, hue="cluster", s=20, alpha=0.8
    )
    plt.xlabel("LSD Pseudotime", fontsize=14)
    plt.ylabel(ylabel, fontsize=14)
    plt.title(title, fontsize=16)
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.tight_layout()
    plt.show()

# Prepare DataFrame
df = pd.DataFrame({
    "time": final_adata.obs["lsd_pseudotime"].values,
    "potential": final_adata.obs["potential"].values,
    "entropy": final_adata.obs["entropy"].values,
    "cluster": final_adata.obs["clusters"].values,
})

# Plot entropy vs time
scatter_time_feature(df, "entropy", "Entropy", "Entropy vs LSD Pseudotime")

## scPlasticity Correlation Analysis

This section analyzes the correlation between LSD outputs (entropy, pseudotime) and scPlasticity scores.

In [None]:
# Load plasticity scores
plasticity_df = pd.read_csv("plasticity_scores.tsv", sep="\t")

print(f"Loaded {len(plasticity_df)} plasticity scores")
print(f"Columns: {plasticity_df.columns.tolist()}")

In [None]:
# Merge plasticity scores with LSD results
obs_df = final_adata.obs[['entropy', 'transferred_phylo_pseudotime']].copy()
obs_df.index.name = "Unnamed: 0"
obs_df.reset_index(inplace=True)

# Perform an inner join on cell IDs
merged_df = plasticity_df.merge(obs_df, on="Unnamed: 0", how="inner")
merged_df.rename(columns={"entropy": "Entropy", "transferred_phylo_pseudotime": "Time"}, inplace=True)

# Normalize entropy
normal_entropy = (merged_df['Entropy'].values - np.min(merged_df['Entropy'].values)) / \
                 (np.max(merged_df['Entropy'].values) - np.min(merged_df['Entropy'].values))
merged_df['Entropy'] = normal_entropy

print(f"Merged {len(merged_df)} cells with scPlasticity scores")

In [None]:
# Add scPlasticity to final_adata
merged_df = merged_df.set_index("Unnamed: 0")
final_adata.obs['scPlasticity'] = merged_df['scPlasticity']

### Entropy vs scPlasticity Correlation

In [None]:
# Filter clusters with >= 50% scPlasticity values
df_full = final_adata.obs[['clusters', 'entropy', 'scPlasticity']].copy()
df_full['clusters'] = df_full['clusters'].astype(str)

valid_cluster_percent = (
    df_full.groupby('clusters')['scPlasticity']
    .apply(lambda x: x.notna().mean())
)
valid_clusters = valid_cluster_percent[valid_cluster_percent >= 0.5].index
df = df_full[df_full['clusters'].isin(valid_clusters)].dropna(subset=['entropy', 'scPlasticity'])

# Calculate correlations
pearson_r, pearson_p = pearsonr(df['scPlasticity'], df['entropy'])
spearman_r, spearman_p = spearmanr(df['scPlasticity'], df['entropy'])

# Create scatterplot with regression line
plt.figure(figsize=(8, 6))
sns.regplot(
    y='scPlasticity',
    x='entropy',
    data=df,
    scatter_kws={'alpha': 0.6, 's': 30},
    line_kws={'color': 'red'}
)

# Annotate correlation coefficients
plt.text(
    0.05, 0.95,
    f'Pearson r = {pearson_r:.3f}',
    ha='left', va='top', transform=plt.gca().transAxes,
    fontsize=13,
    bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray')
)

plt.ylabel('Effective Plasticity (scPlasticity)', fontsize=14)
plt.xlabel('Entropy', fontsize=14)
plt.title('Correlation between Entropy and Effective Plasticity', fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(model_dir, 'entropy_vs_plasticity.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"Pearson r = {pearson_r:.3f} (p = {pearson_p:.2e})")
print(f"Spearman r = {spearman_r:.3f} (p = {spearman_p:.2e})")

### LSD Pseudotime vs scPlasticity Correlation

In [None]:
# Filter clusters and prepare data
df_full = final_adata.obs[['clusters', 'lsd_pseudotime', 'scPlasticity']].copy()
df_full['clusters'] = df_full['clusters'].astype(str)

valid_cluster_percent = (
    df_full.groupby('clusters')['scPlasticity']
    .apply(lambda x: x.notna().mean())
)
valid_clusters = valid_cluster_percent[valid_cluster_percent >= 0.5].index
df = df_full[df_full['clusters'].isin(valid_clusters)].dropna(subset=['lsd_pseudotime', 'scPlasticity'])

# Calculate correlations
pearson_r, pearson_p = pearsonr(df['scPlasticity'], df['lsd_pseudotime'])
spearman_r, spearman_p = spearmanr(df['scPlasticity'], df['lsd_pseudotime'])

# Plot
plt.figure(figsize=(8, 6))
sns.regplot(
    y='scPlasticity',
    x='lsd_pseudotime',
    data=df,
    scatter_kws={'alpha': 0.6, 's': 30},
    line_kws={'color': 'red'}
)

plt.text(
    0.05, 0.95,
    f'Pearson r = {pearson_r:.3f}',
    ha='left', va='top', transform=plt.gca().transAxes,
    fontsize=13,
    bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray')
)

plt.ylabel('Effective Plasticity (scPlasticity)', fontsize=14)
plt.xlabel('LSD Pseudotime', fontsize=14)
plt.title('Correlation between LSD Pseudotime and Effective Plasticity', fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(model_dir, 'pseudotime_vs_plasticity.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"Pearson r = {pearson_r:.3f} (p = {pearson_p:.2e})")
print(f"Spearman r = {spearman_r:.3f} (p = {spearman_p:.2e})")

## Save Results

In [None]:
import json

# Save summary of results
results_summary = {
    "entropy_plasticity_correlation": {
        "pearson_r": float(pearson_r),
        "pearson_p": float(pearson_p),
    },
    "cells_with_plasticity": int(df['scPlasticity'].notna().sum()),
    "total_cells": int(len(final_adata)),
}

with open(os.path.join(model_dir, "analysis_summary.json"), "w") as f:
    json.dump(results_summary, f, indent=2)

print(f"Results saved to {model_dir}/")
print(f"\nAnalysis complete!")