In [None]:
import scipy as sp
from numba import jit
import pickle
import utilities as ut
import sklearn.metrics as met
from SAM import SAM
import utilities_full as ut2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from graph_tool.all import Graph, local_clustering
from sklearn.preprocessing import Normalizer

def scatter(self, projection=None, c=None, cmap='rainbow', linewidth=0.0,
            edgecolor='k', axes=None, colorbar=True, s=10, **kwargs):
    PLOTTING=True
    if (not PLOTTING):
        print("matplotlib not installed!")
    else:
        if(isinstance(projection, str)):
            try:
                dt = self.adata.obsm[projection]
            except KeyError:
                print('Please create a projection first using run_umap or'
                      'run_tsne')

        elif(projection is None):
            try:
                dt = self.adata.obsm['X_umap']
            except KeyError:
                try:
                    dt = self.adata.obsm['X_tsne']
                except KeyError:
                    print("Please create either a t-SNE or UMAP projection"
                          "first.")
                    return
        else:
            dt = projection

        if(axes is None):
            plt.figure()
            axes = plt.gca()

        if(c is None):
            plt.scatter(dt[:, 0], dt[:, 1], s=s,
                        linewidth=linewidth, edgecolor=edgecolor, **kwargs)
        else:

            if isinstance(c, str):
                try:
                    c = self.adata.obs[c].get_values()
                except KeyError:
                    0  # do nothing

            if((isinstance(c[0], str) or isinstance(c[0], np.str_)) and
               (isinstance(c, np.ndarray) or isinstance(c, list))):
                i = ut.convert_annotations(c)
                ui, ai = np.unique(i, return_index=True)
                cax = axes.scatter(dt[:,0], dt[:,1], c=i, cmap=cmap, s=s,
                                   linewidth=linewidth,
                                   edgecolor=edgecolor,
                                   **kwargs)

                if(colorbar):
                    cbar = plt.colorbar(cax, ax=axes, ticks=ui)
                    cbar.ax.set_yticklabels(c[ai])
            else:
                if not (isinstance(c, np.ndarray) or isinstance(c, list)):
                    colorbar = False
                i = c

                cax = axes.scatter(dt[:,0], dt[:,1], c=i, cmap=cmap, s=s,
                                   linewidth=linewidth,
                                   edgecolor=edgecolor,
                                   **kwargs)

                if(colorbar):
                    plt.colorbar(cax, ax=axes)
def DARMANIS(**kwargs):
    sam = SAM()
    sam.load_data('darmanis/darmanis_data.csv')
    sam.load_obs_annotations('darmanis/darmanis_ann.csv')
    sam.preprocess_data(**kwargs)
    return sam

def WANG(**kwargs):
    sam = SAM()
    sam.load_data('final_datasets/GSE83139/wang_data_sparse.p', **kwargs)
    sam.preprocess_data(**kwargs)
    A = pd.read_csv('final_datasets/GSE83139/wang_ann.csv',header=None,index_col=0)    
    A.index = A.index.astype("<U100")
    sam.adata.obs['ann'] = A
    sam.adata.var_names_make_unique()
    return sam

def human1(**kwargs):
    sam = SAM()
    sam.load_data('final_datasets/GSE84133_1/human1_sparse.p')
    sam.preprocess_data(**kwargs)
    sam.load_obs_annotations('final_datasets/GSE84133_1/human1_ann.csv')
    return sam


def human2(**kwargs):
    sam = SAM()
    sam.load_data('final_datasets/GSE84133_2/human2_sparse.p')
    sam.preprocess_data(**kwargs)
    sam.load_obs_annotations('final_datasets/GSE84133_2/human2_ann.csv')
    return sam

def human3(**kwargs):
    sam = SAM()
    sam.load_data('final_datasets/GSE84133_3/human3_sparse.p')
    sam.preprocess_data(**kwargs)
    sam.load_obs_annotations('final_datasets/GSE84133_3/human3_ann.csv')
    return sam
def human4(**kwargs):
    sam = SAM()
    sam.load_data('final_datasets/GSE84133_4/human4_sparse.p')
    sam.preprocess_data(**kwargs)
    sam.load_obs_annotations('final_datasets/GSE84133_4/human4_ann.csv')
    return sam


def KOH(**kwargs):
    sam = SAM()
    sam.load_data('final_datasets/SRP073808/SRP073808_data.csv')
    sam.load_obs_annotations('final_datasets/SRP073808/SRP073808_ann.csv')
    sam.preprocess_data(**kwargs)
    return sam


def SEGER(**kwargs):
    sam=SAM()
    sam.load_data('final_datasets/seger/seger_sparse.p')
    sam.load_obs_annotations('final_datasets/seger/seger_ann.csv')
    sam.preprocess_data(**kwargs)
    return sam


def MURARO(**kwargs):
    sam=SAM()
    sam.load_data('final_datasets/muraro/muraro_sparse.p')
    sam.load_obs_annotations('final_datasets/muraro/muraro_ann.csv')
    sam.preprocess_data(**kwargs)
    return sam
def nmi(x, y):
    return met.adjusted_mutual_info_score(x, y, average_method='arithmetic')


def ari(x, y):
    return met.adjusted_rand_score(x, y)
def SEURAT(adata,npcs,ngenes):
    pca,_,_,_ = ut2.do_SEURAT4(adata.copy(),npcs=npcs,NN=ngenes)
    cl = hdbknn(pca)
    return pca,cl
    #RECORD['('+str(npcs)+','+str(ngenes)+')'] = cl
def hdbknn(X):
    import hdbscan
    k=20

    hdb = hdbscan.HDBSCAN(metric='euclidean')

    cl = hdb.fit_predict(X)

    idx0 = np.where(cl != -1)[0]
    idx1 = np.where(cl == -1)[0]
    if idx1.size > 0 and idx0.size > 0:
        xcmap = ut.generate_euclidean_map(X[idx0, :], X[idx1, :])
        knn = np.argsort(xcmap.T, axis=1)[:, :k]
        nnm = np.zeros(xcmap.shape).T
        nnm[np.tile(np.arange(knn.shape[0])[:, None],
                    (1, knn.shape[1])).flatten(),
            knn.flatten()] = 1
        nnmc = np.zeros((nnm.shape[0], cl.max() + 1))
        for i in range(cl.max() + 1):
            nnmc[:, i] = nnm[:, cl[idx0] == i].sum(1)

        cl[idx1] = np.argmax(nnmc, axis=1)

    return cl
@jit(nopython=True)
def permute(D, npermutes):
    for i in range(npermutes):
        x1, y1 = np.random.randint(0, D.shape[0]), np.random.randint(
            0, D.shape[1])
        x2, y2 = np.random.randint(0, D.shape[0]), np.random.randint(
            0, D.shape[1])
        num = D[x1, y1]
        D[x1, y1] = D[x2, y2]
        D[x2, y2] = num
        
def permute2(D, n_elements):
    x,y = np.unravel_index(np.random.choice(D.size,size=n_elements,replace=False),D.shape)
    ind = np.random.permutation(x.size)
    D[x[ind],y[ind]] = D[x,y]
    
def modularity(graph, cl):
    indegree = graph.sum(0).flatten()
    outdegree = graph.sum(1).flatten()
    m = graph.sum()

    C = np.zeros(graph.shape)
    for i in range(cl.max() + 1):
        idxs = np.where(cl == i)[0]
        idxs1 = np.repeat(idxs, idxs.size)
        idxs2 = np.tile(idxs, idxs.size)
        C[idxs1, idxs2] = 1

    Q = ((graph - indegree[:, None] * outdegree[None, :] / m) * C / m).sum()
    return Q


def generate_graph(graph):
    Npca = graph.copy()
    G = Graph(directed=True)

    Npca[np.arange(Npca.shape[0]), np.arange(Npca.shape[0])] = 0

    G.add_edge_list(np.transpose(Npca.nonzero()))
    return G


def local_clust(graph):
    G = generate_graph(graph)
    return np.mean(list(local_clustering(G, undirected=False)))

def l2disp(data, graph, k, N):
    davg = graph.dot(data / k)
    mu = davg.mean(0)

    disp = davg.var(0)[mu > 0] / mu[mu > 0]
    disp[np.isnan(disp)] = 0
    return np.sqrt(np.sum(np.sort(disp)[-N:]**2))

In [None]:
preproc = ['Normalizer','Normalizer']+['StandardScaler',]*4+['Normalizer',]*3
d = dict(min_expression=1,filter_genes=False)
funcs = [DARMANIS(**d),WANG(**d),human1(**d),human2(**d),human3(**d),human4(**d),KOH(**d),SEGER(**d),MURARO(**d)]
names = ['DARMANIS','WANG','human1','human2','human3','human4','KOH','SEGER','MURARO']

In [None]:
ngenes = [500,1000,1500,2000,2500,3000,3500,4000,4500,5000,5500,6000,6500,7000,-1]
npcs = [6,10,15,20,25,30,35,40,45,50]
RECORDd = pickle.load(open('paper_scripts/seurat_param_sweep_bigger_fixed.p','rb'))
nge=[]
npc=[]
for i in names:
    ind = np.where(RECORDd[i]>= RECORDd[i].max())
    n1,n2 = ngenes[ind[0][0]],npcs[ind[1][0]]
    nge.append(n1)
    npc.append(n2)
    
nge[1]=3000
npc[1]=6

In [None]:
import copy

ARIsamL=[]
ARIseurR=[]
ARIseurL=[]
ARIseurO=[]
for i in range(len(funcs)):
    sam = copy.deepcopy(funcs[i])
    D = sam.adata_raw.X.copy()
    
    NTRIALS=50
    NREPLICATES=10
    
    AMIsam = np.zeros((NREPLICATES,NTRIALS))    
    AMIseur = np.zeros((NREPLICATES,NTRIALS))
    AMIseuro = np.zeros((NREPLICATES,NTRIALS))
    AMIseurr = np.zeros((NREPLICATES,NTRIALS))
    
    ARIsam = np.zeros((NREPLICATES,NTRIALS))    
    ARIseur = np.zeros((NREPLICATES,NTRIALS))    
    ARIseuro = np.zeros((NREPLICATES,NTRIALS)) 
    ARIseurr = np.zeros((NREPLICATES,NTRIALS)) 


    AMIsaml = np.zeros((NREPLICATES,NTRIALS))    
    AMIseurl = np.zeros((NREPLICATES,NTRIALS))
    AMIseurlo = np.zeros((NREPLICATES,NTRIALS))    
    AMIseurlr = np.zeros((NREPLICATES,NTRIALS))    

    ARIsaml = np.zeros((NREPLICATES,NTRIALS))    
    ARIseurl = np.zeros((NREPLICATES,NTRIALS))       
    ARIseurlo = np.zeros((NREPLICATES,NTRIALS))
    ARIseurlr = np.zeros((NREPLICATES,NTRIALS))
    
    METRICSsam = np.zeros((3,NREPLICATES,NTRIALS))    
    METRICSseur = np.zeros((3,NREPLICATES,NTRIALS)) 
    METRICSseuro = np.zeros((3,NREPLICATES,NTRIALS)) 
    METRICSseurr = np.zeros((3,NREPLICATES,NTRIALS)) 
    
    stds = np.round(np.linspace(0,np.prod(sam.adata.shape),NTRIALS)).astype('int64')
    ann = sam.adata.obs.iloc[:,0].get_values()
    
    sam.k=20
    
    for j in range(NTRIALS): 
        for k in range(NREPLICATES):
            print(str(i) + ' --- ' + str (j) + ' --- ' + str(k))
            Ds = D.A.copy()
            permute2(Ds,stds[j])
            sam.adata.X=sp.sparse.csr_matrix(Ds)
            sam.adata.X.data[:] = np.log2(sam.adata.X.data+1)
            sam.adata.X.data[sam.adata.X.data < 1] = 0
            sam.adata.X.eliminate_zeros()
            
            #"""
            sam.adata.layers['X_disp'] = sam.adata.X            
            sam.run(stopping_condition=5e-3,preprocessing=preproc[i],projection=None)
            sam.hdbknn_clustering(npcs=15)            
            sam.louvain_clustering()
            sam_hdb = sam.adata.obs['hdbscan_clusters'].get_values()
            sam_louv = sam.adata.obs['louvain_clusters'].get_values()
            
                        
            AMIsam[k,j] = nmi(sam_hdb,ann)
            ARIsam[k,j] = ari(sam_hdb,ann)
            AMIsaml[k,j] = nmi(sam_louv,ann)
            ARIsaml[k,j] = ari(sam_louv,ann)     

            print('ARIsam: ' + str(ARIsam[k,j]))   
            print('ARIsaml: ' + str(ARIsaml[k,j]))   
            
            X=sam.adata.uns['neighbors']['connectivities'].A            
            clx = sam.louvain_clustering(X=X)
            Qx = modularity(X, clx)
            
            METRICSsam[0,k,j] = local_clust(X)
            METRICSsam[1,k,j] = Qx
            METRICSsam[2,k,j] = l2disp(sam.adata.X.toarray(),X,sam.k,100)   
            print('METRICS sam: ' + str(METRICSsam[:,k,j]))                        
                        
            #"""
            #"""
            adata = sam.adata.copy() # put raw, unfiltered expressions in for seurat
            adata.X = sp.sparse.csr_matrix(Ds)
            if i == 1:
                adata.X[adata.X<1]=0
                adata.X.eliminate_zeros()
            #adata.X.data[:] = 2**adata.X.data-1
                        
            """ L """
            pca,_,_,seur_louv = ut2.do_SEURAT4(adata.copy(),npcs=15,NN=3000)
            seur_hdb = hdbknn(Normalizer().fit_transform(pca))
            ARIseur[k,j] = ari(seur_hdb,ann)
            AMIseur[k,j] = nmi(seur_hdb,ann)
            AMIseurl[k,j] = nmi(seur_louv,ann)
            ARIseurl[k,j] = ari(seur_louv,ann)

            print('ARIseur: ' + str(ARIseur[k,j]))
            print('ARIseurl: ' + str(ARIseurl[k,j]))

            Y = ut.dist_to_nn(ut.compute_distances(pca,'correlation'),sam.k)            
            cly = sam.louvain_clustering(X=Y)
            Qy = modularity(Y, cly)

            
            METRICSseur[0,k,j] = local_clust(Y)
            METRICSseur[1,k,j] = Qy
            METRICSseur[2,k,j] = l2disp(sam.adata.X.toarray(),Y,sam.k,100)
            
            print('METRICS seur: ' + str(METRICSseur[:,k,j]))     
            
            """ O """
            pca,_,_,seur_louv = ut2.do_SEURAT4(adata.copy(),npcs=npc[i],NN=nge[i])
            seur_hdb = hdbknn(Normalizer().fit_transform(pca))
                
            AMIseuro[k,j] = nmi(seur_hdb,ann)
            ARIseuro[k,j] = ari(seur_hdb,ann)
            AMIseurlo[k,j] = nmi(seur_louv,ann)
            ARIseurlo[k,j] = ari(seur_louv,ann)
            
            print('ARIseurO: ' + str(ARIseuro[k,j]))
            print('ARIseurlO: ' + str(ARIseurlo[k,j]))

            Y = ut.dist_to_nn(ut.compute_distances(pca,'correlation'),sam.k)
            cly = sam.louvain_clustering(X=Y)
            Qy = modularity(Y, cly)
 
            METRICSseuro[0,k,j] = local_clust(Y)
            METRICSseuro[1,k,j] = Qy
            METRICSseuro[2,k,j] = l2disp(sam.adata.X.toarray(),Y,sam.k,100)
            print('METRICS seurO: ' + str(METRICSseuro[:,k,j]))             
            #"""
            """ RESCUE """
            pca,_,_,seur_louv = ut2.do_SEURAT4(adata.copy(),npcs=15,genes = sam.adata.uns['ranked_genes'][:3000])
            seur_hdb = hdbknn(Normalizer().fit_transform(pca))
            ARIseurr[k,j] = ari(seur_hdb,ann)
            AMIseurr[k,j] = nmi(seur_hdb,ann)
            AMIseurlr[k,j] = nmi(seur_louv,ann)
            ARIseurlr[k,j] = ari(seur_louv,ann)

            print('ARIseurR: ' + str(ARIseurr[k,j]))
            print('ARIseurlR: ' + str(ARIseurlr[k,j]))

            Y = ut.dist_to_nn(ut.compute_distances(pca,'correlation'),sam.k)            
            cly = sam.louvain_clustering(X=Y)
            Qy = modularity(Y, cly)
            
            METRICSseurr[0,k,j] = local_clust(Y)
            METRICSseurr[1,k,j] = Qy
            METRICSseurr[2,k,j] = l2disp(sam.adata.X.toarray(),Y,sam.k,100)
            
            print('METRICS seurR: ' + str(METRICSseurr[:,k,j]))


    ARIsamL.append((AMIsam,ARIsam,AMIsaml,ARIsaml,METRICSsam))
    ARIseurL.append((AMIseur,ARIseur,AMIseurl,ARIseurl,METRICSseur))
    ARIseurR.append((AMIseurr,ARIseurr,AMIseurlr,ARIseurlr,METRICSseurr))
    ARIseurO.append((AMIseuro,ARIseuro,AMIseurlo,ARIseurlo,METRICSseuro))
    pickle.dump((ARIsamL,ARIseurL,ARIseurR,ARIseurO),open('paper_scripts/8_23_2019_all_permute2.p','wb'))

In [None]:
import numpy as np
def mm2inch(*tupl):
    inch = 25.4
    if isinstance(tupl[0], tuple):
        return tuple(i/inch for i in tupl[0])
    else:
        return tuple(i/inch for i in tupl)
    
#ARIseurR = pickle.load(open('paper_scripts/8_17_2019_seurat_rescued.p','rb'))
#ARIseurL,ARIseurO = pickle.load(open('paper_scripts/to_100p_elife_rev_w_optimize_fixed.p','rb'))
#ARIsamL = pickle.load(open('paper_scripts/8_17_elife_rev_corruption.p','rb'))
ARIsamL,ARIseurL,ARIseurR,ARIseurO = pickle.load(open('paper_scripts/8_23_2019_all_permute2.p','rb'))

name = ['Darmanis','Wang','Baron1','Baron2','Baron3','Baron4','Koh','Segerstolpe','Muraro']#ARI.index[::-1].values
AUC1 = np.zeros((10,len(name),4))
AUC2 = np.zeros((10,len(name),4))
AUC3 = np.zeros((10,len(name),4))
AUC4 = np.zeros((10,len(name),4))

sensitivitycd,sensitivitycd2 = pickle.load(open('paper_scripts/ariamicollection_permutation_sensitivities_extra_darmanis_permute2.p','rb'))
sensitivitycd = sensitivitycd.flatten()
sensitivitycd2 = sensitivitycd2.flatten()

fig,axs = plt.subplots(nrows=5,ncols=1)
fig.set_size_inches(mm2inch(42.4*2,106.214))

z=0
for i in range(9):
    S1 = sensitivitycd
    S1e = sensitivitycd2
    
    for M in range(10):
        A1 = ARIsamL[i][1][M,:]#.mean(0)
        A2 = ARIseurL[i][1][M,:]#.mean(0)
        A3 = ARIseurR[i][1][M,:]#.mean(0)
        A4 = ARIseurO[i][1][M,:]

        B1 = ARIsamL[i][-1][0,:][M,:]#.mean(0)
        C1 = ARIsamL[i][-1][1,:][M,:]#.mean(0)
        D1 = ARIsamL[i][-1][2,:][M,:]#.mean(0)

        B2 = ARIseurL[i][-1][0,:][M,:]#.mean(0)
        C2 = ARIseurL[i][-1][1,:][M,:]#.mean(0)
        D2 = ARIseurL[i][-1][2,:][M,:]#.mean(0)        

        B3 = ARIseurR[i][-1][0,:][M,:]#.mean(0)
        C3 = ARIseurR[i][-1][1,:][M,:]#.mean(0)
        D3 = ARIseurR[i][-1][2,:][M,:]#.mean(0)  
        
        B4 = ARIseurO[i][-1][0,:][M,:]#.mean(0)
        C4 = ARIseurO[i][-1][1,:][M,:]#.mean(0)
        D4 = ARIseurO[i][-1][2,:][M,:]#.mean(0)

        AUC1[M,i,0] = met.auc(np.linspace(0,1,50),A1)
        AUC1[M,i,1] = met.auc(np.linspace(0,1,50),A2)
        AUC1[M,i,2] = met.auc(np.linspace(0,1,50),A3)
        AUC1[M,i,3] = met.auc(np.linspace(0,1,50),A4)

        AUC2[M,i,0] = met.auc(np.linspace(0,1,50),B1)
        AUC2[M,i,1] = met.auc(np.linspace(0,1,50),B2)
        AUC2[M,i,2] = met.auc(np.linspace(0,1,50),B3)
        AUC2[M,i,3] = met.auc(np.linspace(0,1,50),B4)

        AUC3[M,i,0] = met.auc(np.linspace(0,1,50),C1)
        AUC3[M,i,1] = met.auc(np.linspace(0,1,50),C2)
        AUC3[M,i,2] = met.auc(np.linspace(0,1,50),C3)
        AUC3[M,i,3] = met.auc(np.linspace(0,1,50),C4)

        AUC4[M,i,0] = met.auc(np.linspace(0,1,50),D1)
        AUC4[M,i,1] = met.auc(np.linspace(0,1,50),D2)
        AUC4[M,i,2] = met.auc(np.linspace(0,1,50),D3)
        AUC4[M,i,3] = met.auc(np.linspace(0,1,50),D4)
    
    ms = 1
    if name[i] =='Darmanis':# or name[i] == 'Baron2' or name[i] == 'Koh' or name[i] == 'Baron3':# or name[i] == 'Baron4':
        c2='#9481c4'
        c3='black'
        x = np.linspace(0,1.0,A1.size)
        A1 = ARIsamL[i][1].mean(0)
        A2 = ARIseurL[i][1].mean(0)
        A3 = ARIseurR[i][1].mean(0)
        A4 = ARIseurO[i][1].mean(0)

        B1 = ARIsamL[i][-1][0,:].mean(0)
        C1 = ARIsamL[i][-1][1,:].mean(0)
        D1 = ARIsamL[i][-1][2,:].mean(0)

        B2 = ARIseurL[i][-1][0,:].mean(0)
        C2 = ARIseurL[i][-1][1,:].mean(0)
        D2 = ARIseurL[i][-1][2,:].mean(0)

        B3 = ARIseurR[i][-1][0,:].mean(0)
        C3 = ARIseurR[i][-1][1,:].mean(0)
        D3 = ARIseurR[i][-1][2,:].mean(0)
        
        B4 = ARIseurO[i][-1][0,:].mean(0)
        C4 = ARIseurO[i][-1][1,:].mean(0)
        D4 = ARIseurO[i][-1][2,:].mean(0)
        
        A1s = ARIsamL[i][1].std(0)
        A2s = ARIseurL[i][1].std(0)
        A3s = ARIseurR[i][1].std(0)
        A4s = ARIseurO[i][1].std(0)

        B1s = ARIsamL[i][-1][0,:].std(0)
        C1s = ARIsamL[i][-1][1,:].std(0)
        D1s = ARIsamL[i][-1][2,:].std(0)

        B2s = ARIseurL[i][-1][0,:].std(0)
        C2s = ARIseurL[i][-1][1,:].std(0)
        D2s = ARIseurL[i][-1][2,:].std(0)

        B3s = ARIseurR[i][-1][0,:].std(0)
        C3s = ARIseurR[i][-1][1,:].std(0)
        D3s = ARIseurR[i][-1][2,:].std(0)        

        B4s = ARIseurO[i][-1][0,:].std(0)
        C4s = ARIseurO[i][-1][1,:].std(0)
        D4s = ARIseurO[i][-1][2,:].std(0)   
        
        axs[1].errorbar(x,A1,yerr = A1s,color = 'blue',marker='.',linewidth=0.5,markersize=ms)        
        axs[1].errorbar(x,A2,yerr = A2s,color = 'red',marker='.',linewidth=0.5,markersize=ms)        
        axs[1].errorbar(x,A3,yerr = A3s,color = c2,marker='.',linewidth=0.5,markersize=ms)
        axs[1].errorbar(x,A4,yerr = A4s,color = c3,marker='.',linewidth=0.5,markersize=ms)
        
        axs[0].tick_params(pad=1)
        axs[1].tick_params(pad=1)
        axs[2].tick_params(pad=1)
        axs[3].tick_params(pad=1)
        axs[4].tick_params(pad=1)
        
        axs[1].set_ylabel('ARI',fontsize=7)
        axs[1].set_xticks([])
        axs[1].set_yticks([0,0.5,1.0])
        f=axs[1].get_xticklabels()
        f2=axs[1].get_yticklabels()
        for ii in f: ii.set_fontsize(7)
        for ii in f2: ii.set_fontsize(7)                
        
        axs[0].set_title(name[i],fontsize=8)

        axs[2].errorbar(x,B1,yerr = B1s,color = 'blue',marker='.',linewidth=0.5,markersize=ms)    
        axs[2].errorbar(x,B2,yerr = B2s,color = 'red',marker='.',linewidth=0.5,markersize=ms)
        axs[2].errorbar(x,B3,yerr = B3s,color = c2,marker='.',linewidth=0.5,markersize=ms)
        axs[2].errorbar(x,B4,yerr = B4s,color = c3,marker='.',linewidth=0.5,markersize=ms)
        
        axs[2].set_ylabel('NACC',fontsize=7)
        axs[2].set_xticks([])
        
        f=axs[2].get_xticklabels()
        f2=axs[2].get_yticklabels()
        for ii in f: ii.set_fontsize(7)
        for ii in f2: ii.set_fontsize(7)    
            
        axs[3].errorbar(x,C1,yerr = C1s,color = 'blue',marker='.',linewidth=0.5,markersize=ms)    
        axs[3].errorbar(x,C2,yerr = C2s,color = 'red',marker='.',linewidth=0.5,markersize=ms)
        axs[3].errorbar(x,C3,yerr = C3s,color = c2,marker='.',linewidth=0.5,markersize=ms)  
        axs[3].errorbar(x,C4,yerr = C4s,color = c3,marker='.',linewidth=0.5,markersize=ms)  
        
        axs[3].set_ylabel('Modularity',fontsize=7)
        axs[3].set_xticks([])
       
        f=axs[3].get_xticklabels()
        f2=axs[3].get_yticklabels()
        for ii in f: ii.set_fontsize(7)
        for ii in f2: ii.set_fontsize(7)    
            
        #plt.savefig('paper_scripts/FIGURE4/'+name[i]+'_C.pdf')
        
        #plt.figure(figsize=mm2inch(36,21));
        axs[4].errorbar(x,D1,yerr = D1s,color = 'blue',marker='.',linewidth=0.5,markersize=ms)    
        axs[4].errorbar(x,D2,yerr = D2s,color = 'red',marker='.',linewidth=0.5,markersize=ms)
        axs[4].errorbar(x,D3,yerr = D3s,color = c2,marker='.',linewidth=0.5,markersize=ms)         
        axs[4].errorbar(x,D4,yerr = D4s,color = c3,marker='.',linewidth=0.5,markersize=ms) 
      
        axs[4].set_ylabel('| Dispersion |',fontsize=7)
        f=axs[4].get_xticklabels()
        f2=axs[4].get_yticklabels()
        for ii in f: ii.set_fontsize(7)
        for ii in f2: ii.set_fontsize(7)    
        
        axs[0].errorbar(x,S1,yerr=S1e,color='black',marker='.',linewidth=0.5,markersize=ms)
        axs[4].set_xlabel('Corruption',fontsize=7)
        axs[0].set_ylabel('Sensitivity',fontsize=7)
        axs[4].set_xticks([0,0.5,1.0])   
        axs[0].set_xticks([])        
        
        f=axs[0].get_xticklabels()
        f2=axs[0].get_yticklabels()
        for ii in f: ii.set_fontsize(7)
        for ii in f2: ii.set_fontsize(7)            
            
        fig.subplots_adjust(wspace=0.15,hspace=0.2,left=0.3,right=1,top=0.94,bottom=0.06)
        fig.align_ylabels(axs[:])
        plt.savefig('/media/storage/dbox/Dropbox/paper_scripts/FIGURE4_fixed/corruptions_seurO_final.pdf',transparent=True)

#plt.savefig('paper_scripts/FIGURE4/corruptions.pdf')
#plt.savefig('/media/storage/dbox/Dropbox/paper_scripts/FIGURE4_fixed/corruptions2.pdf',transparent=True)
        
AUC1pm = pd.DataFrame(data=AUC1.mean(0),index=name,columns = ['SAM','Seurat','Seurat rescued','Seurat optimized'])        
AUC2pm = pd.DataFrame(data=AUC2.mean(0),index=name,columns = ['SAM','Seurat','Seurat rescued','Seurat optimized'])
AUC3pm = pd.DataFrame(data=AUC3.mean(0),index=name,columns = ['SAM','Seurat','Seurat rescued','Seurat optimized'])
AUC4pm = pd.DataFrame(data=AUC4.mean(0),index=name,columns = ['SAM','Seurat','Seurat rescued','Seurat optimized'])
        
AUC1ps = pd.DataFrame(data=AUC1.std(0),index=name,columns = ['SAM','Seurat','Seurat rescued','Seurat optimized'])        
AUC2ps = pd.DataFrame(data=AUC2.std(0),index=name,columns = ['SAM','Seurat','Seurat rescued','Seurat optimized'])
AUC3ps = pd.DataFrame(data=AUC3.std(0),index=name,columns = ['SAM','Seurat','Seurat rescued','Seurat optimized'])
AUC4ps = pd.DataFrame(data=AUC4.std(0),index=name,columns = ['SAM','Seurat','Seurat rescued','Seurat optimized'])
#I=
#plt.figure(); plt.scatter(s2[0,I,:,:].flatten(),ss2[I,:,:].flatten()); plt.xlabel('NACC'); plt.ylabel('ARI')
#plt.figure(); plt.scatter(s2[1,I,:,:].flatten(),ss2[I,:,:].flatten()); plt.xlabel('Modularity'); plt.ylabel('ARI')
#plt.figure(); plt.scatter(s2[2,I,:,:].flatten(),ss2[I,:,:].flatten()); plt.xlabel('l2disp'); plt.ylabel('ARI')
#sam.save_figures('ami_metric_scatter.pdf')

In [None]:
fig,axs = plt.subplots(nrows=1,ncols=4)
fig.set_size_inches(mm2inch(85,102.5))

colors = ['black','#9481c4','red','blue']
auc = [AUC1pm,AUC2pm,AUC3pm,AUC4pm]
aucs = [AUC1ps,AUC2ps,AUC3ps,AUC4ps]
nm = ['ARI','NACC','Modularity','|Dispersion|']

for I in range(len(axs)):
    f=axs[I].get_xticklabels()
    f2=axs[I].get_yticklabels()
    for ii in f: ii.set_fontsize(7)
    for ii in f2: ii.set_fontsize(7)    
    axs[I].tick_params(pad=1)
        
for I in range(4):
    au = auc[I].copy()
    aus = aucs[I].copy()
    
    au = au.T[['Darmanis','Wang','Segerstolpe','Muraro','Koh','Baron1','Baron2','Baron3','Baron4']].T
    au = au.iloc[:,::-1]
    au = au.iloc[::-1,:]

    aus = aus.T[['Darmanis','Wang','Segerstolpe','Muraro','Koh','Baron1','Baron2','Baron3','Baron4']].T
    aus = aus.iloc[:,::-1]
    aus = aus.iloc[::-1,:]
    
    barlist=au.plot.barh(ax=axs[I],linewidth=0.0,xerr=aus,error_kw=dict(ecolor='black',elinewidth=0.5))#,color = 'black')

    z=0
    for i in range(au.shape[1]):
        c = colors[i]
        for j in range(au.shape[0]):
            barlist.get_children()[z+4].set_color(c)
            z+=1

    axs[I].get_legend().remove()
    box = axs[I].get_position()
    axs[I].set_position([box.x0,box.y0,box.width,box.height])
    #if I == 0:
    #    axs[I].set_ylabel('AUC',fontsize=7,fontname='Arial')
    
    if I > 0:
        axs[I].set_yticks([])
        axs[I].set_yticklabels([])
    
    axs[I].set_title(nm[I],fontsize=8)
    #axs[I].set_xlabel('AUC',fontsize=7)
    

    
han,lab = axs[-1].get_legend_handles_labels()
lab = lab[::-1]
han = han[::-1]
fig.subplots_adjust(top=0.95,bottom=0.05)
#lab,han = zip(*sorted(zip(lab,han),key = lambda t: t[0]))
#axs[-1].legend(han,lab,loc='upper right',bbox_to_anchor=(2,0.5))#,bbox_to_anchor=(1,0.5))

#plt.savefig('paper_scripts/FIGURE4/AUC_corruption.pdf')
plt.savefig('/media/storage/dbox/Dropbox/paper_scripts/FIGURE4_fixed/AUC_corruptionO.pdf',transparent=True)


In [None]:
# In[]
import utilities_full as ut2

NTRIALS=50

funcs = [DARMANIS]#,human1,human2,human3,human4,KOH,GOLDSTAND]
sensitivitycd=np.zeros((len(funcs),NTRIALS))
sensitivitycd2=np.zeros((len(funcs),NTRIALS))


for i in range(len(funcs)):
    print(i)
    sam=funcs[i]()
        
    D = sam.adata.X.copy()
    D=D.toarray()
    idx = np.where(D.mean(0)>0)[0]
    z = ut2.get_fano_zscore2(D[:,idx].mean(0),D[:,idx].var(0))
    z[np.isnan(z)]=0
    idx = idx[np.argsort(-z)][:10000]    
    D = D[:,idx]    
    sam.adata = sam.adata[:,idx]
    stds = np.round(np.linspace(0,np.prod(sam.adata.shape),NTRIALS)).astype('int64')
    
    for mmm in range(stds.size):
        Ds = D.copy()
        permute2(Ds,stds[mmm])
        sam.adata.X=sp.sparse.csr_matrix(Ds)
        
        
        print(mmm)    
    
        DDIF=[]
        
        NTR=10
        for I in range(NTR):
            W=np.random.choice(sam.adata.shape[1],size = int(2000),replace=False)
            #W=sam.normalizer(W)
            g=ut.weighted_PCA(Normalizer().fit_transform(sam.adata.X[:,W].A),do_weight=False,npcs=15)[0]
            d=ut.compute_distances(g,'correlation')
            Nn=ut.dist_to_nn(d,20)[0]
            #NDIF.append(Nn)
            DDIF.append(d)
            #WDIF.append(g)
        print(g.shape)
    
        
        xx=NTR
        ddif=np.zeros((xx,xx))
        for j in range(xx):
            for z in range(xx):
                ddif[j,z]=np.diag(ut2.generate_correlation_map(DDIF[j],DDIF[z])).mean()
                
        sensitivitycd[i,mmm] = np.mean(1-ddif[ddif<1])
        sensitivitycd2[i,mmm] = np.std(1-ddif[ddif<1])    

In [None]:
#pickle.dump((sensitivitycd,sensitivitycd2),open('paper_scripts/ariamicollection_permutation_sensitivities_extra_darmanis_permute2.p','wb'))

In [4]:
from SAMGUI import SAMGUI
from SAM import SAM
import pandas as pd
sam = SAM()
sam.load('paper_scripts/FIGURE1/manifold_graph_convergence_figures.p')


In [5]:
co = sam.co
dt = sam.umap2d
cl = sam.cluster_labels_k
sam=SAM(counts=(sam.sparse_data,sam.all_gene_names,sam.all_cell_names))
sam.adata.obsm['X_umap'] = dt

In [6]:
sam.adata.obs['cl']=pd.Categorical(cl)

In [7]:
sg = SAMGUI(sam)
sg.SamPlot

HBox(children=(Tab(children=(FigureWidget({
    'data': [{'hoverinfo': 'text',
              'marker': {'size'…

In [None]:
np.array([(0.49411765, 0.20784314, 0.63529412)])*255

In [9]:
sg.stab.children[0].data[0].marker.colorscale = [[0.0,'rgb(237, 30, 66)'],[0.3333333333 ,'rgb(42, 84,165)'
                                        ],[0.666666666, 'rgb(15,178,64)'],[1.0 ,'rgb(126,53,162)']]