In [1]:
import pandas as pd
import scanpy as sc
import numpy as np
import h5py

import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
from matplotlib.pyplot import plot,savefig
from sklearn import metrics

import warnings
warnings.filterwarnings("ignore")
from read_count import read_data

In [2]:
col = ["#E64B35CC", "#0072B5CC", "#00A087CC", "#3C5488CC", "#F39B7FCC", "#F7DC05FF", "#FD7446E5",
       "#8491B4CC", "#7E6148CC", "#B09C85CC", "#E18727CC", "#FFDC91E5", "#6A6599E5", "#9467BDB2"]

In [3]:
def plot_cluster(df, data_name, phase, by, y_true, n, ax):
    
    """
        phase: 'split' or 'enhance'. If phase == 'split', using the results after splitting, else using the results after enhancement.
        by: 'clusters' or 'true'. If by == 'clusters', colored by cluster labels, else colored by true cell types.
        n: n-th dataset in [Human1, Human2, Human3, Mouse1, Mouse2, Mouse3]
    """
    
    if phase == 'split':
        tsne = tsne_init_all[n]
        y_pred = df['Clusters'][0]
    
    else :
        tsne = tsne_last_all[n]
        y_pred = df['Clusters'][1]
    
    y_pred = np.asarray(y_pred, dtype='int').squeeze()
    K_pred = len(np.unique(y_pred))
    
    ari_pred = np.round(metrics.adjusted_rand_score(y_pred, y_true), 2)
    nmi_pred = np.round(metrics.normalized_mutual_info_score(y_pred, y_true), 2)
    
    if by == 'clusters':
        print('Datasets: {}_{}, ARI={}, NMI={}, k={}'.format(data_name, phase, ari_pred, nmi_pred, K_pred))
        
    adata = sc.AnnData(pd.DataFrame(np.random.rand(len(y_pred), 1)))
    adata.obs['pred'] = y_pred
    adata.obs['pred'] = adata.obs['pred'].astype(str).astype('category')
    adata.obs['true'] = y_true
    adata.obs['true'] = adata.obs['true'].astype(str).astype('category')

    adata.obsm['X_tsne'] = tsne
    
    if by == 'clusters':
        sc.pl.tsne(adata, color=['pred'], ax=ax, show=False, legend_loc='None', size=20)
        if phase == 'split':
            ax.set_title('K={}'.format(K_pred), fontsize=14, family='Arial')
        else:
            ax.set_title('K={} ARI={}'.format(K_pred, ari_pred), fontsize=14, family='Arial')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        
    else:
        sc.pl.tsne(adata, color=['true'], ax=ax, show=False, legend_loc='None', palette=col, size=20)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

In [4]:
data_mat = h5py.File('dataset/Human1.h5')
y_true_human1 = np.array(data_mat['Y'], dtype='int')
data_mat.close()

data_mat = h5py.File('dataset/Human2.h5')
y_true_human2 = np.array(data_mat['Y'], dtype='int')
data_mat.close()

data_mat = h5py.File('dataset/Human3.h5')
y_true_human3 = np.array(data_mat['Y'], dtype='int')
data_mat.close()

mat, obs, var, uns = read_data('dataset/Mouse1.h5', sparsify=False, skip_exprs=False)
cell_name = np.array(obs["cell_type1"])
cell_type, cell_label = np.unique(cell_name, return_inverse=True)
y_true_mouse1 = cell_label

mat, obs, var, uns = read_data('dataset/Mouse2.h5', sparsify=False, skip_exprs=False)
cell_name = np.array(obs["cell_type1"])
cell_type, cell_label = np.unique(cell_name, return_inverse=True)
y_true_mouse2 = cell_label

data_mat = h5py.File('dataset/Mouse3.h5')
y_true_mouse3 = np.array(data_mat['Y'], dtype='int')
data_mat.close()

In [5]:
fig = plt.figure(figsize=(20, 5))
sub_figs = fig.subfigures(2, 1)
axs = []

for i, sub_fig in enumerate(sub_figs):
    axs.append(sub_fig.subplots(1, 6))
    
axs = np.array(axs)

# Split

In [6]:
human1_s = np.load('results/enhancement/scAce_enhance_Seurat_human1.npz')
human2_s = np.load('results/enhancement/scAce_enhance_Seurat_human2.npz')
human3_s = np.load('results/enhancement/scAce_enhance_Seurat_human3.npz')
mouse1_s = np.load('results/enhancement/scAce_enhance_Seurat_mouse1.npz')
mouse2_s = np.load('results/enhancement/scAce_enhance_Seurat_mouse2.npz')
mouse3_s = np.load('results/enhancement/scAce_enhance_Seurat_mouse3.npz')

In [7]:
datasets = [human1_s, human2_s, human3_s, mouse1_s, mouse2_s, mouse3_s]

In [8]:
tsne_init_all, tsne_last_all = [], []
for data in datasets:
    adata_init = sc.AnnData(data['Embedding'][0])
    sc.tl.tsne(adata_init, random_state=0)
    tsne_init_all.append(np.array(adata_init.obsm['X_tsne']))
    adata_last = sc.AnnData(data['Embedding'][1])
    sc.tl.tsne(adata_last, random_state=0)
    tsne_last_all.append(np.array(adata_last.obsm['X_tsne']))

In [9]:
plot_cluster(human1_s, 'Human1', 'split', 'clusters', y_true_human1, 0, axs[0][0])
plot_cluster(human2_s, 'Human2', 'split', 'clusters', y_true_human2, 1, axs[0][1])
plot_cluster(human3_s, 'Human3', 'split', 'clusters', y_true_human3, 2, axs[0][2])
plot_cluster(mouse1_s, 'Mouse1', 'split', 'clusters', y_true_mouse1, 3, axs[0][3])
plot_cluster(mouse2_s, 'Mouse2', 'split', 'clusters', y_true_mouse2, 4, axs[0][4])
plot_cluster(mouse3_s, 'Mouse3', 'split', 'clusters', y_true_mouse3, 5, axs[0][5])

Datasets: Human1_split, ARI=0.62, NMI=0.8, k=17
Datasets: Human2_split, ARI=0.57, NMI=0.79, k=18
Datasets: Human3_split, ARI=0.62, NMI=0.81, k=13
Datasets: Mouse1_split, ARI=0.64, NMI=0.78, k=18
Datasets: Mouse2_split, ARI=0.51, NMI=0.78, k=14
Datasets: Mouse3_split, ARI=0.51, NMI=0.72, k=13


In [10]:
fig

<Figure size 2000x500 with 12 Axes>

In [11]:
plot_cluster(human1_s, 'Human1', 'split', 'labels', y_true_human1, 0, axs[1][0])
plot_cluster(human2_s, 'Human2', 'split', 'labels', y_true_human2, 1, axs[1][1])
plot_cluster(human3_s, 'Human3', 'split', 'labels', y_true_human3, 2, axs[1][2])
plot_cluster(mouse1_s, 'Mouse1', 'split', 'labels', y_true_mouse1, 3, axs[1][3])
plot_cluster(mouse2_s, 'Mouse2', 'split', 'labels', y_true_mouse2, 4, axs[1][4])
plot_cluster(mouse3_s, 'Mouse3', 'split', 'labels', y_true_mouse3, 5, axs[1][5])

In [12]:
fig

<Figure size 2000x500 with 12 Axes>

In [13]:
plt.savefig('Figures/Figure5B.svg', dpi=300, format='svg', bbox_inches='tight')

# Enhance

In [14]:
fig = plt.figure(figsize=(20, 5))
sub_figs = fig.subfigures(2, 1)
axs = []

for i, sub_fig in enumerate(sub_figs):
    axs.append(sub_fig.subplots(1, 6))

axs = np.array(axs)

In [15]:
plot_cluster(human1_s, 'Human1', 'enhance', 'clusters', y_true_human1, 0, axs[0][0])
plot_cluster(human2_s, 'Human2', 'enhance', 'clusters', y_true_human2, 1, axs[0][1])
plot_cluster(human3_s, 'Human3', 'enhance', 'clusters', y_true_human3, 2, axs[0][2])
plot_cluster(mouse1_s, 'Mouse1', 'enhance', 'clusters', y_true_mouse1, 3, axs[0][3])
plot_cluster(mouse2_s, 'Mouse2', 'enhance', 'clusters', y_true_mouse2, 4, axs[0][4])
plot_cluster(mouse3_s, 'Mouse3', 'enhance', 'clusters', y_true_mouse3, 5, axs[0][5])

Datasets: Human1_enhance, ARI=0.59, NMI=0.77, k=12
Datasets: Human2_enhance, ARI=0.95, NMI=0.91, k=6
Datasets: Human3_enhance, ARI=0.92, NMI=0.9, k=7
Datasets: Mouse1_enhance, ARI=0.82, NMI=0.86, k=7
Datasets: Mouse2_enhance, ARI=0.99, NMI=0.99, k=6
Datasets: Mouse3_enhance, ARI=0.99, NMI=0.98, k=4


In [16]:
plot_cluster(human1_s, 'Human1', 'enhance', 'labels', y_true_human1, 0, axs[1][0])
plot_cluster(human2_s, 'Human2', 'enhance', 'labels', y_true_human2, 1, axs[1][1])
plot_cluster(human3_s, 'Human3', 'enhance', 'labels', y_true_human3, 2, axs[1][2])
plot_cluster(mouse1_s, 'Mouse1', 'enhance', 'labels', y_true_mouse1, 3, axs[1][3])
plot_cluster(mouse2_s, 'Mouse2', 'enhance', 'labels', y_true_mouse2, 4, axs[1][4])
plot_cluster(mouse3_s, 'Mouse3', 'enhance', 'labels', y_true_mouse3, 5, axs[1][5])

In [17]:
fig

<Figure size 2000x500 with 12 Axes>

In [18]:
plt.savefig('Figures/Figure5C.svg', dpi=300, format='svg', bbox_inches='tight')