In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import anndata as ad
import scanpy as sc
import os
import numpy as np
import pandas as pd
import squidpy as sq
import os
from sklearn.preprocessing import LabelEncoder
import scprep
from scvi.model import SCVI
import torch
import torchsde
from torchdyn.core import NeuralODE
from tqdm import tqdm
from torchcfm.conditional_flow_matching import *
from torchcfm.models import MLP, GradModel
from torchcfm.utils import plot_trajectories, torch_wrapper
import pandas as pd
import seaborn as sns
from umap import UMAP
import torch.nn as nn
from torchdyn.core import NeuralODE
import seaborn as sns
from scvi.model import SCVI
import scvi
from scipy.stats import wasserstein_distance
from scipy.stats import energy_distance
from sklearn.metrics import r2_score
from sklearn.neighbors import NearestNeighbors

import matplotlib.colors as mcolors




In [2]:
def compute_local_mean(scRNA, representation='X_pca', spatial_key='spatial', radius=50):
    """
    Compute mean vector of the chosen representation for spatial neighbors.
    
    Parameters:
        scRNA: AnnData object
        representation: str, key in .obsm, e.g. 'X_pca', 'X_umap', 'X_scVI'
        spatial_key: str, key in .obsm with spatial coordinates
        radius: float, neighborhood radius in same units as coordinates
    """
    coords = scRNA.obsm[spatial_key]
    X = scRNA.obsm[representation]

    nbrs = NearestNeighbors(radius=radius).fit(coords)
    neighbors_idx = nbrs.radius_neighbors(coords, return_distance=False)

    local_means = np.zeros_like(X)

    for i, idx in enumerate(neighbors_idx):
        if len(idx) > 0:
            local_means[i] = X[idx].mean(axis=0)
        else:
            local_means[i] = X[i]  

    # Store the result in .obsm
    scRNA.obsm[f"local_mean_{representation}"] = local_means

    return scRNA


def get_scVI_latent_representation(scRNA, cell_type_key, spatial_key):
        print ("Using scVI for input data")
        #read the data (again cause scVI requires unormalized data) and set up scVI
        le = LabelEncoder()
        scRNA.obs[cell_type_key] = le.fit_transform(scRNA.obs[cell_type_key])
        scvi.model.SCVI.setup_anndata(scRNA)
        model = SCVI(scRNA)
        model.train()
        latent = model.get_latent_representation()
        scRNA.obsm["X_scVI"] = latent
        scRNA = compute_local_mean(scRNA, representation='X_scVI', radius=50)
        return scRNA

In [3]:
def load_scRNA_GSE232025(data_path):
    """
    This function processes the GSE232025 dataset and returns the concatenated AnnData object.
    """
    sc1= sc.read_h5ad(os.path.join(data_path, "d0_spatial_scRNAseq.h5ad"))
    sc1.obs["Batch"] = "0"
    sc2= sc.read_h5ad(os.path.join(data_path, "d1_spatial_scRNAseq.h5ad"))
    sc2.obs["Batch"] = "1"
    sc3= sc.read_h5ad(os.path.join(data_path, "d2_spatial_scRNAseq.h5ad"))
    sc3.obs["Batch"] = "2"
    sc4= sc.read_h5ad(os.path.join(data_path, "d3_spatial_scRNAseq.h5ad"))
    sc4.obs["Batch"] = "3"
    sc5= sc.read_h5ad(os.path.join(data_path, "d4_spatial_scRNAseq.h5ad"))
    sc5.obs["Batch"] = "4"

    adatas = [sc1, sc2, sc3, sc4, sc5]

    ## cell color mapping
    all_cell_types = set()
    for adata in adatas:
        all_cell_types.update(adata.obs["Annotation"].unique())

    # Create a consistent color map for all cell types
    palette = sns.color_palette("tab20", len(all_cell_types))
    color_dict = dict(zip(sorted(all_cell_types), palette))

    for sci in adatas:
        sci.obs = sci.obs.rename(columns={"Annotation": "celltype", "Batch": "day"})
        sci.obs["celltype"] = sci.obs["celltype"].astype(str)
        sci.obs["color"] = sci.obs["celltype"].map(color_dict)
        sci.obs["color"] = sci.obs["color"].apply(lambda x: mcolors.to_hex(x) if isinstance(x, (tuple, list, np.ndarray)) else x)
    scRNA = ad.concat(adatas, label='day', keys=[a.obs['day'][0] for a in adatas])

    return scRNA

In [4]:
def apply_processing_steps(scRNA, GSE_ID, num_genes=2000, n_comps=50):
    """
    This function applies the processing steps to the AnnData object.
    There needs to standardized processing steps for all datasets, this is a temporary solution.
    """
    if GSE_ID == "GSE232025":
        sc.pp.highly_variable_genes(scRNA, batch_key='day', n_top_genes=num_genes, flavor='seurat_v3')
        scRNA = scRNA[:, scRNA.var.highly_variable]
        sc.tl.pca(scRNA, n_comps=n_comps)
        sc.pp.neighbors(scRNA)
        sc.tl.umap(scRNA)
        scRNA= compute_local_mean(scRNA, representation='X_pca')
        # scRNA = get_scVI_latent_representation(scRNA, cell_type_key='celltype', spatial_key='spatial')

    elif GSE_ID == "GSE149457":
        sc.pp.normalize_total(scRNA, target_sum=1e4)
        sc.pp.log1p(scRNA)
        sc.pp.highly_variable_genes(scRNA, n_top_genes=num_genes, subset=True)
        sc.pp.scale(scRNA, max_value=10)
        sc.tl.pca(scRNA, n_comps=n_comps)
        sc.pp.neighbors(scRNA)
        sc.tl.umap(scRNA)

    return scRNA


In [5]:
base_path= "/Users/rssantanu/Desktop/codebase/constrained_FM/datasets"

def get_concatenated_dynamic_data(GSE_ID, save_path=None, data_type="scRNA", num_genes=2000, n_comps=50):
    """
    This function concatenates the dynamic data from the raw datasets and saves it as a csv file.
    """
    # Get the raw data

    if GSE_ID == "GSE149457" and data_type == "scRNA":
        data_path = os.path.join(base_path, "raw_datasets", "GSE149457")
        scRNA = load_scRNA_GSE149457(data_path)
        scRNA = apply_processing_steps(scRNA, GSE_ID, num_genes=num_genes, n_comps=n_comps)
    
    elif GSE_ID == "GSE232025" and data_type == "scRNA":
        data_path = os.path.join(base_path, "raw_datasets", "GSE232025")
        scRNA = load_scRNA_GSE232025(data_path)
        scRNA = apply_processing_steps(scRNA, GSE_ID, num_genes=num_genes, n_comps=n_comps)
    
    return scRNA

In [6]:
GSE_ID = "GSE232025"
scRNA= get_concatenated_dynamic_data(GSE_ID, num_genes=2000, n_comps=50)

  scRNA = ad.concat(adatas, label='day', keys=[a.obs['day'][0] for a in adatas])
  utils.warn_names_duplicates("obs")
  adata.obsm[key_obsm] = X_pca
  utils.warn_names_duplicates("obs")


In [11]:
scRNA

AnnData object with n_obs × n_vars = 28757 × 2000
    obs: 'CellID', 'day', 'cell_id', 'celltype', 'color'
    var: 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
    uns: 'hvg', 'pca', 'neighbors', 'umap'
    obsm: 'spatial', 'X_pca', 'X_umap', 'local_mean_X_pca'
    varm: 'PCs'
    layers: 'counts'
    obsp: 'distances', 'connectivities'

In [None]:
lrdata_by_day= []

for stage in scRNA.obs['day'].unique():
    adata_stage = scRNA[scRNA.obs['day'] == stage].copy()
    lrdata = li.mt.bivariate(adata_stage,
                resource_name='consensus', # NOTE: uses HUMAN gene symbols!
                local_name='cosine', # Name of the function
                global_name="morans", # Name global function
                n_perms=100, # Number of permutations to calculate a p-value
                mask_negatives=False, # Whether to mask LowLow/NegativeNegative interactions
                add_categories=True, # Whether to add local categories to the results
                nz_prop=0.01, # Minimum expr. proportion for ligands/receptors and their subunits
                use_raw=False,
                verbose=True
                )
    # Save or analyze results for this stage
    # adata_stage.uns['liana_res'].to_csv(f'liana_results_stage_{stage}.csv')
    lrdata_by_day.append(lrdata)