In [2]:
import scanpy as sc
import numpy as np
import pandas as pd
import os
from SDMBench import *
import palettable
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
pd_df_rank = pd.read_feather('performance_summary_rank.feather')

In [5]:
method_order = [
    'louvain','leiden',
    'SpaGCN_without','SpaGCN_with','BayesSpace','stLearn','SEDR',
    'CCST','SCAN-IT','STAGATE','SpaceFlow','conST_nopre','BASS','DeepST'
]

In [6]:
cmp = palettable.tableau.Tableau_20.mpl_colors
method_color_dict = dict(zip(method_order,cmp))
def make_cmp(method_color_dict,method_list):
    idx = []
    method_key = list(method_color_dict.keys())
    color_val = []
    for i in range(len(method_color_dict)):
        cur_method = method_key[i]
        cur_color = method_color_dict[cur_method]
        color_val.append(cur_color)
        if cur_method in method_list:
            idx.append(i)
    return [method_key[i] for i in idx],[color_val[i] for i in idx]
    

In [8]:
method_list

array(['BASS', 'CCST', 'DeepST', 'SCAN-IT', 'SEDR', 'STAGATE',
       'SpaGCN_without', 'SpaceFlow', 'conST_nopre', 'leiden', 'louvain'],
      dtype=object)

In [18]:
for biotech in np.unique(pd_df_rank['Biotech']):
    pd_df_rank_biotech = pd_df_rank[pd_df_rank['Biotech']==biotech]
    agg_dict = {
        'Accuracy':[],
        'Continuity':[],
        'Method':[],
        'Accuracy_err':[],
        'Continuity_err':[]
    }
    method_list = np.unique(pd_df_rank_biotech['Method'])
    cur_method_order,cur_cmp = make_cmp(method_color_dict,method_list)
    
    for method in cur_method_order:
        cur_select = pd_df_rank_biotech[pd_df_rank_biotech.Method==method]
        Accuracy = np.mean(cur_select['Accuracy'])
        Continuity = np.mean(cur_select['Continuity'])
        # Accuracy_err = np.var(cur_select['Accuracy'])
        # Continuity_err = np.var(cur_select['Continuity'])
        Accuracy_err = np.std(cur_select['Accuracy'])
        Continuity_err = np.std(cur_select['Continuity'])



        agg_dict['Method'].append(method)
        agg_dict['Accuracy'].append(Accuracy)    
        agg_dict['Continuity'].append(Continuity)    
        agg_dict['Accuracy_err'].append(Accuracy_err)    
        agg_dict['Continuity_err'].append(Continuity_err)    


    fig,ax = plt.subplots(1,1,figsize=(5,5))
    ax.errorbar(agg_dict['Accuracy'], 
                 agg_dict['Continuity'],
                 xerr = agg_dict['Accuracy_err'],
                 yerr = agg_dict['Continuity_err'],
                 fmt='o',
                 markersize=0,
                 ecolor='k'
                 # color=['k']*9
                 # , xerr=c, yerr=d, fmt="o", color="r")
                )

    ax.scatter(agg_dict['Accuracy'], 
                 agg_dict['Continuity'],
                c = cur_cmp,
                s=100,
                zorder=3
                # markersize
                )


    plt.savefig(f'figures/biotech_rank_scatter/{biotech}.pdf',dpi=400,bbox_inches='tight',transparent=True)
    plt.close()
    # plt.show()