In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")
import pickle
import sys
sys.path.append("../../../Benchmark/")
from benchmarker import Benchmarker

In [None]:
import os
os.chdir("../../../datatset/MultiEmbryo1/")

Prepare data

In [3]:
rna = sc.read_h5ad("rna.h5ad")
atac = sc.read_h5ad("atac.h5ad") 

In [5]:
gs = sc.read_h5ad("gs.h5ad")

In [4]:
rna

AnnData object with n_obs × n_vars = 2187 × 17058
    obs: 'orig.ident', 'nCount_Spatial', 'nFeature_Spatial', 'nCount_SCT', 'nFeature_SCT', 'nCount_ATAC', 'nFeature_ATAC', 'nCount_peaks', 'nFeature_peaks', 'RNA_clusters', 'ATAC_clusters', 'Joint_clusters', 'spatial1', 'spatial2'
    obsm: 'spatial'

In [6]:
gs

AnnData object with n_obs × n_vars = 2187 × 55291
    obs: 'tsse', 'n_fragment', 'frac_dup', 'frac_mito'
    uns: 'reference_sequences'
    obsm: 'insertion'

In [None]:
sc.pp.normalize_total(rna)
sc.pp.log1p(rna)
sc.pp.pca(rna)
sc.pp.neighbors(rna)

In [83]:
def search_resolution(adata, fixed_clus_count, increment=0.01):
    closest_count = np.inf  
    closest_res = None  
    
    for res in sorted(list(np.arange(0.1, 2., increment)), reverse=True):
        sc.tl.leiden(adata, random_state=0, resolution=res, key_added="temp_label")
        # print(len(set(adata.obs["leiden"])))
        count_unique_leiden = len(list(set(adata.obs["temp_label"])))
        current_diff = abs(count_unique_leiden - fixed_clus_count)
        if current_diff < closest_count:
            closest_count = current_diff
            closest_res = res
        if count_unique_leiden == fixed_clus_count:
            break

    return closest_res

res = search_resolution(rna, fixed_clus_count=14)
sc.tl.leiden(rna, key_added="label", resolution=res)

In [85]:
sc.pp.normalize_total(atac)
sc.pp.log1p(atac)
sc.pp.pca(atac)
sc.pp.neighbors(atac)

In [86]:
res = search_resolution(atac, fixed_clus_count=14)
sc.tl.leiden(atac, key_added="label", resolution=res)

In [91]:
atac = atac[rna.obs_names]
atac.obsm["spatial"] = rna.obsm["spatial"]

In [93]:
rna.obs["omic"] = "rna"
atac.obs["omic"] = "atac"
concat = sc.concat([rna, atac])

In [102]:
batch = pd.DataFrame(concat.obs["omic"])
cluster =  pd.DataFrame(concat.obs["label"])
embed = pd.DataFrame(concat.obsm["X_pca"])
umap = pd.DataFrame(concat.obsm["X_pca"])

In [None]:
batch.to_csv("../../../result/MultiEmbryo1/embed/batch_single.csv")
cluster.to_csv("../../../result/MultiEmbryo1/embed/cluster_single.csv")
embed.to_csv("../../../result/MultiEmbryo1/embed/embed_single.csv")
umap.to_csv("../../../result/MultiEmbryo1/embed/umap_single.csv")

In [3]:
# files = ["rna_sp.h5ad", "gs.h5ad"]
# adatas = []
# for i in files:
#     adata = sc.read_h5ad(i)
#     adatas.append(adata)

In [5]:
# adatas[1].obs_names = [i.split("-")[0] for i in adatas[1].obs_names]

In [6]:
# adatas[1] = adatas[1][adatas[0].obs_names]
# adatas[0].obs["cell_type"] = list(adatas[0].obs["Joint_clusters"])
# adatas[1].obs["cell_type"] = list(adatas[0].obs["cell_type"])
# adatas[1].obsm["spatial"] = adatas[0].obsm["spatial"].copy()

In [7]:
# import numpy as np
# for i, ad in enumerate(adatas):
#     ad.obs_names = [j+f"_{i}" for j in ad.obs_names]
# adatas[0].obs["batch"] = "RNA"
# adatas[1].obs["batch"] = "ATAC"

In [8]:
# concat = sc.concat(adatas)
# concat.obs["spatial1"] =  list(concat.obsm["spatial"][:,0])
# concat.obs["spatial2"] =  list(concat.obsm["spatial"][:,1])

In [9]:
# concat.write("concat.h5ad")

In [3]:
bm = Benchmarker(R_conda_env="Rbase")

In [11]:
# bm.h5ad2rds(in_file=concat, out_file="concat.rds", verbose=True)

Run all method

In [None]:
# bm.run(RDS_file_path="concat.rds", H5AD_file_path="concat.h5ad",
#        save_path="/../../../result/MultiEmbryo1/embed/", 
#        n_cluster=14, verbose=True, workers=4, methods=["PRECAST"])

Evaluate

In [4]:
adata = sc.read_h5ad("concat.h5ad")

In [None]:
# res_dict = bm.read_result("../../result/MultiEmbryo1/embed/", 
#                           index=list(adata.obs_names),
#                           reindex=False,
#                           save="../../result/MultiEmbryo1/embed_dict.pkl",
#                           methods=bm.all_methods+["BindSC", "GLUE", "Monae", "SCALEX", "scConfluence", "SIMBA"])

In [105]:
# for key in res_dict["Cluster"].keys():
#     print(key, len(set(res_dict["Cluster"][key].flatten())))

In [None]:
metrics = bm.cal_metrics(adata=adata, batch_key="batch", label_key="cell_type",
                         res_dict=res_dict, methods="all", verbose=True, rep=1,
                         min_max_scale=False, save="../../result/MultiEmbryo1/metrics.pkl")

In [33]:
with open("../../result/MultiEmbryo1/metrics.pkl", "rb") as f:
    metrics = pickle.load(f)

In [None]:
bm.set_plot_params(params_dict={"figure.dpi": 300}, font_file_path="./Helvetica.ttf")

2025-12-21 17:41:52 - INFO - Custom font 'Helvetica' has been set


In [None]:
save_dir = "../../../figures/Embryo1/"

In [11]:
# sorted(set(adata.obs["cell_type"]), key=lambda x: int(x.replace("J", "")))

In [177]:
# bm.plot_legend(category_lst=adata.obs["cell_type"], marker="o", ncol=5, borderpad=0,
#                handletextpad = 0, labelspacing = 0.5, columnspacing = 0,
#                order=sorted(list(set(adata.obs["cell_type"])), key=lambda x: int(x.replace("J", ""))),
#                save=f"{save_dir}/annot_legend_5col.pdf")

In [None]:
# bm.plot_legend(category_lst=adata.obs["batch"], marker="o", ncol=1, save=f"{save_dir}/batch_legend.pdf")

In [34]:
# bm.plot_legend(category_lst=list(range(14)), marker="o", ncol=4, save=f"{save_dir}/cluster_legend.pdf")

In [37]:
# bm.plot_legend(category_lst=adata.obs["cell_type"], marker="o", ncol=4, save=f"{save_dir}/annot_legend2.pdf")

In [83]:
# bm.plot_heatmap_legend(max_rank="22/26", save=f"{save_dir}/heatmap_legend2.pdf")

In [34]:
batch_metric = metrics[0].copy()
batch_metric = batch_metric.loc[:,["Silhouette batch","iLISI", "KBET", "Graph connectivity", "PCR comparison", "Batch correction"]]
batch_metric.loc["Metric Type"] = [i.replace(" ", "\n") for i in batch_metric.loc["Metric Type"]]

In [38]:
# bm.plot_heatmap(metric_df=batch_metric, save=f"{save_dir}/summary_batch_heatmap_all.pdf", total_name="Batch correction",)# insert_marker_row=8, show_top=7, show_bottom=5)


In [8]:
from benchmarker import split_adata, transform_coord
import numpy as np
spatial = [ad.obsm["spatial"] for ad in split_adata(adata)]
spatial = transform_coord(spatial, vertical=False, margin_size=0.15, axis="y", horizontal=False, angle=90)
# spatial = [np.concatenate(spatial)]

In [23]:
# spatial = [adata.obsm["spatial"]]
# spatial = transform_coord(spatial, vertical=False, )

In [9]:
bg_dict = {i:"#D4B483" if i in bm.spatial_methods else "#5873a4" for i in res_dict["UMAP"].keys() }
bg_dict["STADIA"] = "#D4B483"

In [25]:
# bm.plot_legend(category_lst=adata.obs["cell_type"], marker="o", ncol=4, save=f"{save_dir}/annot_legend.pdf",
#                handletextpad=0, borderpad=0,
#                labelspacing=0.6, columnspacing=0,)

In [None]:
bm.plot_spatial(spatial=spatial, label_dict={"annot": np.array(adata.obs["cell_type"]).reshape(-1,1)},
                figsize=(2, 4), frameon=True, inner_gs_row=2, inner_gs_col=1, size=15, ncol=1,
                xlabel=["Annotation"], ylabel=None, only_show_left=True,
                axis_width = 1.2, axis_color="lightgrey",
                outer_row_hspace=0.15, outer_col_wspace=0.1,
                background_color = lambda x: bg_dict[x] if x in bg_dict else None,
                xlabel_pad=0.015,
                save=None)#f"{save_dir}/annot_spatial_plot.pdf")

In [11]:
palette = sc.pl.palettes.default_20
cmap = {str(i):palette[i] for i in range(14)}

In [13]:
# bm.plot_spatial(spatial=spatial, label_dict=res_dict["Cluster"],
#                 figsize=(18, 14), frameon=True, inner_gs_row=2, inner_gs_col=1, size=20, ncol=8,
#                 xlabel="name", ylabel=["RNA", "ATAC"], only_show_left=True,
#                 axis_width = 1.2, axis_color="lightgrey",
#                 # order=["SLAT", "Seurat", "GLUE", "SCALEX", "BindSC"],
#                 outer_row_hspace=0.2, outer_col_wspace=0.05, inner_common_camp=True,
#                 background_color = lambda x: bg_dict[x] if x in bg_dict else None,
#                 xlabel_pad=0.012, palette=cmap, save_dpi=600,
#                 save=f"{save_dir}/all_spatial.pdf")

In [39]:
paired_methods = ["SpatialGlue", "COSMOS", "MISO", "PRESENT"]
embed, cluster = {}, {}
for i in paired_methods:
    embed[i] = pd.read_csv(f"../../result/MultiEmbryo1/embed/embed_{i}.csv", index_col=0)
    cluster[i] = pd.read_csv(f"../../result/MultiEmbryo1/embed/cluster_{i}.csv", index_col=0)
    if i == "PRESENT":
        cluster[i].index = [i.split("-")[0] for i in cluster[i].index]
        embed[i].index = cluster[i].index.copy()
        embed[i] = embed[i].reindex(rna.obs_names)
        cluster[i] = cluster[i].reindex(rna.obs_names)
    else:
        embed[i].index = rna.obs_names.copy()

    cluster[i] = np.array(cluster[i].iloc[:,0])


In [40]:
paired_res_dict = {}
paired_res_dict["Cluster"] = cluster.copy()
paired_res_dict["Embed"] = embed.copy()

In [43]:
rna.obs["batch"] = ["b1" if i<1200  else "b2" for i in range(rna.shape[0])] 

In [41]:
bio_metric = metrics[0].copy()
bio_metric = bio_metric.loc[:,['CHAOS', 'PAS', 'Isolated labels', 'NMI', 'ARI', 'Silhouette label', 'cLISI', 'Domain continuity', 'Bio conservation']]
# bio_metric.loc["Metric Type"] = [i.replace(" ", "\n") for i in bio_metric.loc["Metric Type"]]

In [None]:
paired_metric = bm.cal_metrics(adata=rna, batch_key="batch", label_key="Joint_clusters",
               res_dict=paired_res_dict, methods="all", verbose=True, rep=1,
               min_max_scale=False, save="../../result/MultiEmbryo1/paired_metrics.pkl")

In [45]:
bio_paired_metirc = paired_metric[0].loc[:,['CHAOS', 'PAS', 'Isolated labels', 'NMI', 'ARI', 'Silhouette label',
       'cLISI', 'Domain continuity', 'Bio conservation']]
bio_paired_metirc = bio_paired_metirc.drop("Metric Type")
bio_metric_all = pd.concat((bio_metric, bio_paired_metirc))#.sort_values("Bio conservation", ascending=False)

In [46]:
bio_metric_all.index = ["*"+i if i in paired_methods else i for i in bio_metric_all.index]

In [48]:
# bm.plot_heatmap(metric_df=bio_metric_all, save=f"{save_dir}/summary_bio_heatmap_all.pdf", total_name="Bio conservation")#, insert_marker_row=8, show_top=7, show_bottom=5)

In [266]:
spatial = [split_adata(adata)[0].obsm["spatial"]]
spatial = transform_coord(spatial, vertical=False, margin_size=0.15, axis="y", horizontal=False, angle=90)

In [268]:
# bm.plot_spatial(spatial=spatial,
#                 label_dict=cluster,
#                 figsize=(4.3, 4.5), frameon=True, inner_gs_row=1, inner_gs_col=1, size=20, ncol=2,
#                 xlabel="name",
#                 only_show_left=True,
#                 axis_width = 1.2, axis_color="lightgrey",
#                 order=["SpatialGlue", "MISO", "COSMOS", "PRESENT"],
#                 outer_row_hspace=0.2, outer_col_wspace=0.05, inner_common_camp=True,
#                 background_color = lambda x: "#D4B483",
#                 xlabel_pad=0.012,
#                 palette=cmap,
#                 save_dpi=600,
#                 save=f"{save_dir}/paired_spatial.pdf"
#                 )

In [251]:
umap = {}
for i in embed.keys():
    t_adata = sc.AnnData(X=embed[i])
    sc.pp.neighbors(t_adata, use_rep="X")
    sc.tl.umap(t_adata)
    umap[i] = t_adata.obsm["X_umap"]

In [275]:
bg_dict = {i:"#D4B483" if i not in bm.spatial_methods else "#5873a4" for i in embed.keys() }

In [283]:
# bm.plot_umap(embed_dict=umap, annot_list=list(rna.obs["Joint_clusters"]), 
#              figsize=(3.8, 4.1), frameon=True, inner_gs_row=1, inner_gs_col=1, size=5, ncol=2,
#              ylabel=["Cell type"], only_show_top=False,  xlabel="name", only_show_left=True,
#              background_color = lambda x: bg_dict[x] if x in bg_dict else None,
#             order=["SpatialGlue", "MISO", "COSMOS", "PRESENT"],
#              axis_width=1.2, axis_color="lightgrey",
#              outer_col_wspace=0.05, save_dpi=600,  outer_row_hspace=0.2,
#              ylabel_pad=0.018, xlabel_pad=0.012, merge=False, merge_margin_size=0.4,
#              save=f"{save_dir}/paired_umap.pdf")

In [29]:
clus = list(res_dict["Cluster"]["Unintegrated"].flatten().astype(str))

In [30]:
# adata.obs["clus"] = clus
# adata.obsm["spatial"] = np.concatenate(spatial)
# sc.pl.spatial(adata, color="clus", spot_size=1)

In [15]:
# bm.plot_umap(embed_dict=res_dict["UMAP"], batch_dict=res_dict["Batch"], annot_list=list(adata.obs["cell_type"]), 
#              figsize=(14, 12), frameon=True, inner_gs_row=2, inner_gs_col=1, size=3, ncol=8,
#              ylabel=["Batch", "Cell type"], only_show_top=False,  xlabel="name", only_show_left=True,
#              background_color = lambda x: bg_dict[x] if x in bg_dict else None,
#             #  order=["Unintegrated", "SLAT", "Seurat", "GLUE", "SCALEX", "BindSC"],
#              axis_width=1.2, axis_color="lightgrey",
#              outer_col_wspace=0.05, save_dpi=600, outer_row_hspace=0.15,
#              ylabel_pad=0.018, xlabel_pad=0.015, merge=True, merge_margin_size=0.4,
#              save=f"{save_dir}/all_umap.pdf")