In [1]:
import pandas as pd
import scanpy as sc
import numpy as np
import h5py

import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
from matplotlib.pyplot import plot,savefig
from sklearn import metrics

import warnings
warnings.filterwarnings("ignore")
from read_count import read_data

# Prepare

In [2]:
def plot_cluster(df, method_name, y_true, by, n, ax):
    
    """
        by: 'pred' or 'true'. If by == 'pred', colored by cluster labels, else colored by true cell types.
        n: n-th method in [scScope, scDeepCluster, DESC, graph-sc, SCCAF, ADClust, scAce]
    """
    
    if method_name in ['Seurat', 'CIDR']:
        y_pred = np.array(df['cluster'])
        tsne_1 = np.array(df['tSNE_1']).reshape(len(y_pred), 1)
        tsne_2 = np.array(df['tSNE_2']).reshape(len(y_pred), 1)
        tsne = np.concatenate((tsne_1, tsne_2), axis=1)
    
    elif method_name == 'scAce':
        tsne = tsne_all[n]
        y_pred = df['Clusters'][-1][-1]
        
    else:
        tsne = tsne_all[n]
        y_pred = df['Clusters']
    
    K = len(np.unique(y_pred))
    
    y_pred = np.asarray(y_pred, dtype='int').squeeze()
    ari = np.round(metrics.adjusted_rand_score(y_pred, y_true), 2)
    nmi = np.round(metrics.normalized_mutual_info_score(y_pred, y_true), 2)
    print('Method: {}, ARI={}, NMI={}, k={}'.format(method_name, ari, nmi, K))
        
    adata = sc.AnnData(pd.DataFrame(np.random.rand(len(y_pred), 1)))
    adata.obs['pred'] = y_pred
    adata.obs['pred'] = adata.obs['pred'].astype(str).astype('category')
    adata.obs['true'] = y_true
    adata.obs['true'] = adata.obs['true'].astype(str).astype('category')

    adata.obsm['X_tsne'] = tsne

    if by == 'pred':
        sc.pl.tsne(adata, color=['pred'], ax=ax, show=False, legend_loc='None', size=20)
        ax.set_title('K={} ARI={}'.format(K, ari), fontsize=15, family='Arial')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        
    else:
        sc.pl.tsne(adata, color=['true'], ax=ax, show=False, legend_loc='None', size=20, palette=col)
#         ax.set_title(method_name, fontsize=15, family='Arial')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.set_xticks([])

In [3]:
fig = plt.figure(figsize=(10, 8))
sub_figs = fig.subfigures(3, 1)
axs = []

for i, sub_fig in enumerate(sub_figs):
    axs.append(sub_fig.subplots(1, 3))
    
axs = np.array(axs)

In [5]:
seurat = pd.read_csv('results/default/Sim/Seurat_wo_sample.csv', header=0, index_col=0)
cidr = pd.read_csv('results/default/Sim/CIDR_wo_sample.csv', header=0, index_col=0)
scscope = np.load('results/default/Sim/scScope_wo_sample.npz')
scd = np.load('results/default/Sim/scDeepCluster_wo_sample.npz')
desc = np.load('results/default/Sim/DESC_wo_sample.npz')
graphsc = np.load('results/default/Sim/graphsc_wo_sample.npz')
sccaf = np.load('results/default/Sim/SCCAF_wo_sample.npz')
adclust = np.load('results/default/Sim/ADClust_wo_sample.npz')
scace = np.load('results/default/Sim/scAce_wo_sample.npz')

# Calculate t-SNE

In [6]:
methods = [scscope, scd, desc, graphsc, sccaf, adclust, scace]

In [7]:
embedding = []
for i, method in enumerate(methods):
    
    if i == len(methods) - 1:
        embedding.append(method['Embedding'][-1])
    else:
        embedding.append(method['Embedding'])

In [8]:
tsne_all = []
for i in range(len(embedding)):
    adata = sc.AnnData(embedding[i])
    sc.tl.tsne(adata, random_state=0)
    tsne_all.append(np.array(adata.obsm['X_tsne']))

         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


# Plot

In [9]:
data_mat = h5py.File('dataset/Sim.h5')
y_true = np.array(data_mat['Y'], dtype='int')
data_mat.close()

In [10]:
plot_cluster(scscope, 'scScope', y_true, 'pred', 0, axs[0][0])
plot_cluster(scd, 'scDeepCluster', y_true, 'pred', 1, axs[0][1])
plot_cluster(sccaf, 'SCCAF', y_true, 'pred', 4, axs[0][2])
plot_cluster(seurat, 'Seurat', y_true, 'pred', None, axs[1][0])
plot_cluster(adclust, 'ADClust', y_true, 'pred', 5, axs[1][1])
plot_cluster(cidr, 'CIDR', y_true, 'pred', None, axs[1][2])
plot_cluster(graphsc, 'graph-sc', y_true, 'pred', 3, axs[2][0])
plot_cluster(desc, 'DESC', y_true, 'pred', 2, axs[2][1])
plot_cluster(scace, 'scAce', y_true, 'pred', 6, axs[2][2])

Method: scScope, ARI=0.27, NMI=0.49, k=8
Method: scDeepCluster, ARI=0.67, NMI=0.82, k=5
Method: SCCAF, ARI=0.67, NMI=0.88, k=6
Method: Seurat, ARI=0.75, NMI=0.89, k=6
Method: ADClust, ARI=0.77, NMI=0.8, k=3
Method: CIDR, ARI=0.78, NMI=0.84, k=3
Method: graph-sc, ARI=0.93, NMI=0.93, k=5
Method: DESC, ARI=0.98, NMI=0.96, k=5
Method: scAce, ARI=1.0, NMI=1.0, k=5


In [11]:
fig

<Figure size 1000x800 with 9 Axes>

In [12]:
plt.savefig('Figures/Figure2A.svg', dpi=300, format='svg', bbox_inches='tight')

In [13]:
fig = plt.figure(figsize=(10, 8))
sub_figs = fig.subfigures(3, 1)
axs = []

for i, sub_fig in enumerate(sub_figs):  
    axs.append(sub_fig.subplots(1, 3))
    
axs = np.array(axs)

In [14]:
col = ["#E64B35CC", "#0072B5CC", "#00A087CC", "#3C5488CC", "#F39B7FCC"]

In [15]:
plot_cluster(scscope, 'scScope', y_true, 'true', 0, axs[0][0])
plot_cluster(scd, 'scDeepCluster', y_true, 'true', 1, axs[0][1])
plot_cluster(sccaf, 'SCCAF', y_true, 'true', 4, axs[0][2])
plot_cluster(seurat, 'Seurat', y_true, 'true', None, axs[1][0])
plot_cluster(adclust, 'ADClust', y_true, 'true', 5, axs[1][1])
plot_cluster(cidr, 'CIDR', y_true, 'true', None, axs[1][2])
plot_cluster(graphsc, 'graph-sc', y_true, 'true', 3, axs[2][0])
plot_cluster(desc, 'DESC', y_true, 'true', 2, axs[2][1])
plot_cluster(scace, 'scAce', y_true, 'true', 6, axs[2][2])

Method: scScope, ARI=0.27, NMI=0.49, k=8
Method: scDeepCluster, ARI=0.67, NMI=0.82, k=5
Method: SCCAF, ARI=0.67, NMI=0.88, k=6
Method: Seurat, ARI=0.75, NMI=0.89, k=6
Method: ADClust, ARI=0.77, NMI=0.8, k=3
Method: CIDR, ARI=0.78, NMI=0.84, k=3
Method: graph-sc, ARI=0.93, NMI=0.93, k=5
Method: DESC, ARI=0.98, NMI=0.96, k=5
Method: scAce, ARI=1.0, NMI=1.0, k=5


In [16]:
fig

<Figure size 1000x800 with 9 Axes>

In [17]:
plt.savefig('Figures/Figure2B.svg', dpi=300, format='svg', bbox_inches='tight')