In [1]:
import pandas as pd
import scanpy as sc
import numpy as np
import h5py

import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
from matplotlib.pyplot import plot,savefig
from sklearn import metrics

import warnings
warnings.filterwarnings("ignore")
from read_count import read_data

In [2]:
col = ["#E64B35CC", "#0072B5CC", "#00A087CC", "#3C5488CC", "#F39B7FCC", "#F7DC05FF", "#FD7446E5",
       "#8491B4CC", "#7E6148CC", "#B09C85CC", "#E18727CC", "#FFDC91E5", "#6A6599E5", "#9467BDB2"]

In [3]:
def plot_cluster(df, method_name, n, y_true, by, ax):
    
    """
        n: n-th dataset in [Mouse1, Mouse2, Mouse3]
        by: 'pred' or 'true'. If by == 'pred', colored by cluster labels, else colored by true cell types.
    """
    
    y_pred = df['Clusters'][0]
    
    if method_name == 'SCCAF':
        tsne = tsne_sccaf[n]
    elif method_name == 'ADClust':
        tsne = tsne_adclust[n]
    else:
        tsne = tsne_scace[n]
    
    K = len(np.unique(y_pred))
    
    y_pred = np.asarray(y_pred, dtype='int').squeeze()
    ari = np.round(metrics.adjusted_rand_score(y_pred, y_true), 2)
    nmi = np.round(metrics.normalized_mutual_info_score(y_pred, y_true), 2)
    print('Method: {}, ARI={}, NMI={}, k={}'.format(method_name, ari, nmi, K))
        
    adata = sc.AnnData(pd.DataFrame(np.random.rand(len(y_pred), 1)))
    adata.obs['pred'] = y_pred
    adata.obs['pred'] = adata.obs['pred'].astype(str).astype('category')
    adata.obs['true'] = y_true
    adata.obs['true'] = adata.obs['true'].astype(str).astype('category')

    adata.obsm['X_tsne'] = tsne
    
    if by == 'pred':
        sc.pl.tsne(adata, color=['pred'], ax=ax, show=False, legend_loc='None', size=30)
        ax.set_title('K={}'.format(K), fontsize=15, family='Arial')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        
    else:
        sc.pl.tsne(adata, color=['true'], ax=ax, show=False, legend_loc='None', size=30, palette=col)
        ax.set_title('({} Cell types)'.format(len(np.unique(y_true))), fontsize=15, family='Arial')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

In [4]:
data_mat = h5py.File('dataset/Human1.h5')
y_true_human1 = np.array(data_mat['Y'], dtype='int')
data_mat.close()

data_mat = h5py.File('dataset/Human2.h5')
y_true_human2 = np.array(data_mat['Y'], dtype='int')
data_mat.close()

data_mat = h5py.File('dataset/Human3.h5')
y_true_human3 = np.array(data_mat['Y'], dtype='int')
data_mat.close()

In [5]:
human1_sccaf = np.load('results/default/Human1/SCCAF_wo_sample.npz')
human2_sccaf = np.load('results/default/Human2/SCCAF_wo_sample.npz')
human3_sccaf = np.load('results/default/Human3/SCCAF_wo_sample.npz')

human1_adclust = np.load('results/default/Human1/ADClust_wo_sample.npz')
human2_adclust = np.load('results/default/Human2/ADClust_wo_sample.npz')
human3_adclust = np.load('results/default/Human3/ADClust_wo_sample.npz')

human1_scace = np.load('results/default/Human1/scAce_wo_sample.npz')
human2_scace = np.load('results/default/Human2/scAce_wo_sample.npz')
human3_scace = np.load('results/default/Human3/scAce_wo_sample.npz')

In [6]:
tsne_sccaf, tsne_adclust, tsne_scace = [], [], []
sccaf_data = [human1_sccaf, human2_sccaf, human3_sccaf]
adclust_data = [human1_adclust, human2_adclust, human3_adclust]
scace_data = [human1_scace, human2_scace, human3_scace]

for data in sccaf_data:
    adata = sc.AnnData(data['Embedding'])
    # adata = data_preprocess(adata)
    sc.tl.tsne(adata, random_state=0)
    tsne_sccaf.append(np.array(adata.obsm['X_tsne']))
    
for data in adclust_data:
    adata = sc.AnnData(data['Embedding'][0])
    sc.tl.tsne(adata, random_state=0)
    tsne_adclust.append(np.array(adata.obsm['X_tsne']))
    
for data in scace_data:
    adata = sc.AnnData(data['Embedding'][0])
    sc.tl.tsne(adata, random_state=0)
    tsne_scace.append(np.array(adata.obsm['X_tsne']))

         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


In [7]:
fig = plt.figure(figsize=(10, 8))
sub_figs = fig.subfigures(3, 1)
axs = []

for i, sub_fig in enumerate(sub_figs):   
    axs.append(sub_fig.subplots(1, 3))
    
axs = np.array(axs)

In [8]:
plot_cluster(human1_sccaf, 'SCCAF', 0, y_true_human1, 'pred', axs[0][0])
plot_cluster(human2_sccaf, 'SCCAF', 1, y_true_human2, 'pred', axs[0][1])
plot_cluster(human3_sccaf, 'SCCAF', 2, y_true_human3, 'pred', axs[0][2])

plot_cluster(human1_adclust, 'ADClust', 0, y_true_human1, 'pred', axs[1][0])
plot_cluster(human2_adclust, 'ADClust', 1, y_true_human2, 'pred', axs[1][1])
plot_cluster(human3_adclust, 'ADClust', 2, y_true_human3, 'pred', axs[1][2])

plot_cluster(human1_scace, 'scAce', 0, y_true_human1, 'pred', axs[2][0])
plot_cluster(human2_scace, 'scAce', 1, y_true_human2, 'pred', axs[2][1])
plot_cluster(human3_scace, 'scAce', 2, y_true_human3, 'pred', axs[2][2])

Method: SCCAF, ARI=0.42, NMI=0.73, k=17
Method: SCCAF, ARI=0.54, NMI=0.77, k=16
Method: SCCAF, ARI=0.6, NMI=0.8, k=11
Method: ADClust, ARI=0.21, NMI=0.65, k=30
Method: ADClust, ARI=0.22, NMI=0.64, k=34
Method: ADClust, ARI=0.23, NMI=0.64, k=26
Method: scAce, ARI=0.3, NMI=0.68, k=22
Method: scAce, ARI=0.32, NMI=0.68, k=26
Method: scAce, ARI=0.34, NMI=0.69, k=18


In [9]:
fig

<Figure size 1000x800 with 9 Axes>

In [10]:
plt.savefig('Figures/FigureS3A.svg', dpi=300, format='svg', bbox_inches='tight')

In [11]:
fig = plt.figure(figsize=(10, 8))
methods = ['SCCAF', 'ADClust', 'scAce']

sub_figs = fig.subfigures(3, 1)

axs = []

for i, sub_fig in enumerate(sub_figs):
        
    axs.append(sub_fig.subplots(1, 3))
#     sub_fig.supylabel(methods[i], x=0.07, fontsize=15, family='Arial')
    
axs = np.array(axs)

In [12]:
plot_cluster(human1_sccaf, 'SCCAF', 0, y_true_human1, 'true', axs[0][0])
plot_cluster(human2_sccaf, 'SCCAF', 1, y_true_human2, 'true', axs[0][1])
plot_cluster(human3_sccaf, 'SCCAF', 2, y_true_human3, 'true', axs[0][2])

plot_cluster(human1_adclust, 'ADClust', 0, y_true_human1, 'true', axs[1][0])
plot_cluster(human2_adclust, 'ADClust', 1, y_true_human2, 'true', axs[1][1])
plot_cluster(human3_adclust, 'ADClust', 2, y_true_human3, 'true', axs[1][2])

plot_cluster(human1_scace, 'scAce', 0, y_true_human1, 'true', axs[2][0])
plot_cluster(human2_scace, 'scAce', 1, y_true_human2, 'true', axs[2][1])
plot_cluster(human3_scace, 'scAce', 2, y_true_human3, 'true', axs[2][2])

Method: SCCAF, ARI=0.42, NMI=0.73, k=17
Method: SCCAF, ARI=0.54, NMI=0.77, k=16
Method: SCCAF, ARI=0.6, NMI=0.8, k=11
Method: ADClust, ARI=0.21, NMI=0.65, k=30
Method: ADClust, ARI=0.22, NMI=0.64, k=34
Method: ADClust, ARI=0.23, NMI=0.64, k=26
Method: scAce, ARI=0.3, NMI=0.68, k=22
Method: scAce, ARI=0.32, NMI=0.68, k=26
Method: scAce, ARI=0.34, NMI=0.69, k=18


In [13]:
fig

<Figure size 1000x800 with 9 Axes>

In [14]:
plt.savefig('Figures/FigureS3B.svg', dpi=300, format='svg', bbox_inches='tight')