In [1]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import celloracle as co
from celloracle import motif_analysis as ma
from celloracle.utility import save_as_pickled_object
from genomepy import install_genome

  from pkg_resources import get_distribution, DistributionNotFound


In [2]:
# Set plotting parameters
plt.rcParams['figure.figsize'] = [6, 4.5]
plt.rcParams["savefig.dpi"] = 300

# Scanpy settings
sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=80, facecolor='white', frameon=False)

print(f"CellOracle version: {co.__version__}")
print(f"Scanpy version: {sc.__version__}")

CellOracle version: 0.20.0
Scanpy version: 1.10.1


In [3]:
adata = sc.read_h5ad("CTR9_snRNASeq/CTR9_snRNASeq.h5ad")
print(f"Loaded data: {adata.shape[0]} cells x {adata.shape[1]} genes")
print(adata)

Loaded data: 9869 cells x 2000 genes
AnnData object with n_obs × n_vars = 9869 × 2000
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'RNA_snn_res.0.5', 'seurat_clusters', 'RNA_snn_res.0.1', 'RNA_snn_res.1', 'RNA_snn_res.0.2', 'cluster_annot'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'
    uns: 'neighbors'
    obsm: 'X_harmony', 'X_pca', 'X_umap'
    varm: 'HARMONY', 'PCs'
    obsp: 'distances'



This is where adjacency matrices should go now.
  warn(


In [4]:
# Visualize cell type and sample distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Cell type counts
adata.obs['cluster_annot'].value_counts().plot(kind='barh', ax=axes[0])
axes[0].set_xlabel('Number of cells')
axes[0].set_title('Cell Type Distribution')

# Sample distribution per cell type
pd.crosstab(adata.obs['cluster_annot'], adata.obs['sample']).plot(kind='barh', stacked=True, ax=axes[1])
axes[1].set_xlabel('Number of cells')
axes[1].set_title('Sample Distribution by Cell Type')
axes[1].legend(title='Sample')

plt.tight_layout()
plt.savefig('figures/cell_type_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

In [5]:
if 'X_umap' in adata.obsm.keys():
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    sc.pl.umap(adata, color='cluster_annot', ax=axes[0], show=False, legend_loc='on data')
    sc.pl.umap(adata, color='sample', ax=axes[1], show=False)
    
    plt.tight_layout()
    plt.savefig('figures/umap_overview.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No UMAP found - will compute later")

... storing 'orig.ident' as categorical
... storing 'sample' as categorical
... storing 'cluster_annot' as categorical


In [6]:
# Load TF info which was made from mouse cell atlas dataset.
base_GRN = co.data.load_mouse_scATAC_atlas_base_GRN()

# Check data
base_GRN.head()

Unnamed: 0,peak_id,gene_short_name,9430076c15rik,Ac002126.6,Ac012531.1,Ac226150.2,Afp,Ahr,Ahrr,Aire,...,Znf784,Znf8,Znf816,Znf85,Zscan10,Zscan16,Zscan22,Zscan26,Zscan31,Zscan4
0,chr10_100050979_100052296,4930430F08Rik,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,chr10_101006922_101007748,SNORA17,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
2,chr10_101144061_101145000,Mgat4c,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
3,chr10_10148873_10149183,9130014G24Rik,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,chr10_10149425_10149815,9130014G24Rik,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [7]:
# Instantiate Oracle object
oracle = co.Oracle()

In [8]:
print("Metadata columns:")
print(adata.obs.columns.tolist())
print("\nDimensional reduction: ")
print(list(adata.obsm.keys()))
print("\nSample distribution:")
print(adata.obs['sample'].value_counts())
print("\nCell type distribution:")
print(adata.obs['cluster_annot'].value_counts())

Metadata columns:
['orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'RNA_snn_res.0.5', 'seurat_clusters', 'RNA_snn_res.0.1', 'RNA_snn_res.1', 'RNA_snn_res.0.2', 'cluster_annot']

Dimensional reduction: 
['X_harmony', 'X_pca', 'X_umap']

Sample distribution:
WT_DM    4981
KO_DM    4888
Name: sample, dtype: int64

Cell type distribution:
Epi_Kit+Elf5+           1811
Adipocyte               1802
Tcells                  1333
BasalEpi_Acta2+Trp63    1066
Epi_Ctr9+                943
Fibroblasts              908
Bcells                   519
Endothelials             442
Myeloid_cells            404
Epi_proliferating        221
DCs                      162
Pericytes/SMC            142
SMC?                      79
Schwann?                  37
Name: cluster_annot, dtype: int64


In [9]:
# Load the updated h5ad file with raw counts layer
adata = sc.read_h5ad("CTR9_snRNASeq/CTR9_snRNASeq_with_raw.h5ad")

# In this notebook, we use the unscaled mRNA count for the input of Oracle object.
adata.X = adata.layers["raw_count"].copy()

# Instantiate Oracle object.
oracle.import_anndata_as_raw_count(adata=adata,
                                   cluster_column_name="cluster_annot",
                                   embedding_name="X_umap")

In [10]:
oracle.import_TF_data(TF_info_matrix=base_GRN)

In [11]:
!wget https://raw.githubusercontent.com/morris-lab/CellOracle/master/docs/demo_data/TF_data_in_Paul15.csv

--2026-02-05 12:29:58--  https://raw.githubusercontent.com/morris-lab/CellOracle/master/docs/demo_data/TF_data_in_Paul15.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1771 (1.7K) [text/plain]
Saving to: ‘TF_data_in_Paul15.csv.2’


2026-02-05 12:29:58 (46.9 MB/s) - ‘TF_data_in_Paul15.csv.2’ saved [1771/1771]



In [12]:
# Load the TF and target gene information from Paul et al. (2015).
Paul_15_data = pd.read_csv("TF_data_in_Paul15.csv")
Paul_15_data


Unnamed: 0,TF,Target_genes
0,Cebpa,"Abcb1b, Acot1, C3, Cnpy3, Dhrs7, Dtx4, Edem2, ..."
1,Irf8,"Abcd1, Aif1, BC017643, Cbl, Ccdc109b, Ccl6, d6..."
2,Irf8,"1100001G20Rik, 4732418C07Rik, 9230105E10Rik, A..."
3,Klf1,"2010011I20Rik, 5730469M10Rik, Acsl6, Add2, Ank..."
4,Spi1,"0910001L09Rik, 2310014H01Rik, 4632428N05Rik, A..."


In [13]:
# Make dictionary: dictionary key is TF and dictionary value is list of target genes.
TF_to_TG_dictionary = {}

for TF, TGs in zip(Paul_15_data.TF, Paul_15_data.Target_genes):
    # convert target gene to list
    TG_list = TGs.replace(" ", "").split(",")
    # store target gene list in a dictionary
    TF_to_TG_dictionary[TF] = TG_list

# We invert the dictionary above using a utility function in celloracle.
TG_to_TF_dictionary = co.utility.inverse_dictionary(TF_to_TG_dictionary)

  0%|          | 0/178 [00:00<?, ?it/s]

In [14]:
# Add TF information 
oracle.addTFinfo_dictionary(TG_to_TF_dictionary)

In [15]:
# Perform PCA
oracle.perform_PCA()

# Select important PCs
plt.plot(np.cumsum(oracle.pca.explained_variance_ratio_)[:100])
n_comps = np.where(np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_))>0.002))[0][0]
plt.axvline(n_comps, c="k")
plt.show()

In [16]:
print(n_comps)
n_comps = min(n_comps, 50)

25


In [17]:
n_cell = oracle.adata.shape[0]
print(f"cell number is :{n_cell}")



k = int(0.025*n_cell)
print(f"Auto-selected k is :{k}")


cell number is :9869
Auto-selected k is :246


In [18]:
oracle.knn_imputation(n_pca_dims=n_comps, k=k, balanced=True, b_sight=k*8,b_maxl=k*4, n_jobs=4)

In [19]:
# Save oracle object.
oracle.to_hdf5("ctr9.celloracle.oracle")

# Load file.
oracle = co.load_hdf5("ctr9.celloracle.oracle")

In [20]:
sc.pl.umap(oracle.adata, color="cluster_annot")

In [21]:
links = oracle.get_links(cluster_name_for_GRN_unit="cluster_annot", 
                         alpha=10, verbose_level=10)

  0%|          | 0/14 [00:00<?, ?it/s]

Inferring GRN for Adipocyte...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for BasalEpi_Acta2+Trp63...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for Bcells...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for DCs...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for Endothelials...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for Epi_Ctr9+...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for Epi_Kit+Elf5+...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for Epi_proliferating...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for Fibroblasts...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for Myeloid_cells...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for Pericytes/SMC...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for SMC?...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for Schwann?...


  0%|          | 0/1467 [00:00<?, ?it/s]

Inferring GRN for Tcells...


  0%|          | 0/1467 [00:00<?, ?it/s]

In [22]:
# Save the links object
links.to_hdf5(file_path="celloracle_results/CTR9_links.celloracle.links")

In [23]:
# Check which clusters have GRNs
print("Clusters with GRNs:")
print(list(links.links_dict.keys()))
print(f"\nTotal clusters: {len(links.links_dict)}")

# Check number of links in each cluster (before filtering)
print("\nNumber of regulatory links per cluster (unfiltered):")
for cluster in links.links_dict.keys():
    n_links = len(links.links_dict[cluster])
    n_tfs = links.links_dict[cluster]['source'].nunique()
    n_targets = links.links_dict[cluster]['target'].nunique()
    print(f"{cluster:20s}: {n_links:6d} links | {n_tfs:4d} TFs | {n_targets:4d} targets")

Clusters with GRNs:
['Adipocyte', 'BasalEpi_Acta2+Trp63', 'Bcells', 'DCs', 'Endothelials', 'Epi_Ctr9+', 'Epi_Kit+Elf5+', 'Epi_proliferating', 'Fibroblasts', 'Myeloid_cells', 'Pericytes/SMC', 'SMC?', 'Schwann?', 'Tcells']

Total clusters: 14

Number of regulatory links per cluster (unfiltered):
Adipocyte           :  32140 links |   66 TFs | 1461 targets
BasalEpi_Acta2+Trp63:  32140 links |   66 TFs | 1461 targets
Bcells              :  32140 links |   66 TFs | 1461 targets
DCs                 :  32140 links |   66 TFs | 1461 targets
Endothelials        :  32140 links |   66 TFs | 1461 targets
Epi_Ctr9+           :  32140 links |   66 TFs | 1461 targets
Epi_Kit+Elf5+       :  32140 links |   66 TFs | 1461 targets
Epi_proliferating   :  32140 links |   66 TFs | 1461 targets
Fibroblasts         :  32140 links |   66 TFs | 1461 targets
Myeloid_cells       :  32140 links |   66 TFs | 1461 targets
Pericytes/SMC       :  32140 links |   66 TFs | 1461 targets
SMC?                :  32140 links

In [24]:
# Examine structure of GRN dataframe
cluster_example = list(links.links_dict.keys())[0]
print(f"\nExample GRN structure from '{cluster_example}':")
print(links.links_dict[cluster_example].head(10))
print(f"\nColumns: {links.links_dict[cluster_example].columns.tolist()}")


Example GRN structure from 'Adipocyte':
   source         target  coef_mean  coef_abs         p     -logp
0   Pparg  1110019D14Rik   0.057630  0.057630  0.000143  3.843907
1  Pou2f2  1110019D14Rik   0.010296  0.010296  0.015726  1.803369
2    Elf5  1110019D14Rik   0.011540  0.011540  0.008512  2.069988
3  Plagl1  1110019D14Rik   0.016529  0.016529  0.000002  5.763010
4    Ebf3  1110019D14Rik   0.035757  0.035757  0.000124  3.907540
5   Batf3  1110019D14Rik   0.006841  0.006841  0.013530  1.868691
6  Mlxipl  1110019D14Rik  -0.015434  0.015434  0.013857  1.858327
7   Gata3  1110019D14Rik   0.018611  0.018611  0.000797  3.098663
8    Lhx8  1110019D14Rik   0.012021  0.012021  0.000003  5.579679
9  Tfap2b  1110019D14Rik   0.003944  0.003944  0.003137  2.503520

Columns: ['source', 'target', 'coef_mean', 'coef_abs', 'p', '-logp']


In [25]:
cluster_name = "Epi_Ctr9+"  
grn_df = links.links_dict[cluster_name]

print(grn_df.head(20))
print(f"\nShape: {grn_df.shape}")
print(f"\nColumns: {grn_df.columns.tolist()}")

    source         target  coef_mean  coef_abs             p      -logp
0    Pparg  1110019D14Rik   0.029126  0.029126  2.497997e-11  10.602408
1   Pou2f2  1110019D14Rik  -0.008699  0.008699  1.508366e-04   3.821493
2     Elf5  1110019D14Rik  -0.005408  0.005408  1.833419e-05   4.736738
3   Plagl1  1110019D14Rik   0.036186  0.036186  4.385414e-09   8.357989
4     Ebf3  1110019D14Rik   0.014902  0.014902  1.485478e-05   4.828134
5    Batf3  1110019D14Rik  -0.006411  0.006411  1.064298e-05   4.972937
6   Mlxipl  1110019D14Rik   0.017153  0.017153  3.447175e-06   5.462537
7    Gata3  1110019D14Rik   0.016090  0.016090  2.069587e-05   4.684116
8     Lhx8  1110019D14Rik   0.017449  0.017449  2.114901e-05   4.674710
9   Tfap2b  1110019D14Rik   0.031759  0.031759  4.122304e-07   6.384860
10   Foxp3  1110019D14Rik   0.010950  0.010950  8.803736e-07   6.055333
11    Mbd1  1110019D14Rik   0.058011  0.058011  3.246451e-15  14.488591
12     Pgr  1110019D14Rik  -0.026251  0.026251  3.092109e-09   8

In [26]:
# Save GRN for specific cluster as CSV
# cluster_name = "Epi_Ctr9+"
# links.links_dict[cluster_name].to_csv(f"grn_results/GRN_{cluster_name}.csv", index=False)

# save all clusters
# for cluster in links.links_dict.keys():
#     links.links_dict[cluster].to_csv(f"grn_results/GRN_{cluster}.csv", index=False)

In [27]:
# Filter links based on p-value and coefficient threshold
links.filter_links(p=0.001,          # P-value threshold
                   weight="coef_abs", # Use absolute coefficient
                   threshold_number=2000)  # Keep top 2000 links per cluster

In [28]:
# Check number of links after filtering
print("\nNumber of regulatory links per cluster (after filtering):")
for cluster in links.links_dict.keys():
    n_links = len(links.links_dict[cluster])
    n_tfs = links.links_dict[cluster]['source'].nunique()
    n_targets = links.links_dict[cluster]['target'].nunique()
    print(f"{cluster:20s}: {n_links:6d} links | {n_tfs:4d} TFs | {n_targets:4d} targets")


Number of regulatory links per cluster (after filtering):
Adipocyte           :  32140 links |   66 TFs | 1461 targets
BasalEpi_Acta2+Trp63:  32140 links |   66 TFs | 1461 targets
Bcells              :  32140 links |   66 TFs | 1461 targets
DCs                 :  32140 links |   66 TFs | 1461 targets
Endothelials        :  32140 links |   66 TFs | 1461 targets
Epi_Ctr9+           :  32140 links |   66 TFs | 1461 targets
Epi_Kit+Elf5+       :  32140 links |   66 TFs | 1461 targets
Epi_proliferating   :  32140 links |   66 TFs | 1461 targets
Fibroblasts         :  32140 links |   66 TFs | 1461 targets
Myeloid_cells       :  32140 links |   66 TFs | 1461 targets
Pericytes/SMC       :  32140 links |   66 TFs | 1461 targets
SMC?                :  32140 links |   66 TFs | 1461 targets
Schwann?            :  32140 links |   66 TFs | 1461 targets
Tcells              :  32140 links |   66 TFs | 1461 targets


In [29]:
# Calculate network scores (degree centrality, betweenness, etc.)
print("Calculating network scores...")
links.get_network_score()
print("✓ Network scores calculated")

# Display score types
print("\nScore metrics available:")
print(links.merged_score.columns.tolist())

# Show example scores
print("\nExample network scores:")
print(links.merged_score.head(20))

Calculating network scores...
✓ Network scores calculated

Score metrics available:
['degree_all', 'degree_centrality_all', 'degree_in', 'degree_centrality_in', 'degree_out', 'degree_centrality_out', 'betweenness_centrality', 'eigenvector_centrality', 'cluster']

Example network scores:
         degree_all  degree_centrality_all  degree_in  degree_centrality_in  \
Rbpj             18               0.039823          0              0.000000   
Fabp4             3               0.006637          3              0.006637   
Irf8             31               0.068584          0              0.000000   
Aldh1a1           9               0.019912          9              0.019912   
Plin1             6               0.013274          6              0.013274   
Esr1             95               0.210177          6              0.013274   
Nnat             15               0.033186         15              0.033186   
Fos              47               0.103982          0              0.000000   
T

In [None]:
# Plot top genes by degree centrality for each cluster
N_TOP_GENES = 20

# Check if network scores exist
if not hasattr(links, 'merged_score') or links.merged_score is None:
    print("⚠ Error: Network scores not calculated yet!")
    print("Please run: links.get_network_score()")
else:
    fig, axes = plt.subplots(3, 4, figsize=(20, 15))
    axes = axes.flatten()
    
    for idx, cluster in enumerate(sorted(links.links_dict.keys())):
        print(cluster)
        if idx < len(axes):
            # Get scores for this cluster
            cluster_scores = links.merged_score.loc[
                links.merged_score['cluster'] == cluster,
                'degree_centrality_all'
            ].sort_values(ascending=False).head(N_TOP_GENES)
            
            # Plot
            cluster_scores.plot(kind='barh', ax=axes[idx], color='steelblue')
            axes[idx].set_title(f'{cluster}', fontsize=10, fontweight='bold')
            axes[idx].set_xlabel('Degree Centrality', fontsize=8)
            axes[idx].invert_yaxis()
            axes[idx].tick_params(labelsize=7)
    
    # Hide unused subplots
    for idx in range(len(links.links_dict.keys()), len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig('figures/network_analysis/top_genes_per_cluster.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"✓ Saved: figures/network_analysis/top_genes_per_cluster.png")


Adipocyte
BasalEpi_Acta2+Trp63
Bcells
DCs
Endothelials
Epi_Ctr9+
Epi_Kit+Elf5+
Epi_proliferating
Fibroblasts
Myeloid_cells
Pericytes/SMC
SMC?
Schwann?
Tcells
✓ Saved: figures/network_analysis/top_genes_per_cluster.png


In [35]:
# Create heatmap of network scores across clusters
N_GENES_HEATMAP = 50
if not hasattr(links, 'merged_score') or links.merged_score is None:
    print("⚠ Error: Network scores not calculated yet!")
    print("Please run: links.get_network_score()")
else:
    # Get top genes per cluster
    top_genes = set()
    for cluster in links.links_dict.keys():
        cluster_mask = links.merged_score['cluster'] == cluster
        cluster_scores = links.merged_score.loc[cluster_mask, 'degree_centrality_all'].sort_values(ascending=False)
        top_genes.update(cluster_scores.head(N_GENES_HEATMAP).index)

    print(f"Creating heatmap with {len(top_genes)} unique genes...")

    # Create pivot table for heatmap
    score_df = links.merged_score[['cluster', 'degree_centrality_all']].copy()
    score_df = score_df[score_df.index.isin(top_genes)]
    pivot_table = score_df.pivot(columns='cluster', values='degree_centrality_all')

    # Plot heatmap
    plt.figure(figsize=(14, 20))
    sns.heatmap(pivot_table, cmap='viridis', cbar_kws={'label': 'Degree Centrality'},
                linewidths=0.5, linecolor='gray')
    plt.title(f'Network Degree Centrality\nTop {N_GENES_HEATMAP} Genes per Cluster',
              fontsize=14, fontweight='bold')
    plt.xlabel('Cluster', fontsize=12)
    plt.ylabel('Gene', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig('figures/network_analysis/network_score_heatmap.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("✓ Saved: figures/network_analysis/network_score_heatmap.png")

Creating heatmap with 138 unique genes...
✓ Saved: figures/network_analysis/network_score_heatmap.png


In [41]:
# Select two clusters to compare (adjust these to your clusters of interest)
CLUSTER1 = "Epi_Ctr9+" 
CLUSTER2 = "Epi_Kit+Elf5+"

# Check if clusters exist
available_clusters = list(links.links_dict.keys())
if CLUSTER1 in available_clusters and CLUSTER2 in available_clusters:
    fig, ax = plt.subplots(figsize=(12, 10))
    
    links.plot_score_comparison_2D(
        value="degree_centrality_all",
        cluster1=CLUSTER1,
        cluster2=CLUSTER2,
        save=None
    )
    
    plt.tight_layout()
    plt.savefig(f'figures/network_analysis/score_comparison_{CLUSTER1}_vs_{CLUSTER2}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()
    print(f"✓ Saved comparison plot")
else:
    print(f"⚠ Cluster not found. Available clusters: {available_clusters}")
    print("Please modify CLUSTER1 and CLUSTER2 in the cell above.")

✓ Saved comparison plot
