In [10]:
import scanpy as sc
import pertpy as pt
import numpy as np
import pandas as pd
from scipy.cluster.hierarchy import distance, linkage, dendrogram
from seaborn import clustermap
import seaborn as sns
import matplotlib.pyplot as plt
import episcanpy as epi
from scipy import stats
import edistance as ed
from utils import equal_subsampling

In [11]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams.update({'font.size': 7})
fig_title = "Norman et al"
workingdir = "output/metric-comp-norman/"

In [12]:
dist_dfs = {}

#dist_dfs['mmd'] = pd.read_csv(workingdir + "mmd.csv", index_col=0)
dist_dfs['pseudobulk'] = pd.read_csv(workingdir + "pseudobulk.csv", index_col=0)
dist_dfs['wasserstein'] = pd.read_csv(workingdir + "wasserstein.csv", index_col=0)
dist_dfs['edist'] = pd.read_csv(workingdir + "edist.csv", index_col=0)
dist_dfs['mmd_gauss'] = pd.read_csv(workingdir + "mmd_rbf_gamma_0pt05.csv", index_col=0)
adata = sc.read_h5ad(workingdir + "processed_subset.h5ad")


In [13]:
magnitudes = pd.concat([dist_dfs[key]['control'] for key in dist_dfs], axis=1, keys =[key for key in dist_dfs] )

In [14]:
# remove the corresponding 0 row

magnitudes = magnitudes[magnitudes['edist']!=0]
magnitudes['perturbation']= magnitudes.index


In [15]:
pd.plotting.scatter_matrix(magnitudes, alpha = .7, diagonal = "kde")
plt.suptitle(fig_title  + " scatter matrix")
plt.savefig(workingdir + "scatter_matrix.pdf")

In [16]:
magma = pd.melt(magnitudes, id_vars= "perturbation")
magma.columns = ['perturbation','metric','distance']
magma = magma[magma['perturbation']!="control"]


In [None]:
g2 = sns.catplot(data = magma, x = 'perturbation', y = 'distance', hue = 'metric', order =magnitudes.sort_values('edist')['perturbation'] ,
                aspect=2)
#g2.set(xticklabels=[])  
g2.set(title=fig_title+"Distances",xticklabels=[])
#g2.set_yscale("log")
plt.xticks(rotation=90)
plt.yscale("log")
plt.savefig(workingdir + "catplot_log.pdf", bbox_inches="tight")

In [None]:
g2 = sns.catplot(data = magma, x = 'perturbation', y = 'distance', hue = 'metric', order =magnitudes.sort_values('edist')['perturbation'] ,
                aspect=2)
#g2.set(xticklabels=[])  
g2.set(title=fig_title+" Distances",xticklabels=[])
plt.xticks(rotation=90)
plt.savefig(workingdir + "catplot.pdf", bbox_inches="tight")

#plt.yscale("log")

In [26]:
def normalize(df, leaveout= ["perturbation"]):
    result = df.copy()
    for feature_name in df.columns:
        if feature_name in leaveout:
            result[feature_name] = df[feature_name]
        else:
            max_value = df[feature_name].max()
            min_value = df[feature_name].min()
            result[feature_name] = (df[feature_name] - min_value) / (max_value - min_value)
    return result

In [27]:
normed_mag = normalize(magnitudes)

In [28]:
normed_magma = pd.melt(normed_mag, id_vars= "perturbation")
normed_magma.columns = ['perturbation','metric','distance']


In [35]:
g2 = sns.catplot(data = normed_magma[normed_magma['metric']!="mmd"], x = 'perturbation', y = 'distance', hue = 'metric', order =magnitudes.sort_values('edist')['perturbation'] ,
                aspect=2)
#g2.set(xticklabels=[])  
g2.set(title=fig_title+" Distances",xticklabels=[])
#g2.set_yscale("log")
plt.xticks(rotation=90)
plt.yscale("log")

plt.savefig(workingdir + "catplot_log_minmax_norm.pdf", bbox_inches="tight")

In [36]:
g2 = sns.catplot(data = normed_magma, x = 'perturbation', y = 'distance', hue = 'metric', order =magnitudes.sort_values('edist')['perturbation'] ,
                aspect=2)
#g2.set(xticklabels=[])  
g2.set(title=fig_title +" Distances", xticklabels=[])
plt.xticks(rotation=90)


plt.savefig(workingdir + "catplot_minmax_norm.pdf", bbox_inches="tight")
