In [None]:
import numpy as np
import pandas as pd
import os
import scanpy as sc

In [None]:
from self_supervision.paths import DATA_DIR, RESULTS_FOLDER

In [None]:
split = 'test'
adata = sc.read_h5ad(os.path.join(RESULTS_FOLDER, f'adata_{split}_embs_scib.h5ad'))

STORE_DIR = os.path.join(DATA_DIR, 'merlin_cxg_2023_05_15_sf-log1p')
cell_type_mapping = pd.read_parquet(os.path.join(STORE_DIR, 'categorical_lookup/cell_type.parquet'))
adata.obs['Cell Type'] = cell_type_mapping.loc[adata.obs['cell_type'].values, 'label'].values
adata.obs['Batch'] = 'Batch ' + adata.obs['tech_sample'].astype('category').cat.codes.astype('str')

### Plot embeddings

In [None]:
top_cell_types = adata.obs['cell_type'].value_counts().index[:10]
adata = adata[adata.obs['cell_type'].isin(top_cell_types)]

In [None]:
# UMAP plots
for key in reversed(list(adata.obsm.keys())):
    if key not in ['X_umap', 'X_tsne']:
        print('ploting for ', key)
        sc.pp.neighbors(adata, use_rep=key)
        sc.tl.umap(adata)
        sc.pl.umap(adata, color=['Cell Type'], save=f'umap_{key}_celltype.png')
        sc.pl.umap(adata, color=['Batch'], save=f'umap_{key}_batch.png')

In [None]:
# tsne plots:
for key in reversed(list(adata.obsm.keys())):
    if key not in ['X_umap', 'X_tsne']:
        print('ploting for ', key)
        sc.tl.tsne(adata, use_rep=key)
        sc.pl.tsne(adata, color=['Cell Type'], save=f'tsne_{key}_celltype.png')
        sc.pl.tsne(adata, color=['Batch'], save=f'tsne_{key}_batch.png')