Inspired by the implementation proposed in Fisher, Rudin, Dominici (2018) https://arxiv.org/abs/1801.01489

import modules:

In [None]:
import torch
import scvi
import pandas as pd
import scanpy as sc
from scvi import REGISTRY_KEYS
from captum.attr import FeatureAblation

sc.set_figure_params(dpi=100, frameon=False, color_map='Reds', facecolor=None)
sc.logging.print_header()

## load model and data

In [None]:
base_path = '/home/icb/yuge.ji/projects/feature-attribution-sc'
hlca_path = f'{base_path}/datasets/hlca_subset.h5ad'
adata = sc.read(hlca_path)
adata

In [None]:
model = scvi.model.SCANVI.load('/home/icb/yuge.ji/projects/HLCA_reproducibility/notebooks/3_atlas_extension/scanvi_model/', adata)
model

get cell type names that match the labels (integers in the model):

In [None]:
ct_names = model.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY)['categorical_mapping']
ct_names = [ct for ct in ct_names if ct != "unlabeled"]

define batch size. In this case we'll set it to the entire size of the subsetted HLCA:

In [None]:
batch_size=adata.shape[0]

create a dataloader and load your first batch (in this case all the cells):

In [None]:
scdl = model._make_data_loader(adata=adata, indices=list(range(adata.shape[0])), batch_size=batch_size)
batch = next(scdl.__iter__())

### measure against posterior (not done yet)

Wrap `model.module.forward` because captum has an internal check that the inputs pass are tensors.

Captum returns an attribution map of either `tensor(n_features * output_size, n_features), n_inputs` or ?

### Measure feature attribution with respect to classification probabilities

create the ablator, containing the forward function inside of it:

In [None]:
ablator = FeatureAblation(model.module.classify)

Run the feature attribution function. The attribution below outputs two tensors. One ablates the gene features one by one, for every cell, and has shape (n_cells * n_ct_classes) * n_genes. The other ablates the (biological) batch variable, which it takes as a single, continuous variable (even though it is a one-hot encoded (n_datasets)-dimensional variable), and therefore ablates it once per cell, and outputs a (n_cells * n_ct_classes) * 1 shape tensor. We'll ignore that one.

In [None]:
%%time
# per feature per output
attribution_map = ablator.attribute((batch['X'], batch['batch']))

In [None]:
%%time
# per feature per output
attribution_map = ablator.attribute((batch['X'], batch['batch']))

In [None]:
attribution_map_genes = attribution_map[0] # take only the first tensor (explained above)

reshape, such that n_cells\*n_classes is split into two dimensions ((n_cells*28) becomes n_cells*28)

In [None]:
attribution_map_genes_3d = attribution_map_genes.reshape((batch_size,28,2000))

For each class, calculate mean only across cells of that class (= cell type), ignore other cells. Then take only the feature importances for that particular class. Reasoning: we want to learn which features were important for classifing a cell of cell type a *as* cell type a, and not the features that made the model *not* classify it as cell type a (the latter would give negative markers rather than positive ones). 

In [None]:
means = pd.DataFrame(index=adata.var_names,columns=ct_names)

In [None]:
for ct in batch['labels'].unique():
    ct_float = ct.item()
    ct_int = int(ct_float)
    ct_indexing = (batch['labels'] == ct_float).reshape(-1)
    means.iloc[:,ct_int] = torch.mean(attribution_map_genes_3d[ct_indexing,ct_int,:],dim=0)

### measure against latent (to do)

## Store results:

In [None]:
means.to_csv("../outputs/ablation/task2.csv")