In [1]:
!pip install -e /dss/dsshome1/04/di93zer/git/cellnet --no-deps

Obtaining file:///dss/dsshome1/04/di93zer/git/cellnet
  Preparing metadata (setup.py) ... [?25ldone
[?25hInstalling collected packages: cellnet
  Running setup.py develop for cellnet
Successfully installed cellnet

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [78]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import pickle
import os
from os.path import join


import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import dask.dataframe as dd

In [3]:
DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'

In [4]:
import dask.dataframe as dd
import dask.array as da

from scipy.sparse import csr_matrix

In [36]:
from cellnet.utils.tabnet_explain import explain
from cellnet.utils.data_loading import dataloader_factory

In [6]:
cell_type_mapping = pd.read_parquet(join(DATA_PATH, 'categorical_lookup/cell_type.parquet'))
inverse_cell_type_mapping = {v: k for k, v in cell_type_mapping.label.to_dict().items()}

In [25]:
# list of easy to predict cell types (already based on a small amount of training data)

cell_types_to_check = [
    'Bergmann glial cell',
    'L6b glutamatergic cortical neuron',
    'bronchus fibroblast of lung',
    'cardiac neuron',
    'caudal ganglionic eminence derived GABAergic cortical interneuron',
    'chandelier pvalb GABAergic cortical interneuron',
    'ependymal cell',
    'alternatively activated macrophage',
    'alveolar macrophage',
    'central nervous system macrophage',
    'elicited macrophage',
    'inflammatory macrophage',
    'lung macrophage',
    'macrophage',
    'lung pericyte',
    'paneth cell',
    'renal interstitial pericyte',
    'retina horizontal cell'
]

In [8]:
def get_count_matrix_and_obs(ddf):
    obs = ddf[['cell_type', 'tech_sample', 'assay', 'dataset_id', 'tissue']].compute()
    x = (
        ddf['X']
        .map_partitions(
            lambda xx: pd.DataFrame(np.vstack(xx.tolist())), 
            meta={col: 'f4' for col in range(19331)}
        )
        .to_dask_array(lengths=[1024] * ddf.npartitions)
    )
    obs = ddf[['cell_type', 'tech_sample', 'assay', 'dataset_id', 'tissue']].compute().reset_index(drop=True)

    return x, obs


In [9]:
ddf_train = dd.read_parquet(join(DATA_PATH, 'train'), split_row_groups=True)
x_train, obs_train = get_count_matrix_and_obs(ddf_train)

In [10]:
# subset x_train to save memory
x_train = x_train[:4_000_000, :]
obs_train = obs_train[:4_000_000]

In [34]:
idxs = []
for celltype in cell_types_to_check:
    idxs += obs_train[
        obs_train.cell_type == inverse_cell_type_mapping[celltype]
    ].index.to_numpy()[:3000].tolist()


In [37]:
from cellnet.models import TabnetClassifier
from cellnet.estimators import EstimatorCellTypeClassifier


obs_ = obs_train.iloc[idxs].copy()
obs_['cell_type'] = pd.Categorical(
    obs_.cell_type.replace(cell_type_mapping.label.to_dict()), 
    cell_type_mapping.label.tolist(), 
    ordered=False
)


if not os.path.isfile('model_eval_cache/explain_extended.pkl'):
    x_ = x_train[idxs, :].map_blocks(csr_matrix).compute()

    CKPT_PATH_TABNET = '/mnt/dssfs02/tb_logs/cxg_2023_05_15_tabnet/default/w_augment_4/checkpoints/val_f1_macro_epoch=45_val_f1_macro=0.847.ckpt'
    estim = EstimatorCellTypeClassifier(DATA_PATH)
    estim.init_datamodule(batch_size=2048)
    model = TabnetClassifier.load_from_checkpoint(CKPT_PATH_TABNET, **estim.get_fixed_model_params('tabnet'))
    loader = dataloader_factory(x_, obs_)
    explain_masks = explain(model, loader, only_return_nnz_idxs=False, normalize=True)
    with open('model_eval_cache/explain_extended.pkl', 'wb') as f:
        pickle.dump(explain_masks, f)
else:
    # load cached predictions
    with open('model_eval_cache/explain_extended.pkl', 'rb') as f:
        explain_masks = pickle.load(f)

17it [00:12,  1.33it/s]


In [38]:
(explain_masks != 0.).mean()

0.010272220712622897

In [39]:
explain_masks.shape

(33825, 19331)

In [40]:
x_ = x_train[idxs, :].map_blocks(csr_matrix).compute()

In [41]:
input_features = x_.toarray() * explain_masks

In [42]:
(input_features != 0.).mean()

0.0020453619239847857

In [43]:
(input_features.mean(axis=0) >= 1e-4).mean()

0.011587605400651802

In [45]:
top_x_genes = {}

for cell_type in cell_types_to_check:
    top_x_genes[cell_type] = np.argsort(
        -input_features[(obs_.cell_type == cell_type).to_numpy(), :].mean(axis=0)
    )[:200]

genes = []
for gene_idxs in top_x_genes.values():
    genes += gene_idxs.tolist()
genes = list(set(genes))

len(genes)

549

In [46]:
var = pd.read_parquet(join(DATA_PATH, 'var.parquet'))

top_x_genes_specific = {
    celltype: var.loc[genes[:25].astype(str)].feature_name.tolist()
    for celltype, genes in top_x_genes.items()
}

index = []
data = []
for k, v in top_x_genes_specific.items():
    index.append(k)
    data.append(v)
    gene_names = v

top_marker_genes = pd.DataFrame(data, index)
top_marker_genes.to_csv('markers_extended.csv')
top_marker_genes.loc[:, range(15)] #.to_csv('markers.csv')

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
Bergmann glial cell,PTPRZ1,RGL1,KCNN2,SCRG1,FASTKD2,CSMD1,KLHL32,GPM6B,LCORL,ADGRL3,PCYT2,TERF2IP,SLC44A1,SUMF1,GABRB3
L6b glutamatergic cortical neuron,WDFY3,SH3GL2,CDH12,WRNIP1,GNAQ,CSMD1,SYT1,KCTD16,NRXN1,IGSF21,TANC2,EPHB1,RPL6,NTM,THY1
bronchus fibroblast of lung,PTGES3,DGCR6L,RPS3A,CETN2,CCDC159,STX2,FTH1,ABCA10,S100A4,B2M,NDN,CEP170,RPL12,SLIT2,RPL35A
cardiac neuron,NRXN1,XKR4,NTM,TTC37,HIBCH,LRCH1,TACC1,AUTS2,FASTKD2,FGF12,LCORL,NCALD,EPHB1,PDE1C,PTPRZ1
caudal ganglionic eminence derived GABAergic cortical interneuron,SYT1,FASTKD2,SYN3,ERC2,NRXN1,MLLT3,UBE2L3,CSMD1,HDAC8,KCND2,VWC2,GNAQ,CEP170,WDFY3,FRY
chandelier pvalb GABAergic cortical interneuron,SYT1,CSMD2,FASTKD2,ERC2,EDIL3,CSMD1,GNAQ,NRG3,KIF1B,WDFY3,NRXN1,GRIK1,ATRNL1,ADGRL3,GABRB3
ependymal cell,WDFY3,KIF1B,ADK,NRG3,GNAQ,PPFIA2,KCTD16,SCN1A,WRNIP1,SLIT2,STXBP5L,LRRC7,SYT1,NTM,PFKFB2
alternatively activated macrophage,RPL8,RAB31,NDUFA1,FGFR1,FCGR3A,SDCBP,B2M,GNAQ,PCMT1,WRNIP1,UBE2L3,ATP1B3,S100A4,RPL35A,ATXN1
alveolar macrophage,WRNIP1,B2M,NDUFA1,CFL1,EEF1B2,PSAP,RAB31,GPCPD1,SDHD,C1QB,FCGR3A,ARPP19,RPS23,PITPNB,HLA-DRB1
central nervous system macrophage,SLC9A9,AUTS2,PGM5,FASTKD2,TACC1,HIBCH,LRCH1,ADK,NTM,NRG3,WDFY3,NXPE3,LCORL,NRXN1,IGSF21
