In [1]:
import SpaGCN as spg
import os,csv,re
import pandas as pd
import numpy as np
import scanpy as sc
import math
from scipy.sparse import issparse
import random, torch
import warnings
warnings.filterwarnings("ignore")
import matplotlib.colors as clr
import matplotlib.pyplot as plt

  from pandas import Int64Index as NumericIndex


In [30]:
p = 0.5
n_clusters=7
r_seed=t_seed=n_seed=100
min_in_group_fraction=0.8
min_in_out_group_ratio=1
min_fold_change=1.5

for i in range(10):
    adata = sc.read_h5ad(f"../../data/simulation/svgs/adata_rep_{i}.h5ad")
    adata.uns['log1p'] = {'base': None}
    adj = spg.calculate_adj_matrix(x=adata.obsm['spatial'][:, 0], 
                               y=adata.obsm['spatial'][:, 1], 
                               histology=False)
    l=spg.search_l(p, adj, start=0.01, end=1000, tol=0.01, max_run=100)
    
    res=spg.search_res(adata, 
                   adj, 
                   l, n_clusters, 
                   start=0.7, step=0.1, 
                   tol=5e-3, lr=0.05, 
                   max_epochs=20, 
                   r_seed=r_seed, 
                   t_seed=t_seed, 
                   n_seed=n_seed)
    
    clf=spg.SpaGCN()
    clf.set_l(l)
    #Set seed
    random.seed(r_seed)
    torch.manual_seed(t_seed)
    np.random.seed(n_seed)
    #Run
    clf.train(adata, adj, init_spa=True,init="louvain",res=res, tol=5e-3, lr=0.05, max_epochs=200)
    y_pred, prob=clf.predict()
    adata.obs["pred"]= y_pred
    
    de_genes_all = list()
    for target in range(n_clusters):
        print(f"target: {target}")
        start, end= np.quantile(adj[adj!=0], q=0.001), np.quantile(adj[adj!=0],q=0.1)
        r=spg.search_radius(target_cluster=target, 
                    cell_id=adata.obs.index.tolist(), 
                    x=adata.obsm['spatial'][:, 0], 
                    y=adata.obsm['spatial'][:, 1], 
                    pred=adata.obs["pred"].tolist(), start=start, end=end, num_min=10, num_max=14,  max_run=100)
        
        try:
            nbr_domians=spg.find_neighbor_clusters(target_cluster=target,
                                       cell_id=adata.obs.index.tolist(), 
                                       x=adata.obsm['spatial'][:, 0], 
                                       y=adata.obsm['spatial'][:, 1], 
                                       pred=adata.obs["pred"].tolist(),
                                       radius=r,
                                       ratio=1/2)

            de_genes_info=spg.rank_genes_groups(input_adata=adata,
                                            target_cluster=target,
                                            nbr_list=nbr_domians, 
                                            label_col="pred", 
                                            adj_nbr=True, 
                                            log=True)
            de_genes_all.append(de_genes_info)
        except (RuntimeError, TypeError, NameError):
            pass
        
    df = pd.concat(de_genes_all)[['genes', 'pvals_adj']]
    df = df.groupby(['genes']).min()
    df = df.loc[adata.var_names]
    df['spatially_variable'] = adata.var.spatially_variable.astype(np.int).values
    df = df[['pvals_adj', 'spatially_variable']]
    df.to_csv(f"./svgs/rep_{i}.csv")

Calculateing adj matrix using xy only...
Run 1: l [0.01, 1000], p [0.0, 2497.9591184697747]
Run 2: l [0.01, 500.005], p [0.0, 2494.841064453125]
Run 3: l [0.01, 250.0075], p [0.0, 2482.434814453125]
Run 4: l [0.01, 125.00874999999999], p [0.0, 2433.849609375]
Run 5: l [0.01, 62.509375], p [0.0, 2254.9013671875]
Run 6: l [0.01, 31.2596875], p [0.0, 1725.561767578125]
Run 7: l [0.01, 15.63484375], p [0.0, 864.577392578125]
Run 8: l [0.01, 7.822421875], p [0.0, 293.59014892578125]
Run 9: l [0.01, 3.9162109375], p [0.0, 83.75708770751953]
Run 10: l [0.01, 1.9631054687499998], p [0.0, 21.752925872802734]
Run 11: l [0.01, 0.9865527343749999], p [0.0, 4.941504001617432]
Run 12: l [0.01, 0.49827636718749996], p [0.0, 0.5931874513626099]
Run 13: l [0.25413818359374996, 0.49827636718749996], p [0.0017036199569702148, 0.5931874513626099]
Run 14: l [0.37620727539062493, 0.49827636718749996], p [0.11784100532531738, 0.5931874513626099]
Run 15: l [0.4372418212890624, 0.49827636718749996], p [0.30740