In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os

import numpy as np
import pandas as pd

import scanpy as sc

import torch

from scmg.model.contrastive_embedding import CellEmbedder, embed_adata, decode_adata
from scmg.preprocessing.data_standardization import GeneNameMapper

gene_name_mapper = GeneNameMapper()

In [None]:
adata = sc.read_h5ad('/GPUData_xingjie/SCMG/test_datasets/blood_dev/Triana_bone_marrow_HS_2021_healthy.h5ad')
adata = adata.raw.to_adata()
adata

In [None]:
#adata = sc.read_h5ad('/GPUData_xingjie/cytofuture/test_datasets/organogenesis/Pijuan-Sala_organogenesis_MM_2019.h5ad')
#adata.var.index = gene_name_mapper.map_gene_names(adata.var.index, 'mouse', 'human', 'id', 'id')
#adata = adata[:, adata.var.index != 'na'].copy()
#adata.obs['cell_type'] = adata.obs['celltype']
#adata.var_names_make_unique()
#adata

In [None]:
# Load the autoencoder model

model_path = 'trained_embedder'

model = torch.load(os.path.join(model_path, 'model.pt'))
model.load_state_dict(torch.load(os.path.join(model_path, 'best_state_dict.pth')))

device = 'cuda'
model.to(device)
model.eval()

In [None]:
embed_adata(model, adata)

In [None]:
sc.pp.neighbors(adata, use_rep='X_ce_latent', n_neighbors=50)
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color='cell_type', legend_loc='on data', legend_fontsize=3)

In [None]:
model.dataset_id_map

In [None]:
#adata_ss = sc.pp.subsample(adata, n_obs=10000, copy=True)
adata_ss = adata.copy()

adata_pred = decode_adata(model, adata_ss, dataset_names=['Suo_ImmuneDev_HS_2022:all'] * adata_ss.shape[0])
adata_pred

In [None]:


adata_ss.var.index = gene_name_mapper.map_gene_names(
    adata_ss.var.index, 'human', 'human', 'id', 'name')
adata_ss.var_names_make_unique()

adata_pred.var.index = gene_name_mapper.map_gene_names(
    adata_pred.var.index, 'human', 'human', 'id', 'name')

import anndata
adata_ss_named = anndata.AnnData(adata_ss.X, 
                                obs=adata_ss.obs.copy(), var=adata_ss.var.copy())
adata_ss_named.obsm['X_umap'] = adata_ss.obsm['X_umap']
sc.pp.normalize_total(adata_ss_named, target_sum=1e4)
sc.pp.log1p(adata_ss_named)

In [None]:
genes_to_plot = ['GATA1', 'SPI1', 'IKZF1']

sc.pl.umap(adata_ss_named, color=genes_to_plot)
sc.pl.umap(adata_pred, color=genes_to_plot)