In [1]:
import numpy as np
import scanpy as sc
import h5py
import os
import matplotlib.pyplot as plt
from scgraphne import read_data
import warnings
warnings.filterwarnings('ignore')

In [2]:
import matplotlib
matplotlib.use('agg')
plt.rcParams['font.family'] = 'Arial'

In [3]:
for dataset in ['10X_PBMC','mouse_bladder_cell','Adam','Human_pancreatic_islets',
                'human_kidney_counts','mouse_ES_cell','Macosko_mouse_retina']:
    print('----------------real data: {} ----------------- '.format(dataset))
    dir0 = '../'
    dir1 = '{}'.format(dataset)
    dir2 = 'data_{}.h5'.format(dataset)

    if dataset in ['Adam']:
        mat, obs, var, uns = read_data(os.path.join(dir0, 'datasets/real/{}.h5'.format(dataset)), sparsify=False,
                                       skip_exprs=False)
        X = np.array(mat.toarray())
        cell_name = np.array(obs["cell_type1"])
        cell_type, cell_label = np.unique(cell_name, return_inverse=True)
        Y = cell_label

    else:
        with h5py.File(os.path.join(dir0, 'datasets/real/{}.h5'.format(dataset))) as data_mat:
            X = np.array(data_mat['X'])
            Y = np.array(data_mat['Y'])
            X = np.ceil(X).astype(np.int_)
            Y = np.array(Y).astype(np.int_).squeeze()

    adata = sc.AnnData(X)
    fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(16, 6), constrained_layout=True)

    if dataset in ['10X_PBMC','mouse_bladder_cell','Adam','Human_pancreatic_islets']:
        methods = ['Seurat', 'PCA', 'ICA','ZIFA', 'VASC','scGAE','scGNN','DCA', 'scVI', 'scGraphNE']
        for method in methods:
            r = np.load(os.path.join(dir0, 'results/visualization/{}/record_{}_{}.npz'.format(dataset, dataset, method)), allow_pickle=True)

            adata.obsm['X_umap'] = r['umap']
            adata.obs['true'] = r["true"]
            adata.obs['louvain'] = r['louvain']
            ari = r['ari']
            adata.obs['louvain'] = adata.obs['louvain'].astype(str).astype('category')
            adata.obs['true'] = adata.obs['true'].astype(str).astype('category')
            i = methods.index(method)
            if i < 5:
                axes[0, i].spines['right'].set_visible(False)
                axes[0, i].spines['top'].set_visible(False)
                sc.pl.umap(adata, color="true", ax=axes[0, i],
                               legend_loc=None, palette='tab20')
                axes[0, i].set_title('{}  ARI={:.4f}'.format(method, ari), fontproperties='Arial', fontsize=18)
            else:
                axes[1, (i - 5)].spines['right'].set_visible(False)
                axes[1, (i - 5)].spines['top'].set_visible(False)
                sc.pl.umap(adata, color="true", ax=axes[1, (i-5)], title='{} ARI={:.4f}'.format(method, ari),
                           legend_loc=None, palette='tab20')
                axes[1, (i - 5)].set_title('{}  ARI={:.4f}'.format(method, ari), fontproperties='Arial', fontsize=18)

    else:
        methods = ['Seurat', 'PCA', 'ICA', 'VASC','scGAE','scGNN','DCA', 'scVI', 'scGraphNE']
        for method in methods:
            r = np.load(os.path.join(dir0, 'results/visualization/{}/record_{}_{}.npz'.format(dataset, dataset, method)), allow_pickle=True)

            adata.obsm['X_umap'] = r['umap']
            adata.obs['true'] = r["true"]
            adata.obs['louvain'] = r['louvain']
            ari = r['ari']
            adata.obs['louvain'] = adata.obs['louvain'].astype(str).astype('category')
            adata.obs['true'] = adata.obs['true'].astype(str).astype('category')
            i = methods.index(method)
            if i < 5:
                axes[0, i].spines['right'].set_visible(False)
                axes[0, i].spines['top'].set_visible(False)
                sc.pl.umap(adata, color="true", ax=axes[0, i],legend_loc=None, palette='tab20')
                axes[0, i].set_title('{}  ARI={:.4f}'.format(method, ari), fontproperties='Arial', fontsize=18)
            else:
                axes[1, (i - 5)].spines['right'].set_visible(False)
                axes[1, (i - 5)].spines['top'].set_visible(False)
                sc.pl.umap(adata, color="true", ax=axes[1, (i-5)], title='{} ARI={:.4f}'.format(method, ari),
                           legend_loc=None, palette='tab20')
                axes[1, (i - 5)].set_title('{}  ARI={:.4f}'.format(method, ari), fontproperties='Arial', fontsize=18)
                axes[1,4].xaxis.set_ticks([])
                axes[1,4].yaxis.set_ticks([])
                axes[1,4].axis('off')
    plt.tight_layout(h_pad=2, w_pad=1.5)
    plt.savefig('../figures/RD_{}.svg'.format(dataset), dpi=300, format='svg', bbox_inches='tight')

----------------real data: 10X_PBMC ----------------- 
----------------real data: mouse_bladder_cell ----------------- 
----------------real data: Adam ----------------- 
----------------real data: Human_pancreatic_islets ----------------- 
----------------real data: human_kidney_counts ----------------- 
----------------real data: mouse_ES_cell ----------------- 
----------------real data: Macosko_mouse_retina ----------------- 
