# CellOracle GRN Calculation - Per Cell Type
Loop through each cell type in `cluster_annot`, subset data, and build GRNs using `sample` as the cluster unit.

## 0. Import libraries

In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns

import celloracle as co
co.__version__

In [None]:
# Visualization settings
%config InlineBackend.figure_format = 'retina'
%matplotlib inline

plt.rcParams['figure.figsize'] = [6, 4.5]
plt.rcParams["savefig.dpi"] = 300

## 1. Load full dataset and base GRN

In [None]:
# Load full data
adata_full = sc.read_h5ad("CTR9_snRNASeq/CTR9_snRNASeq_full.h5ad")
print(f"Loaded data: {adata_full.shape[0]} cells x {adata_full.shape[1]} genes")
print(f"\nCell types (cluster_annot): {adata_full.obs['cluster_annot'].unique().tolist()}")
print(f"Samples: {adata_full.obs['sample'].unique().tolist()}")
adata_full

In [None]:
# Load base GRN once (shared across all cell types)
base_GRN = co.data.load_mouse_scATAC_atlas_base_GRN()
print(f"Base GRN shape: {base_GRN.shape}")
base_GRN.head()

In [None]:
# Get all cell types
cell_types = adata_full.obs['cluster_annot'].unique().tolist()
print(f"Will process {len(cell_types)} cell types:")
for i, ct in enumerate(cell_types):
    n_cells = (adata_full.obs['cluster_annot'] == ct).sum()
    print(f"  {i}: {ct} ({n_cells} cells)")

## 2. Create output directories

In [None]:
# Base output directories
base_results = "celloracle_results/per_celltype"
base_figures = "celltype_figures"
base_genes = "celltype_genes"

os.makedirs(base_results, exist_ok=True)
os.makedirs(base_figures, exist_ok=True)
os.makedirs(base_genes, exist_ok=True)
print("\u2713 Base directories created")

## 3. Define the per-cell-type pipeline function

In [None]:
def run_celloracle_for_celltype(adata_full, cell_type, base_GRN,
                                 base_results, base_figures, base_genes,
                                 min_cells=50, n_top_genes=3000):
    """
    Run the full CellOracle pipeline for a single cell type.
    GRN is built using 'sample' as the cluster unit (WT vs KO comparison).
    
    Parameters
    ----------
    adata_full : AnnData
        Full dataset
    cell_type : str
        Cell type to subset from cluster_annot
    base_GRN : pd.DataFrame
        Base GRN from CellOracle
    min_cells : int
        Minimum number of cells required to run pipeline
    n_top_genes : int
        Number of highly variable genes to select
    """
    # Safe name for file paths
    safe_name = cell_type.replace("/", "_").replace("\\", "_").replace(" ", "_")
    
    # Create cell-type-specific directories
    save_folder = f"{base_figures}/{safe_name}"
    os.makedirs(save_folder, exist_ok=True)
    os.makedirs(f"{save_folder}/degree_distribution", exist_ok=True)
    os.makedirs(f"{save_folder}/ranked_score", exist_ok=True)
    os.makedirs(f"{save_folder}/score_comparison", exist_ok=True)
    os.makedirs(f"{save_folder}/top30_degree_centrality", exist_ok=True)
    os.makedirs(f"{base_genes}/{safe_name}", exist_ok=True)
    
    print(f"\n{'='*70}")
    print(f"  Processing: {cell_type}")
    print(f"{'='*70}")
    
    # =========================================================
    # 1. Subset to this cell type
    # =========================================================
    adata = adata_full[adata_full.obs['cluster_annot'] == cell_type, :].copy()
    print(f"\n[1] Subset: {adata.shape[0]} cells x {adata.shape[1]} genes")
    
    # Check minimum cell count
    if adata.shape[0] < min_cells:
        print(f"  ⚠ Skipping {cell_type}: only {adata.shape[0]} cells (min={min_cells})")
        return None, None
    
    # Check sample distribution
    sample_counts = adata.obs['sample'].value_counts()
    print(f"  Sample distribution:")
    for s, c in sample_counts.items():
        print(f"    {s}: {c} cells")
    
    if len(sample_counts) < 2:
        print(f"  ⚠ Skipping {cell_type}: only 1 sample present")
        return None, None
    
    # =========================================================
    # 2. Preprocessing
    # =========================================================
    print(f"\n[2] Preprocessing...")
    
    # Save raw counts to layers
    if hasattr(adata, 'raw') and adata.raw is not None:
        adata.layers['counts'] = adata.raw.X.copy()
        adata.layers['log1p'] = adata.X.copy()
    else:
        adata.layers['counts'] = adata.X.copy()
    
    # Filter genes
    sc.pp.filter_genes(adata, min_counts=1)
    print(f"  After gene filtering: {adata.shape}")
    
    # Normalize
    sc.pp.normalize_total(adata, target_sum=1e4)
    
    # Select HVGs - adjust n_top_genes if we have fewer genes
    actual_n_top = min(n_top_genes, adata.shape[1] - 1)
    filter_result = sc.pp.filter_genes_dispersion(
        adata.X,
        flavor='cell_ranger',
        n_top_genes=actual_n_top,
        log=False
    )
    print(f"  Selected {filter_result.gene_subset.sum()} highly variable genes")
    
    adata = adata[:, filter_result.gene_subset]
    
    # Renormalize
    sc.pp.normalize_total(adata, target_sum=1e4)
    print(f"  After HVG selection: {adata.shape}")
    
    # =========================================================
    # 3. Compute embeddings
    # =========================================================
    print(f"\n[3] Computing embeddings...")
    
    if 'X_pca' not in adata.obsm.keys():
        sc.pp.pca(adata, n_comps=min(50, adata.shape[0] - 1, adata.shape[1] - 1))
    print(f"  \u2713 PCA")
    
    n_neighbors = min(30, adata.shape[0] - 1)
    sc.pp.neighbors(adata, n_pcs=min(30, adata.obsm['X_pca'].shape[1]), n_neighbors=n_neighbors)
    print(f"  \u2713 Neighbors (n={n_neighbors})")
    
    sc.tl.umap(adata)
    print(f"  \u2713 UMAP")
    
    # Save UMAP plot
    sc.pl.umap(adata, color='sample', title=f"{cell_type} - by sample", show=False)
    plt.savefig(f"{save_folder}/umap_by_sample.png", dpi=150, bbox_inches='tight')
    plt.close()
    
    # =========================================================
    # 4. Create Oracle object
    # =========================================================
    print(f"\n[4] Creating Oracle object...")
    
    oracle = co.Oracle()
    
    # Import data - use 'sample' as cluster column for WT vs KO GRN building
    oracle.import_anndata_as_raw_count(
        adata=adata,
        cluster_column_name="sample",
        embedding_name="X_umap"
    )
    print(f"  \u2713 Data imported (cluster_column='sample')")
    
    # Import TF info
    oracle.import_TF_data(TF_info_matrix=base_GRN)
    print(f"  \u2713 TF data imported")
    
    # =========================================================
    # 5. KNN imputation
    # =========================================================
    print(f"\n[5] KNN imputation...")
    
    oracle.perform_PCA()
    
    # Select n_comps
    try:
        n_comps = np.where(
            np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_)) > 0.002)
        )[0][0]
    except IndexError:
        n_comps = min(20, oracle.adata.shape[0] - 1)
    n_comps = max(n_comps, 5)  # Ensure at least 5 components
    print(f"  PCA components: {n_comps}")
    
    n_cell = oracle.adata.shape[0]
    k = max(int(0.025 * n_cell), 5)  # Ensure k >= 5
    print(f"  Cell count: {n_cell}, k: {k}")
    
    oracle.knn_imputation(
        n_pca_dims=n_comps,
        k=k,
        balanced=True,
        b_sight=k * 8,
        b_maxl=k * 4,
        n_jobs=4
    )
    print(f"  \u2713 KNN imputation complete")
    
    # Save oracle object
    oracle_path = f"{base_results}/{safe_name}.celloracle.oracle"
    oracle.to_hdf5(oracle_path)
    print(f"  \u2713 Oracle saved: {oracle_path}")
    
    # =========================================================
    # 6. GRN calculation (using 'sample' as cluster unit)
    # =========================================================
    print(f"\n[6] GRN calculation (by sample)...")
    
    links = oracle.get_links(
        cluster_name_for_GRN_unit="sample",
        alpha=10,
        verbose_level=10
    )
    
    # Save raw links
    links_path = f"{base_results}/{safe_name}.celloracle.links"
    links.to_hdf5(file_path=links_path)
    print(f"  \u2713 Links saved: {links_path}")
    
    # =========================================================
    # 7. Network preprocessing
    # =========================================================
    print(f"\n[7] Network preprocessing...")
    
    links.filter_links(p=0.001, weight="coef_abs", threshold_number=2000)
    print(f"  \u2713 Links filtered")
    
    # Degree distribution plots
    for cluster in links.cluster:
        safe_cluster = cluster.replace("/", "_").replace("\\", "_")
        cluster_folder = f"{save_folder}/degree_distribution/degree_dist_{safe_cluster}"
        os.makedirs(cluster_folder, exist_ok=True)
    
    plt.rcParams["figure.figsize"] = [9, 4.5]
    links.plot_degree_distributions(
        plot_model=True,
        save=f"{save_folder}/degree_distribution/"
    )
    plt.rcParams["figure.figsize"] = [6, 4.5]
    
    # =========================================================
    # 8. Network scores
    # =========================================================
    print(f"\n[8] Calculating network scores...")
    
    links.get_network_score()
    print(f"  \u2713 Network scores calculated")
    
    # Save filtered + scored links
    filtered_links_path = f"{base_results}/{safe_name}_filtered.celloracle.links"
    links.to_hdf5(file_path=filtered_links_path)
    print(f"  \u2713 Filtered links saved: {filtered_links_path}")
    
    # =========================================================
    # 9. Save gene scores
    # =========================================================
    print(f"\n[9] Saving gene scores...")
    
    merged_scores = links.merged_score
    for clust in links.cluster:
        safe_cluster = clust.replace("/", "_").replace("\\", "_")
        filepath = f"{base_genes}/{safe_name}/{safe_cluster}_all_genes.csv"
        scores = merged_scores.loc[merged_scores['cluster'] == clust]
        scores_sorted = scores.sort_values('degree_centrality_all', ascending=False)
        scores_sorted.to_csv(filepath)
        print(f"  \u2713 Saved: {filepath}")
    
    # =========================================================
    # 10. Visualize top genes with high degree centrality
    # =========================================================
    print(f"\n[10] Visualizing top genes...")
    
    all_top_genes = []
    for sample_name in links.cluster:
        df = links.merged_score[links.merged_score["cluster"] == sample_name].copy()
        df_sorted = df.sort_values("degree_centrality_all", ascending=False).head(50)
        
        # Collect top genes
        df_top = df_sorted[["degree_centrality_all"]].copy()
        df_top["sample"] = sample_name
        df_top["cell_type"] = cell_type
        df_top["gene"] = df_top.index
        df_top["rank"] = range(1, len(df_top) + 1)
        all_top_genes.append(df_top)
        
        # Plot top 30
        df_plot = df_sorted.head(30)
        fig, ax = plt.subplots(figsize=(6, 8))
        ax.scatter(df_plot["degree_centrality_all"].values, range(len(df_plot)))
        ax.set_yticks(range(len(df_plot)))
        ax.set_yticklabels(df_plot.index)
        ax.invert_yaxis()
        ax.set_xlabel("degree_centrality_all")
        ax.set_title(f"degree_centrality_all\ntop 30 in {cell_type} - {sample_name}")
        plt.tight_layout()
        safe_sample = sample_name.replace("/", "_")
        plt.savefig(f"{save_folder}/top30_degree_centrality/{safe_sample}.png", dpi=150)
        plt.close()
    
    # Save combined CSV
    if all_top_genes:
        combined_df = pd.concat(all_top_genes, ignore_index=True)
        combined_df = combined_df[["cell_type", "sample", "rank", "gene", "degree_centrality_all"]]
        combined_df.to_csv(f"{save_folder}/top50_degree_centrality_by_sample.csv", index=False)
        print(f"  \u2713 Top genes CSV saved")
    
    # =========================================================
    # 11. Heatmap of network scores
    # =========================================================
    print(f"\n[11] Creating heatmap...")
    
    N_GENES_HEATMAP = 50
    top_genes = set()
    for cluster in links.cluster:
        cluster_mask = links.merged_score['cluster'] == cluster
        cluster_scores = links.merged_score.loc[
            cluster_mask, 'degree_centrality_all'
        ].sort_values(ascending=False)
        top_genes.update(cluster_scores.head(N_GENES_HEATMAP).index)
    
    score_df = links.merged_score[['cluster', 'degree_centrality_all']].copy()
    score_df = score_df[score_df.index.isin(top_genes)]
    pivot_table = score_df.pivot(columns='cluster', values='degree_centrality_all')
    
    plt.figure(figsize=(10, 16))
    sns.heatmap(pivot_table, cmap='viridis',
                cbar_kws={'label': 'Degree Centrality'},
                linewidths=0.5, linecolor='gray')
    plt.title(f'Network Degree Centrality - {cell_type}\nTop {N_GENES_HEATMAP} Genes per Sample',
              fontsize=14, fontweight='bold')
    plt.xlabel('Sample', fontsize=12)
    plt.ylabel('Gene', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(f"{save_folder}/network_score_heatmap.png", dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  \u2713 Heatmap saved")
    
    # =========================================================
    # 12. Score comparison between samples (WT vs KO)
    # =========================================================
    print(f"\n[12] Score comparison between samples...")
    
    if len(links.cluster) >= 2:
        CLUSTER1 = links.cluster[0]
        CLUSTER2 = links.cluster[1]
        print(f"  Comparing: {CLUSTER1} vs {CLUSTER2}")
        
        links.plot_score_comparison_2D(
            value="degree_centrality_all",
            cluster1=CLUSTER1,
            cluster2=CLUSTER2,
            percentile=97,
            save=f"{save_folder}/score_comparison"
        )
        print(f"  \u2713 Comparison plot saved")
    else:
        print(f"  \u26a0 Only {len(links.cluster)} sample(s) - skipping comparison")
    
    print(f"\n{'='*70}")
    print(f"  \u2713 COMPLETED: {cell_type}")
    print(f"{'='*70}\n")
    
    return oracle, links

## 4. Run pipeline for all cell types

In [None]:
# Store results for all cell types
results = {}
skipped = []

for i, cell_type in enumerate(cell_types):
    print(f"\n\n>>> [{i+1}/{len(cell_types)}] Starting: {cell_type}")
    
    try:
        oracle, links = run_celloracle_for_celltype(
            adata_full=adata_full,
            cell_type=cell_type,
            base_GRN=base_GRN,
            base_results=base_results,
            base_figures=base_figures,
            base_genes=base_genes,
            min_cells=50,
            n_top_genes=3000
        )
        
        if oracle is not None:
            results[cell_type] = {'oracle': oracle, 'links': links}
        else:
            skipped.append(cell_type)
            
    except Exception as e:
        print(f"\n  ✗ ERROR processing {cell_type}: {e}")
        import traceback
        traceback.print_exc()
        skipped.append(cell_type)

print(f"\n\n{'='*70}")
print(f"SUMMARY")
print(f"{'='*70}")
print(f"Successfully processed: {len(results)}/{len(cell_types)}")
print(f"  Completed: {list(results.keys())}")
if skipped:
    print(f"  Skipped/Failed: {skipped}")

## 5. Cross-cell-type summary

In [None]:
# Create a combined summary of top degree centrality genes across all cell types and samples
all_summaries = []

for cell_type, res in results.items():
    links = res['links']
    for sample_name in links.cluster:
        df = links.merged_score[links.merged_score['cluster'] == sample_name].copy()
        df_top = df.sort_values('degree_centrality_all', ascending=False).head(20)
        df_top = df_top[['degree_centrality_all']].copy()
        df_top['cell_type'] = cell_type
        df_top['sample'] = sample_name
        df_top['gene'] = df_top.index
        df_top['rank'] = range(1, len(df_top) + 1)
        all_summaries.append(df_top)

if all_summaries:
    summary_df = pd.concat(all_summaries, ignore_index=True)
    summary_df = summary_df[['cell_type', 'sample', 'rank', 'gene', 'degree_centrality_all']]
    summary_df.to_csv(f"{base_genes}/cross_celltype_top20_summary.csv", index=False)
    print(f"\u2713 Cross-cell-type summary saved")
    print(f"  Shape: {summary_df.shape}")
    display(summary_df.head(20))
else:
    print("No results to summarize")

In [None]:
# Heatmap: top hub genes across cell types (for a specific sample)
if all_summaries:
    # Get unique samples
    samples = summary_df['sample'].unique()
    
    for sample_name in samples:
        sample_data = summary_df[summary_df['sample'] == sample_name]
        pivot = sample_data.pivot_table(
            index='gene', columns='cell_type',
            values='degree_centrality_all', aggfunc='first'
        )
        
        plt.figure(figsize=(12, max(8, len(pivot) * 0.3)))
        sns.heatmap(pivot.fillna(0), cmap='viridis',
                    cbar_kws={'label': 'Degree Centrality'},
                    linewidths=0.5, linecolor='gray')
        plt.title(f'Top Hub Genes Across Cell Types - {sample_name}',
                  fontsize=14, fontweight='bold')
        plt.xlabel('Cell Type', fontsize=12)
        plt.ylabel('Gene', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        safe_sample = sample_name.replace("/", "_")
        plt.savefig(f"{base_figures}/cross_celltype_heatmap_{safe_sample}.png",
                    dpi=300, bbox_inches='tight')
        plt.show()
        print(f"\u2713 Cross-cell-type heatmap saved for {sample_name}")

In [None]:
print("\n=== ALL DONE ===")
print(f"\nResults saved to:")
print(f"  Oracle/Links objects: {base_results}/")
print(f"  Figures: {base_figures}/")
print(f"  Gene lists: {base_genes}/")