In [1]:
import pandas as pd
import numpy as np
import h5py
import scanpy as sc

import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
from matplotlib.pyplot import plot,savefig

import warnings
warnings.filterwarnings("ignore")
from read_count import read_data

In [2]:
from collections import Counter

In [3]:
col_2 = ["#E64B35CC", "#0072B5CC", "#00A087CC", "#3C5488CC", 
         "#F39B7FCC", "#F7DC05FF", "#FD7446E5", "#8491B4CC"]

In [4]:
col_1 = ["#2175B1", "#FC7E10", "#2A9E30", "#D82725", "#9466BE", 
         "#8C5749", "#DF7BBB", "#808081", "#BDBC23", "#28B5C8",
         "#AEC7E8", "#FFBB78", "#98DF8A", "#FF9896", "#C5B0D5", 
         "#C49C94", "#F7B6D2", "#DBDB8D", "#9EDAE5", "#AD494A",
         "#4FC601", "#006FA6", "#E31C1F", "#0000A6", "#B79761", 
         "#5A0007", "#3B5DFF", "#BA0900", "#07C4A2"]

In [5]:
scace = np.load('results/default/Adam/scAce_wo_sample.npz')

In [6]:
clusters_merge = scace['Clusters'][1:5]
embedded_merge = scace['Embedding'][1:5]

In [7]:
for t in range(4):
    after = {}
    final = {}
    n_clusters_before = len(np.unique(clusters_merge[t][0]))
    
    for i, label in enumerate(clusters_merge[t][-1]):
        if label not in after.keys():
            after[label] = clusters_merge[t][0][i]
    
    clusters_merge[t][-1] = np.array(list(map(after.get, clusters_merge[t][-1])))
    final = dict(zip(list(after.values()) + list(set(range(n_clusters_before)) - set(after.values())), list(range(n_clusters_before))))
    
    clusters_merge[t][0] = np.array(list(map(final.get, clusters_merge[t][0])))
    clusters_merge[t][-1] = np.array(list(map(final.get, clusters_merge[t][-1])))

In [8]:
umap_f = []
for i in range(4):
    adata = sc.AnnData(embedded_merge[i])
    sc.pp.neighbors(adata)
    sc.tl.umap(adata, random_state=0)
    umap_f.append(np.array(adata.obsm['X_umap']))

In [9]:
def plot_merge(t, phase, ax):
    
    """
        t: t-interation of merging
        phase: 'pred_before' or 'pred_after' or 'true',
                if phase == 'pred_before', using clusters before t-interation of merging,
                else if phase == 'pred_after', using clusters after t-interation of merging,
                else if phase == 'true', using true cell types.
    """
    
    pred_before = clusters_merge[t][0]
    pred_after = clusters_merge[t][-1]
    k_before = len(np.unique(pred_before))
    k_after = len(np.unique(pred_after))
    umap = umap_f[t]
    
    adata = sc.AnnData(pd.DataFrame(np.random.rand(len(pred_before), 1)))
    adata.obs['pred_before'] = pred_before
    adata.obs['pred_before'] = adata.obs['pred_before'].astype(int).astype('category')
    adata.obs['pred_after'] = pred_after
    adata.obs['pred_after'] = adata.obs['pred_after'].astype(int).astype('category')
    adata.obs['true'] = y_true
    adata.obs['true'] = adata.obs['true'].astype(int).astype('category')
    adata.obsm['X_umap'] = umap
    
    if phase == 'pred_before':
        sc.pl.umap(adata, color=[phase], ax=ax, show=False, legend_loc='None', size=12, palette=col_1)
        ax.set_title('Before Iteration {}'.format(t+1), fontsize=16, family='Arial')
    elif phase == 'pred_after':
        sc.pl.umap(adata, color=[phase], ax=ax, show=False, legend_loc='None', size=12, palette=col_1)
        ax.set_title('After Iteration {}'.format(t+1), fontsize=16, family='Arial')
    else:
        sc.pl.umap(adata, color=[phase], ax=ax, show=False, legend_loc='None', size=12, palette=col_2)
        ax.set_title('Ground Truth', fontsize=16, family='Arial')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

In [10]:
fig = plt.figure(figsize=(20, 5))

sub_figs = fig.subfigures(2, 1)

axs = []
for i, sub_fig in enumerate(sub_figs):
    axs.append(sub_fig.subplots(1, 6))
    
axs = np.array(axs)

In [11]:
mat, obs, var, uns = read_data('dataset/Mouse_k.h5', sparsify=False, skip_exprs=False)
cell_name = np.array(obs["cell_type1"])
cell_type, cell_label = np.unique(cell_name, return_inverse=True)
y_true = cell_label

In [12]:
plot_merge(0, 'pred_before', axs[0][0])
plot_merge(0, 'pred_after', axs[0][1])
plot_merge(0, 'true', axs[0][2])
plot_merge(1, 'pred_before', axs[0][3])
plot_merge(1, 'pred_after', axs[0][4])
plot_merge(1, 'true', axs[0][5])
plot_merge(2, 'pred_before', axs[1][0])
plot_merge(2, 'pred_after', axs[1][1])
plot_merge(2, 'true', axs[1][2])
plot_merge(3, 'pred_before', axs[1][3])
plot_merge(3, 'pred_after', axs[1][4])
plot_merge(3, 'true', axs[1][5])
fig

<Figure size 2000x500 with 12 Axes>

In [13]:
plt.savefig('Figures/FigureS6.svg', dpi=300, format='svg', bbox_inches='tight')