In [None]:
import seml
import seaborn as sns
from poisson_atac.utils import model_type_map
import itertools
import matplotlib.pyplot as plt
import pandas as pd
import os
import numpy as np

In [None]:
from poisson_atac.utils import load_experiment

In [None]:
save_path = '/storage/groups/ml01/workspace/laura.martens/atac_poisson_data/processed/scib/'
fig_path = '/storage/groups/ml01/workspace/laura.martens/atac_poisson_data/panels/Figure2/scib'

In [None]:
def load_seml(seml_database):
    results = seml.get_results(seml_database, to_data_frame=True,  fields=["config", "config_hash", "result", "batch_id"],)

    results.loc[results['config.data.batch'].apply(isinstance,args = [list]), 'config.data.batch'] = 'Neurips\nall batches'
    results.loc[results['config.data.batch'].isna(), 'config.data.batch'] = 'Satpathy\net al.'

    results.loc[results['config.data.dataset'] == "neurips_count", 'config.model.model_type'] = "cistopic_count"


    results["config.model.model_type"] = results["config.model.model_type"].map(model_type_map)
    #results["result.model_type"] = results["result.model_type"].map(model_type_map)
    return results

In [None]:
model_order = [model_type_map['count'], model_type_map["peakvi"]]

In [None]:
seml_database = 'cv_atac_atac_prediction'

In [None]:
results = load_seml(seml_database)

df = results.loc[(results["config.data.batch"].isin(['Neurips\nall batches', 'Satpathy\net al.'])) & (results['config.model.model_type'].isin(model_order)) ]

df = df[(df["config.setup.model_params.use_layer_norm"] != "both") & (~df["config.setup.model_params.n_latent"].isna())]

df.shape

# Export data for scib plotting in R

scib_table = pd.concat(df['result.embedding'].values, axis =1).T

scib_table.shape

scib_table.index = "/" + df['config.data.dataset'] + "/metrics/unscaled/full_feature/" + df['config.model.model_type'] + "_embed"

scib_table.index.name = "X"

scib_table = scib_table.reset_index().groupby('X').agg("mean")

scib_table

scib_table.to_csv(os.path.join(save_path, "integration_metrics_mean.csv"))

os.path.join(save_path, "integration_metrics_mean.csv")


source("plotSingleTaskATAC.R") 
plotSingleTaskATAC('/storage/groups/ml01/workspace/laura.martens/atac_poisson_data/processed/scib/integration_metrics_mean.csv',
                  outdir='/storage/groups/ml01/workspace/laura.martens/atac_poisson_data/panels/Figure2/scib')


# Plot distirbution

scib_table = pd.concat(df['result.embedding'].values, axis =1).T.dropna(axis=1, how='all')

batch_metrics = ['PCR_batch', 'ASW_label/batch', 'graph_conn']
bio_metrics = ['NMI_cluster/label', 'ARI_cluster/label', 'ASW_label', 'isolated_label_F1', 'isolated_label_silhouette', 'trajectory']

scib_table = scib_table.loc[:, batch_metrics+bio_metrics]

scib_table["Model"] = df['config.model.model_type'].values
scib_table["Dataset"] = df['config.data.dataset'].values

weight_batch=0.4
score_group_batch = scib_table.loc[:, batch_metrics].mean(axis=1)
score_group_bio = scib_table.loc[:, bio_metrics].mean(axis=1)
score_all = (weight_batch*score_group_batch + (1-weight_batch)*score_group_bio)

scib_table.insert(0, "Overall", score_all)

metrics_map = dict(zip(
  ["ASW_label", "ASW label/batch", "cell cycle conservation", "hvg overlap", "trajectory", "graph conn", "iLISI", "cLISI"], 
  ["Cell type ASW", "Batch ASW", "CC conservation", "HVG conservation", "trajectory conservation", "graph connectivity", "graph iLISI", "graph cLISI"]
))

scib_table.columns = scib_table.columns.str.replace("_", " ")

scib_table.columns = pd.Series(scib_table.columns).replace(metrics_map)

# Plot scatterplot

sns_df = scib_table.drop("Dataset", axis=1).melt(id_vars=["Model"])

df1 = sns_df[sns_df["Model"] == model_type_map["count"]]
df2 = sns_df[sns_df["Model"] == model_type_map["peakvi"]]

merged_df = pd.DataFrame({model_type_map["count"]: df1.value.values, model_type_map["peakvi"]: df2.value.values, 'Metric': df1.variable.values})

sns.set_style("white")
pal = sns.cubehelix_palette(10, rot=-.25, light=.7)
g = sns.FacetGrid(merged_df, col="Metric", hue="Metric", aspect=1, height=5, sharex=False, sharey=False, col_wrap=4)
g.map(sns.scatterplot, model_type_map["peakvi"],model_type_map["count"],alpha=0.8, linewidth=1.5)


for label, ax in g.axes_dict.items():
    ax.text(1, 0.01, label, fontweight="normal", color="black",
            ha="right", va="bottom", transform=ax.transAxes)
    x_lim=ax.get_xlim()
    y_lim=ax.get_ylim()
    lim = (min(x_lim[0],y_lim[0]), max(x_lim[1],y_lim[1]))
    ax.axline((0.5,0.5), slope=1, color="lightgrey", linewidth=.5)
    ax.set_xlim(lim)
    ax.set_ylim(lim)

g.set_titles("")
plt.tight_layout()
plt.savefig(os.path.join(fig_path, "integration_metrics_dist.pdf"))
plt.savefig(os.path.join(fig_path, "integration_metrics_dist.png"))

# Plot boxplot
dataset = "neurips"
sns_df = scib_table[scib_table["Dataset"] == "neurips"]

sns_df = sns_df.melt(id_vars=["Dataset", "Model"], var_name = "Metric", value_name="Value")

sns.set_style("whitegrid")
fig, ax = plt.subplots(figsize=(12, 6))
sns.boxplot(data=sns_df, x="Value", y = "Metric", hue="Model", ax=ax)

merged_df

test = scib_table.melt(id_vars = ['Dataset', "Model"])

import matplotlib.patches as  mpatches
pal = sns.cubehelix_palette(10, rot=-.25, light=.7)
g = sns.FacetGrid(test.dropna(), row="variable", hue="Model", aspect=10, height=1.5, sharex=False, sharey=False)
g.map(sns.kdeplot, "value",
      bw_adjust=.5, clip_on=False,
      fill=True, alpha=0.5, linewidth=1.5)
g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)

for label, ax in g.axes_dict.items():
    ax.text(0, .2, label, fontweight="normal", color="black",
            ha="left", va="center", transform=ax.transAxes)

# Remove axes details that don't play well with overlap
g.set_titles("")
g.set(yticks=[], ylabel="")
g.despine(bottom=False, left=True)
handles = [mpatches.Patch(facecolor=plt.cm.Reds(100)),
           mpatches.Patch(facecolor=plt.cm.Blues(100))]
g.axes_dict["Overall"].legend(bbox_to_anchor=(1, 3.5), loc = 'best')