In [28]:
import os 
import numpy as np
import pandas as pd
import anndata
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import matplotlib
from scgen_model import SCGEN
# import scgen
matplotlib.rc('ytick', labelsize=18)
matplotlib.rc('xtick', labelsize=18)
sc.set_figure_params(dpi_save=300)

In [17]:
path_to_save = "../results/Figures/Supplemental Figure 2/"
os.makedirs(path_to_save, exist_ok=True)
sc.settings.figdir = path_to_save

In [24]:
pbmc = sc.read("tests/data/train_kang.h5ad")
pbmc

AnnData object with n_obs × n_vars = 16893 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'distances', 'connectivities'

In [45]:
pbmc.obs['cell_type'].unique().tolist()

['NK', 'Dendritic', 'CD4T', 'B', 'FCGR3A+Mono', 'CD14+Mono', 'CD8T']

In [40]:
pbmc.X.shape[1]

6998

In [46]:
pbmc.obs.groupby(['cell_type', 'condition']).size()

cell_type    condition 
CD4T         control       2437
             stimulated    3127
CD14+Mono    control       1946
             stimulated     615
B            control        818
             stimulated     993
CD8T         control        574
             stimulated     541
NK           control        517
             stimulated     646
FCGR3A+Mono  control       1100
             stimulated    2501
Dendritic    control        615
             stimulated     463
dtype: int64

In [50]:
pbmc

AnnData object with n_obs × n_vars = 16893 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'distances', 'connectivities'

In [47]:
pbmc[pbmc.obs['cell_type'] == 'NK']

View of AnnData object with n_obs × n_vars = 1163 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'distances', 'connectivities'

In [48]:
pbmc[((pbmc.obs['cell_type'] == 'NK') & (pbmc.obs['condition'] == 'control'))]

View of AnnData object with n_obs × n_vars = 517 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'distances', 'connectivities'

In [49]:
pbmc[~((pbmc.obs['cell_type'] =='NK') & (pbmc.obs['condition'] == 'control'))]

View of AnnData object with n_obs × n_vars = 16376 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'distances', 'connectivities'

In [22]:
scgen_recon = sc.read("valid_pbmc.h5ad")
ctrl = scgen_recon[scgen_recon.obs["condition"] == "CD4T_ctrl"]
real_stim = scgen_recon[scgen_recon.obs["condition"] == "CD4T_real_stim"]
scgen_recon = scgen_recon[scgen_recon.obs["condition"].isin(["CD4T_real_stim", "CD4T_pred_stim", "CD4T_ctrl"])]
scgen_recon

View of AnnData object with n_obs × n_vars = 0 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'distances', 'connectivities'

In [56]:
pbmc.X

<16893x6998 sparse matrix of type '<class 'numpy.float32'>'
	with 5533579 stored elements in Compressed Sparse Row format>

In [55]:
pbmc.X.A

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [33]:
cd4t = pbmc[pbmc.obs["cell_type"] == "CD4T"]
sc.tl.rank_genes_groups(cd4t, groupby="condition", n_genes=100, method="wilcoxon")
diff_genes = cd4t.uns["rank_genes_groups"]["names"]["stimulated"]

  self.data[key] = value


In [39]:
def make_plots(adata, conditions, model_name, figure, x_coeff=0.3, y_coeff=0.1):
    print(adata)
    if model_name == "RealCD4T":
        mean_labels = {"x": "ctrl mean", "y": "stim mean"}
        var_labels = {"x": "ctrl var", "y": "stim var"}
    else:
        mean_labels = {"x": "pred mean", "y": "stim mean"}
        var_labels = {"x": "pred var", "y": "stim var"}
    print(adata.obs.groupby(['condition']).size())
    SCGEN.reg_mean_plot(adata,
                                 condition_key="condition",
                                 axis_keys={"x": conditions["pred_stim"], "y": conditions["real_stim"]},
                                 gene_list=diff_genes[:5],
                                 top_100_genes=diff_genes,
                                 path_to_save=os.path.join(path_to_save, f"SupplFig2{figure}_{model_name}_reg_mean.pdf"),
                                 legend=False,
                                 title="",
                                 labels=mean_labels,
                                 fontsize=26,
                                 textsize=18,
                                 x_coeff=x_coeff,
                                 y_coeff=y_coeff,
                                 show=True,
                                 range=[0, 5, 1])
    
    SCGEN.reg_var_plot(adata, 
                                condition_key="condition",
                                axis_keys={"x": conditions["pred_stim"], "y": conditions["real_stim"]},
                                gene_list=diff_genes[:5],
                                top_100_genes=diff_genes,
                                path_to_save=os.path.join(path_to_save, f"SupplFig2{figure}_{model_name}_reg_var.pdf"),
                                legend=False,
                                labels=var_labels,
                                title="",
                                fontsize=26,
                                textsize=18,
                                x_coeff=x_coeff,
                                y_coeff=y_coeff,
                                show=True)
    if model_name == "scGen":
        adata = adata[adata.obs["condition"].isin(["CD4T_ctrl", "CD4T_pred_stim", "CD4T_real_stim"])]
        adata.obs["condition"].replace("CD4T_ctrl", "ctrl", inplace=True)
        adata.obs["condition"].replace("CD4T_real_stim", "real_stim", inplace=True)
        adata.obs["condition"].replace("CD4T_pred_stim", "pred_stim", inplace=True)
    sc.pp.neighbors(adata, n_neighbors=20)
    sc.tl.umap(adata, min_dist=1.1)
    import matplotlib.style
    plt.style.use('default')
    if model_name == "scGen":
        sc.pl.umap(adata, color=["condition"],
                   legend_loc=False,
                   frameon=False,
                   title="",
                   palette=matplotlib.rcParams["axes.prop_cycle"],
                   save=f"_latent_conditions.png",
                   show=True)
    else:
        if model_name == "RealCD4T":
            sc.pl.umap(adata, color=["condition"],
                   legend_loc=False,
                   frameon=False,
                   title="",
                   palette=['#1f77b4', '#2ca02c'],
                   save=f"_latent_conditions.png",
                   show=True)
        else:
            
            sc.pl.umap(adata, color=["condition"],
                       legend_loc=False,
                       frameon=False,
                       title="",
                       palette=matplotlib.rcParams["axes.prop_cycle"],
                       save=f"_latent_conditions.png",
                       show=True)
    
    os.rename(src=os.path.join(path_to_save, "umap_latent_conditions.png"), 
              dst=os.path.join(path_to_save, f"SupplFig2{figure}_{model_name}_umap.png"))

In [38]:
conditions = {"real_stim": "CD4T_real_stim", "pred_stim": "CD4T_pred_stim"}
make_plots(scgen_recon, conditions, "scGen", "a", 0.45, 0.8)

Series([], dtype: int64)


TypeError: reg_mean_plot() missing 1 required positional argument: 'adata'