In [3]:
import pandas as pd
import scanpy as sc
import numpy as np
import h5py

import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
from matplotlib.pyplot import plot,savefig
from sklearn import metrics

import warnings
warnings.filterwarnings("ignore")
from read_count import read_data

In [4]:
col = ["#E64B35CC", "#0072B5CC", "#00A087CC", "#3C5488CC", "#F39B7FCC", "#F7DC05FF", "#FD7446E5",
       "#8491B4CC", "#7E6148CC", "#B09C85CC", "#E18727CC", "#FFDC91E5", "#6A6599E5", "#9467BDB2",
       "#FFFFFFFF", "#0000FFFF", "#FF0000FF", "#00FF00FF", "#000033FF", "#FF00B6FF", "#005300FF", "#FFD300FF",
       "#009FFFFF", "#9A4D42FF", "#00FFBEFF", "#783FC1FF", "#1F9698FF", "#FFACFDFF", "#B1CC71FF", "#F1085CFF",
       "#FE8F42FF", "#DD00FFFF", "#201A01FF", "#720055FF", "#766C95FF", "#02AD24FF", "#C8FF00FF", "#886C00FF",
       "#FFB79FFF", "#858567FF", "#A10300FF", "#14F9FFFF", "#00479EFF", "#DC5E93FF", "#93D4FFFF", "#004CFFFF"]

In [5]:
def plot_cluster(df, method_name, n, y_true, by, ax):
    
    """
        n: n-th dataset in [Human pancreas, Human kidney, Human hypothalamus, Turtle brain]
        by: 'pred' or 'true'. If by == 'pred', colored by cluster labels, else colored by true cell types.
    """
    
    y_pred = df['Clusters'][0]
    
    if method_name == 'SCCAF':
        umap = umap_sccaf[n]
    elif method_name == 'ADClust':
        umap = umap_adclust[n]
    else:
        umap = umap_scace[n]
    
    K = len(np.unique(y_pred))
    
    y_pred = np.asarray(y_pred, dtype='int').squeeze()
    ari = np.round(metrics.adjusted_rand_score(y_pred, y_true), 2)
    nmi = np.round(metrics.normalized_mutual_info_score(y_pred, y_true), 2)
    print('Method: {}, ARI={}, NMI={}, k={}'.format(method_name, ari, nmi, K))
        
    adata = sc.AnnData(pd.DataFrame(np.random.rand(len(y_pred), 1)))
    adata.obs['pred'] = y_pred
    adata.obs['pred'] = adata.obs['pred'].astype(str).astype('category')
    adata.obs['true'] = y_true
    adata.obs['true'] = adata.obs['true'].astype(str).astype('category')

    adata.obsm['X_umap'] = umap
    
    if by == 'pred':
        sc.pl.umap(adata, color=['pred'], ax=ax, show=False, legend_loc='None', size=8)
        ax.set_title('K={}'.format(K), fontsize=15, family='Arial')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        
    else:
        sc.pl.umap(adata, color=['true'], ax=ax, show=False, legend_loc='None', size=8, palette=col)
        ax.set_title('({} Cell types)'.format(len(np.unique(y_true))), fontsize=15, family='Arial')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

In [6]:
data_mat = h5py.File('dataset/Human_p.h5')
y_true_human = np.array(data_mat['Y'], dtype='int')
data_mat.close()

data_mat = h5py.File('dataset/Human_k.h5')
y_true_kidney = np.array(data_mat['Y'], dtype='int')
data_mat.close()

mat, obs, var, uns = read_data('dataset/Mouse_h.h5', sparsify=False, skip_exprs=False)
x_chen = np.array(mat.toarray())
cell_name = np.array(obs["cell_type1"])
cell_type, y_true_chen = np.unique(cell_name, return_inverse=True)

mat, obs, var, uns = read_data('dataset/Turtle_b.h5', sparsify=False, skip_exprs=False)
x = np.array(mat.toarray())
cell_name = np.array(obs["cell_type1"])
cell_type, y_true_turtle = np.unique(cell_name, return_inverse=True)

In [7]:
adata = sc.AnnData(x_chen)
adata.obs['celltype'] = y_true_chen
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.filter_cells(adata, min_genes=200)

In [8]:
y_true_2 = np.array(adata.obs['celltype']).squeeze()

In [9]:
human_sccaf = np.load('results/default/Human/SCCAF_wo_sample.npz')
kidney_sccaf = np.load('results/default/Kidney/SCCAF_wo_sample.npz')
chen_sccaf = np.load('results/default/Chen/SCCAF_wo_sample.npz')
turtle_sccaf = np.load('results/default/Turtle/SCCAF_wo_sample.npz')

human_adclust = np.load('results/default/Human/ADClust_wo_sample.npz')
kidney_adclust = np.load('results/default/Kidney/ADClust_wo_sample.npz')
chen_adclust = np.load('results/default/Chen/ADClust_wo_sample.npz')
turtle_adclust = np.load('results/default/Turtle/ADClust_wo_sample.npz')

human_scace = np.load('results/default/Human/scAce_wo_sample.npz')
kidney_scace = np.load('results/default/Kidney/scAce_wo_sample.npz')
chen_scace = np.load('results/default/Chen/scAce_wo_sample.npz')
turtle_scace = np.load('results/default/Turtle/scAce_wo_sample.npz')

In [None]:
umap_sccaf, umap_adclust, umap_scace = [], [], []
sccaf_data = [human_sccaf, kidney_sccaf, chen_sccaf, turtle_sccaf]
adclust_data = [human_adclust, kidney_adclust, chen_adclust, turtle_adclust]
scace_data = [human_scace, kidney_scace, chen_scace, turtle_scace]

for data in sccaf_data:
    adata = sc.AnnData(data['Embedding'])
    sc.pp.neighbors(adata)
    sc.tl.umap(adata, random_state=0)
    umap_sccaf.append(np.array(adata.obsm['X_umap']))
    
for data in adclust_data:
    adata = sc.AnnData(data['Embedding'][0])
    sc.pp.neighbors(adata)
    sc.tl.umap(adata, random_state=0)
    umap_adclust.append(np.array(adata.obsm['X_umap']))
    
for data in scace_data:
    adata = sc.AnnData(data['Embedding'][0])
    sc.pp.neighbors(adata)
    sc.tl.umap(adata, random_state=0)
    umap_scace.append(np.array(adata.obsm['X_umap']))

In [10]:
umap_all = np.load("umap/umap_merge_init_others.npz")
umap_sccaf = umap_all['umap_sccaf']
umap_adclust = umap_all['umap_adclust']
umap_scace = umap_all['umap_scace']
# np.savez("umap/umap_merge_init_others.npz", umap_sccaf=umap_sccaf, umap_adclust=umap_adclust, umap_scace=umap_scace)

In [12]:
fig = plt.figure(figsize=(15, 8))
sub_figs = fig.subfigures(3, 1)
axs = []

for i, sub_fig in enumerate(sub_figs):   
    axs.append(sub_fig.subplots(1, 4))
    
axs = np.array(axs)

In [13]:
plot_cluster(human_sccaf, 'SCCAF', 0, y_true_human, 'pred', axs[0][0])
plot_cluster(kidney_sccaf, 'SCCAF', 1, y_true_kidney, 'pred', axs[0][1])
plot_cluster(chen_sccaf, 'SCCAF', 2, y_true_2, 'pred', axs[0][2])
plot_cluster(turtle_sccaf, 'SCCAF', 3, y_true_turtle, 'pred', axs[0][3])

plot_cluster(human_adclust, 'ADClust', 0, y_true_human, 'pred', axs[1][0])
plot_cluster(kidney_adclust, 'ADClust', 1, y_true_kidney, 'pred', axs[1][1])
plot_cluster(chen_adclust, 'ADClust', 2, y_true_chen, 'pred', axs[1][2])
plot_cluster(turtle_adclust, 'ADClust', 3, y_true_turtle, 'pred', axs[1][3])

plot_cluster(human_scace, 'scAce', 0, y_true_human, 'pred', axs[2][0])
plot_cluster(kidney_scace, 'scAce', 1, y_true_kidney, 'pred', axs[2][1])
plot_cluster(chen_scace, 'scAce', 2, y_true_2, 'pred', axs[2][2])
plot_cluster(turtle_scace, 'scAce', 3, y_true_turtle, 'pred', axs[2][3])

Method: SCCAF, ARI=0.54, NMI=0.77, k=16
Method: SCCAF, ARI=0.37, NMI=0.7, k=33
Method: SCCAF, ARI=0.6, NMI=0.77, k=28
Method: SCCAF, ARI=0.39, NMI=0.74, k=26
Method: ADClust, ARI=0.21, NMI=0.63, k=36
Method: ADClust, ARI=0.23, NMI=0.66, k=48
Method: ADClust, ARI=0.27, NMI=0.71, k=55
Method: ADClust, ARI=0.21, NMI=0.65, k=51
Method: scAce, ARI=0.33, NMI=0.69, k=27
Method: scAce, ARI=0.28, NMI=0.66, k=46
Method: scAce, ARI=0.33, NMI=0.71, k=44
Method: scAce, ARI=0.23, NMI=0.66, k=44


In [14]:
fig

<Figure size 1500x800 with 12 Axes>

In [15]:
plt.savefig('Figures/FigureS7A.svg', dpi=300, format='svg', bbox_inches='tight')

In [16]:
fig = plt.figure(figsize=(15, 8))
methods = ['SCCAF', 'ADClust', 'scAce']

sub_figs = fig.subfigures(3, 1)

axs = []

for i, sub_fig in enumerate(sub_figs):
        
    axs.append(sub_fig.subplots(1, 4))
    
axs = np.array(axs)

In [17]:
plot_cluster(human_sccaf, 'SCCAF', 0, y_true_human, 'true', axs[0][0])
plot_cluster(kidney_sccaf, 'SCCAF', 1, y_true_kidney, 'true', axs[0][1])
plot_cluster(chen_sccaf, 'SCCAF', 2, y_true_2, 'true', axs[0][2])
plot_cluster(turtle_sccaf, 'SCCAF', 3, y_true_turtle, 'true', axs[0][3])

plot_cluster(human_adclust, 'ADClust', 0, y_true_human, 'true', axs[1][0])
plot_cluster(kidney_adclust, 'ADClust', 1, y_true_kidney, 'true', axs[1][1])
plot_cluster(chen_adclust, 'ADClust', 2, y_true_chen, 'true', axs[1][2])
plot_cluster(turtle_adclust, 'ADClust', 3, y_true_turtle, 'true', axs[1][3])

plot_cluster(human_scace, 'scAce', 0, y_true_human, 'true', axs[2][0])
plot_cluster(kidney_scace, 'scAce', 1, y_true_kidney, 'true', axs[2][1])
plot_cluster(chen_scace, 'scAce', 2, y_true_2, 'true', axs[2][2])
plot_cluster(turtle_scace, 'scAce', 3, y_true_turtle, 'true', axs[2][3])

Method: SCCAF, ARI=0.54, NMI=0.77, k=16
Method: SCCAF, ARI=0.37, NMI=0.7, k=33
Method: SCCAF, ARI=0.6, NMI=0.77, k=28
Method: SCCAF, ARI=0.39, NMI=0.74, k=26
Method: ADClust, ARI=0.21, NMI=0.63, k=36
Method: ADClust, ARI=0.23, NMI=0.66, k=48
Method: ADClust, ARI=0.27, NMI=0.71, k=55
Method: ADClust, ARI=0.21, NMI=0.65, k=51
Method: scAce, ARI=0.33, NMI=0.69, k=27
Method: scAce, ARI=0.28, NMI=0.66, k=46
Method: scAce, ARI=0.33, NMI=0.71, k=44
Method: scAce, ARI=0.23, NMI=0.66, k=44


In [18]:
fig

<Figure size 1500x800 with 12 Axes>

In [19]:
plt.savefig('Figures/FigureS7B.svg', dpi=300, format='svg', bbox_inches='tight')