In [1]:
import pandas as pd
import scanpy as sc
import numpy as np
import h5py

import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
from matplotlib.pyplot import plot,savefig
from sklearn import metrics

import warnings
warnings.filterwarnings("ignore")
from read_count import read_data

In [2]:
col = ["#E64B35CC", "#0072B5CC", "#00A087CC", "#3C5488CC", "#F39B7FCC", "#F7DC05FF", "#FD7446E5",
       "#8491B4CC", "#7E6148CC", "#B09C85CC", "#E18727CC", "#FFDC91E5", "#6A6599E5", "#9467BDB2"]

In [3]:
def plot_cluster(df, method_name, n, y_true, by, ax):
    
    """
        n: n-th dataset in [Human PBMC, Mouse ES, Mouse kidney]
        by: 'pred' or 'true'. If by == 'pred', colored by cluster labels, else colored by true cell types.
    """
    
    y_pred = df['Clusters'][0]
    
    if method_name == 'SCCAF':
        umap = umap_sccaf[n]
    elif method_name == 'ADClust':
        umap = umap_adclust[n]
    else:
        umap = umap_scace[n]
    
    K = len(np.unique(y_pred))
    
    y_pred = np.asarray(y_pred, dtype='int').squeeze()
    ari = np.round(metrics.adjusted_rand_score(y_pred, y_true), 2)
    nmi = np.round(metrics.normalized_mutual_info_score(y_pred, y_true), 2)
    print('Method: {}, ARI={}, NMI={}, k={}'.format(method_name, ari, nmi, K))
        
    adata = sc.AnnData(pd.DataFrame(np.random.rand(len(y_pred), 1)))
    adata.obs['pred'] = y_pred
    adata.obs['pred'] = adata.obs['pred'].astype(str).astype('category')
    adata.obs['true'] = y_true
    adata.obs['true'] = adata.obs['true'].astype(str).astype('category')

    adata.obsm['X_umap'] = umap
    
    if by == 'pred':
        sc.pl.umap(adata, color=['pred'], ax=ax, show=False, legend_loc='None', size=10)
        ax.set_title('K={}'.format(K), fontsize=15, family='Arial')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        
    else:
        sc.pl.umap(adata, color=['true'], ax=ax, show=False, legend_loc='None', size=10, palette=col)
        ax.set_title('({} Cell types)'.format(len(np.unique(y_true))), fontsize=15, family='Arial')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

In [4]:
data_mat = h5py.File('dataset/Human_PBMC.h5')
x_pbmc = np.array(data_mat['X'], dtype='int')
y_true_pbmc = np.array(data_mat['Y'], dtype='int')
data_mat.close()

mat, obs, var, uns = read_data('dataset/Mouse_E.h5', sparsify=False, skip_exprs=False)
x_klein = np.array(mat.toarray())
cell_name = np.array(obs["cell_type1"])
cell_type, y_true_klein = np.unique(cell_name, return_inverse=True)

mat, obs, var, uns = read_data('dataset/Mouse_k.h5', sparsify=False, skip_exprs=False)
x_adam = np.array(mat.toarray())
cell_name = np.array(obs["cell_type1"])
cell_type, y_true_adam = np.unique(cell_name, return_inverse=True)

In [5]:
pbmc_sccaf = np.load('results/default/PBMC/SCCAF_wo_sample.npz')
klein_sccaf = np.load('results/default/Klein/SCCAF_wo_sample.npz')
adam_sccaf = np.load('results/default/Adam/SCCAF_wo_sample.npz')

pbmc_adclust = np.load('results/default/PBMC/ADClust_wo_sample.npz')
klein_adclust = np.load('results/default/Klein/ADClust_wo_sample.npz')
adam_adclust = np.load('results/default/Adam/ADClust_wo_sample.npz')

pbmc_scace = np.load('results/default/PBMC/scAce_wo_sample.npz')
klein_scace = np.load('results/default/Klein/scAce_wo_sample.npz')
adam_scace = np.load('results/default/Adam/scAce_wo_sample.npz')

In [7]:
umap_sccaf, umap_adclust, umap_scace = [], [], []
sccaf_data = [pbmc_sccaf, klein_sccaf, adam_sccaf]
adclust_data = [pbmc_adclust, klein_adclust, adam_adclust]
scace_data = [pbmc_scace, klein_scace, adam_scace]

for data in sccaf_data:
    adata = sc.AnnData(data['Embedding'])
    sc.pp.neighbors(adata)
    sc.tl.umap(adata, random_state=0)
    umap_sccaf.append(np.array(adata.obsm['X_umap']))
    
for data in adclust_data:
    adata = sc.AnnData(data['Embedding'][0])
    sc.pp.neighbors(adata)
    sc.tl.umap(adata, random_state=0)
    umap_adclust.append(np.array(adata.obsm['X_umap']))
    
for data in scace_data:
    adata = sc.AnnData(data['Embedding'][0])
    sc.pp.neighbors(adata)
    sc.tl.umap(adata, random_state=0)
    umap_scace.append(np.array(adata.obsm['X_umap']))

In [6]:
umap_all = np.load("umap/umap_merge_init.npz")
umap_sccaf = umap_all['umap_sccaf']
umap_adclust = umap_all['umap_adclust']
umap_scace = umap_all['umap_scace']
# np.savez("umap/umap_merge_init.npz", umap_sccaf=umap_sccaf, umap_adclust=umap_adclust, umap_scace=umap_scace)

In [7]:
fig = plt.figure(figsize=(10, 8))
methods = ['SCCAF', 'ADClust', 'scAce']

sub_figs = fig.subfigures(3, 1)

axs = []

for i, sub_fig in enumerate(sub_figs):
        
    axs.append(sub_fig.subplots(1, 3))
#     sub_fig.supylabel(methods[i], x=0.07, fontsize=15, family='Arial')
    
axs = np.array(axs)

In [8]:
plot_cluster(pbmc_sccaf, 'SCCAF', 0, y_true_pbmc, 'pred', axs[0][0])
plot_cluster(klein_sccaf, 'SCCAF', 1, y_true_klein, 'pred', axs[0][1])
plot_cluster(adam_sccaf, 'SCCAF', 2, y_true_adam, 'pred', axs[0][2])

plot_cluster(pbmc_adclust, 'ADClust', 0, y_true_pbmc, 'pred', axs[1][0])
plot_cluster(klein_adclust, 'ADClust', 1, y_true_klein, 'pred', axs[1][1])
plot_cluster(adam_adclust, 'ADClust', 2, y_true_adam, 'pred', axs[1][2])

plot_cluster(pbmc_scace, 'scAce', 0, y_true_pbmc, 'pred', axs[2][0])
plot_cluster(klein_scace, 'scAce', 1, y_true_klein, 'pred', axs[2][1])
plot_cluster(adam_scace, 'scAce', 2, y_true_adam, 'pred', axs[2][2])

Method: SCCAF, ARI=0.64, NMI=0.73, k=12
Method: SCCAF, ARI=0.64, NMI=0.79, k=9
Method: SCCAF, ARI=0.56, NMI=0.74, k=17
Method: ADClust, ARI=0.24, NMI=0.57, k=33
Method: ADClust, ARI=0.18, NMI=0.53, k=29
Method: ADClust, ARI=0.28, NMI=0.65, k=42
Method: scAce, ARI=0.43, NMI=0.69, k=19
Method: scAce, ARI=0.33, NMI=0.64, k=19
Method: scAce, ARI=0.45, NMI=0.72, k=25


In [9]:
fig

<Figure size 1000x800 with 9 Axes>

In [10]:
plt.savefig('Figures/FigureS5A.svg', dpi=300, format='svg', bbox_inches='tight')

In [11]:
fig = plt.figure(figsize=(10, 8))
methods = ['SCCAF', 'ADClust', 'scAce']

sub_figs = fig.subfigures(3, 1)

axs = []

for i, sub_fig in enumerate(sub_figs):
    axs.append(sub_fig.subplots(1, 3))
    
axs = np.array(axs)

In [12]:
col = ["#E64B35CC", "#0072B5CC", "#00A087CC", "#3C5488CC", "#F39B7FCC", "#F7DC05FF", "#FD7446E5",
       "#8491B4CC", "#7E6148CC", "#B09C85CC", "#E18727CC", "#FFDC91E5", "#6A6599E5", "#9467BDB2"]

In [13]:
plot_cluster(pbmc_sccaf, 'SCCAF', 0, y_true_pbmc, 'true', axs[0][0])
plot_cluster(klein_sccaf, 'SCCAF', 1, y_true_klein, 'true', axs[0][1])
plot_cluster(adam_sccaf, 'SCCAF', 2, y_true_adam, 'true', axs[0][2])

plot_cluster(pbmc_adclust, 'ADClust', 0, y_true_pbmc, 'true', axs[1][0])
plot_cluster(klein_adclust, 'ADClust', 1, y_true_klein, 'true', axs[1][1])
plot_cluster(adam_adclust, 'ADClust', 2, y_true_adam, 'true', axs[1][2])

plot_cluster(pbmc_scace, 'scAce', 0, y_true_pbmc, 'true', axs[2][0])
plot_cluster(klein_scace, 'scAce', 1, y_true_klein, 'true', axs[2][1])
plot_cluster(adam_scace, 'scAce', 2, y_true_adam, 'true', axs[2][2])

Method: SCCAF, ARI=0.64, NMI=0.73, k=12
Method: SCCAF, ARI=0.64, NMI=0.79, k=9
Method: SCCAF, ARI=0.56, NMI=0.74, k=17
Method: ADClust, ARI=0.24, NMI=0.57, k=33
Method: ADClust, ARI=0.18, NMI=0.53, k=29
Method: ADClust, ARI=0.28, NMI=0.65, k=42
Method: scAce, ARI=0.43, NMI=0.69, k=19
Method: scAce, ARI=0.33, NMI=0.64, k=19
Method: scAce, ARI=0.45, NMI=0.72, k=25


In [14]:
fig

<Figure size 1000x800 with 9 Axes>

In [15]:
plt.savefig('Figures/FigureS5B.svg', dpi=300, format='svg', bbox_inches='tight')