import modules:

In [1]:
import torch
import scvi
import pandas as pd
import scanpy as sc
from scvi import REGISTRY_KEYS
from captum.attr import FeatureAblation

sc.set_figure_params(dpi=100, frameon=False, color_map='Reds', facecolor=None)
sc.logging.print_header()

Global seed set to 0


scanpy==1.9.1 anndata==0.8.0 umap==0.5.3 numpy==1.21.5 scipy==1.9.1 pandas==1.4.4 scikit-learn==1.1.2 statsmodels==0.13.2 pynndescent==0.5.7


## load model and data

In [2]:
base_path = '/home/icb/yuge.ji/projects/feature-attribution-sc'
hlca_path = f'{base_path}/datasets/hlca_subset.h5ad'
adata = sc.read(hlca_path)
adata

AnnData object with n_obs × n_vars = 14500 × 2000
    obs: 'sample', 'original_celltype_ann', 'study_long', 'study', 'last_author_PI', 'subject_ID', 'subject_ID_as_published', 'pre_or_postnatal', 'age_in_years', 'age_range', 'sex', 'ethnicity', 'mixed_ethnicity', 'smoking_status', 'smoking_history', 'BMI', 'known_lung_disease', 'condition', 'subject_type', 'cause_of_death', 'sample_type', 'anatomical_region_coarse', 'anatomical_region_detailed', 'tissue_dissociation_protocol', 'cells_or_nuclei', 'single_cell_platform', "3'_or_5'", 'enrichment', 'sequencing_platform', 'reference_genome_coarse', 'ensembl_release_reference_genome', 'cell_ranger_version', 'disease_status', 'fresh_or_frozen', 'cultured', 'cell_viability_%', 'comments', 'Processing_site', 'dataset', 'anatomical_region_level_1', 'anatomical_region_level_2', 'anatomical_region_level_3', 'anatomical_region_highest_res', 'age', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'ann_highest_res', 'ann_new

In [3]:
model = scvi.model.SCANVI.load('/home/icb/yuge.ji/projects/HLCA_reproducibility/notebooks/3_atlas_extension/scanvi_model/', adata)
model

[34mINFO    [0m File [35m/home/icb/yuge.ji/projects/HLCA_reproducibility/notebooks/3_atlas_extension/sca[0m
         [35mnvi_model/[0m[95mmodel.pt[0m already downloaded                                               






get cell type names that match the labels (integers in the model):

In [4]:
ct_names = model.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY)['categorical_mapping']
ct_names = [ct for ct in ct_names if ct != "unlabeled"]

retrieve the data as a tensor from the dataloader. We'll define batch size to be the size of the whole dataset to do so. In this case we'll set it to the entire size of the subsetted HLCA:

In [5]:
batch_size=adata.shape[0]

create a dataloader and load your first batch (in this case all the cells):

In [6]:
scdl = model._make_data_loader(adata=adata, indices=list(range(adata.shape[0])), batch_size=batch_size)
batch = next(scdl.__iter__())

### measure against posterior (not done yet)

Wrap `model.module.forward` because captum has an internal check that the inputs pass are tensors.

### Measure feature attribution with respect to classification probabilities

create the ablator, containing the forward function inside of it:

In [7]:
ablator = FeatureAblation(model.module.classify)

Run the feature attribution function. The attribution below outputs two tensors. One ablates the gene features one by one, for every cell, and has shape (n_cells * n_ct_classes) * n_genes. The other ablates the (biological) batch variable, which it takes as a single, continuous variable (even though it is a one-hot encoded (n_datasets)-dimensional variable), and therefore ablates it once per cell, and outputs a (n_cells * n_ct_classes) * 1 shape tensor. We'll ignore that one.

In [8]:
%%time
# per feature per output
attribution_map = ablator.attribute((batch['X'], batch['batch']))

CPU times: user 1h 40min 41s, sys: 1h 4min 23s, total: 2h 45min 4s
Wall time: 1h 23min 40s


In [9]:
attribution_map_genes = attribution_map[0] # take only the first tensor (explained above)

reshape, such that n_cells\*n_classes is split into two dimensions ((n_cells*28) becomes n_cells*28)

In [10]:
attribution_map_genes_3d = attribution_map_genes.reshape((batch_size,28,2000))

TODO: save this attribution map for further analyses

For each class, calculate mean only across cells of that class (= cell type), ignore other cells. Then take only the feature importances for that particular class. Reasoning: we want to learn which features were important for classifing a cell of cell type a *as* cell type a, and not the features that made the model *not* classify it as cell type a (the latter would give negative markers rather than positive ones). 

In [11]:
means = pd.DataFrame(index=adata.var_names,columns=ct_names)

In [12]:
for ct in batch['labels'].unique():
    ct_float = ct.item()
    ct_int = int(ct_float)
    ct_indexing = (batch['labels'] == ct_float).reshape(-1)
    means.iloc[:,ct_int] = torch.mean(attribution_map_genes_3d[ct_indexing,ct_int,:],dim=0)

### measure against latent (to do)

## Store results:

In [13]:
means.to_csv("../outputs/ablation/task2.csv")