Inspired by the implementation proposed in Fisher, Rudin, Dominici (2018) https://arxiv.org/abs/1801.01489

In [7]:
# !conda install -c conda-forge shap
# !pip install captum

In [3]:
import torch
import scvi
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

sc.set_figure_params(dpi=100, frameon=False, color_map='Reds', facecolor=None)
sc.logging.print_header()

scanpy==1.9.1 anndata==0.8.0 umap==0.5.3 numpy==1.22.4 scipy==1.9.1 pandas==1.4.4 scikit-learn==1.1.2 statsmodels==0.13.2 pynndescent==0.5.7


  IPython.display.set_matplotlib_formats(*ipython_format)


## load model

In [4]:
hlca_path = '../datasets/hlca.h5ad'
adata = sc.read(hlca_path)
adata

AnnData object with n_obs × n_vars = 584944 × 2000
    obs: 'sample', 'original_celltype_ann', 'study_long', 'study', 'last_author_PI', 'subject_ID', 'subject_ID_as_published', 'pre_or_postnatal', 'age_in_years', 'age_range', 'sex', 'ethnicity', 'mixed_ethnicity', 'smoking_status', 'smoking_history', 'BMI', 'known_lung_disease', 'condition', 'subject_type', 'cause_of_death', 'sample_type', 'anatomical_region_coarse', 'anatomical_region_detailed', 'tissue_dissociation_protocol', 'cells_or_nuclei', 'single_cell_platform', "3'_or_5'", 'enrichment', 'sequencing_platform', 'reference_genome_coarse', 'ensembl_release_reference_genome', 'cell_ranger_version', 'disease_status', 'fresh_or_frozen', 'cultured', 'cell_viability_%', 'comments', 'Processing_site', 'dataset', 'anatomical_region_level_1', 'anatomical_region_level_2', 'anatomical_region_level_3', 'anatomical_region_highest_res', 'age', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'ann_highest_res', 'ann_ne

In [68]:
model = scvi.model.SCANVI.load('/home/icb/yuge.ji/projects/HLCA_reproducibility/notebooks/3_atlas_extension/scanvi_model/', adata)
model

[34mINFO    [0m File [35m/home/icb/yuge.ji/projects/HLCA_reproducibility/notebooks/3_atlas_extension/sca[0m
         [35mnvi_model/[0m[95mmodel.pt[0m already downloaded                                               


var_names for adata passed in does not match var_names of adata used to train the model. For valid results, the vars need to be the same and in the same order as the adata used to train the model.
adata.X does not contain unnormalized count data. Are you sure this is what you want?
Category 18 in adata.obs['_scvi_labels'] has fewer than 3 cells. Models may not train properly.
Category 18 in adata.obs['_scvi_labels'] has fewer than 3 cells. Models may not train properly.




In [192]:
scdl = model._make_data_loader(adata=adata, indices=list(range(1000)), batch_size=100)
batch = next(scdl.__iter__())

## captum

In [193]:
import numpy as np
# from captum.attr import visualization as viz
from captum.attr import FeatureAblation

Generate feature mask.

In [181]:
# # faux mask for when we might want to ablate or permute in groups
# feature_mask = np.array(adata.shape[1])
# feature_mask[feature_mask.isin([5, 10])] = 1 
# feature_mask[feature_mask == 20] = 2
# feature_mask[feature_mask == 255] = 3

### measure against posterior

Wrap `model.module.forward` because captum has an internal check that the inputs pass are tensors.

In [194]:
forw = lambda x, y, z: model.module.forward({'X': x, 'batch':y, 'labels':z}, compute_loss=False)[1]['px_scale']

In [195]:
model.module.forward(batch, compute_loss=False)[1]['px_scale'].shape

torch.Size([100, 2000])

Run captum.

In [196]:
ablator = FeatureAblation(forw)

In [197]:
%%time
# # per feature per output
# attribution_map = ablator.attribute(
#     (batch['X'], batch['batch'], batch['labels']),
# #     target=['Macrophages']*batch['X'].shape[1],
#     feature_mask=(torch.tensor(feature_mask), torch.tensor(feature_mask), torch.tensor(feature_mask)))

# per feature (aggregation) (takes longer)
attribution_map = ablator.attribute(
    (batch['X'], batch['batch'], batch['labels']),
    perturbations_per_eval=1,
#     show_progress=True
)

CPU times: user 1h 30min 49s, sys: 33min 49s, total: 2h 4min 38s
Wall time: 7min 50s


In [175]:
[i.shape for i in attribution_map]

[torch.Size([10000, 2000]), torch.Size([10000, 1]), torch.Size([10000, 1])]

In [177]:
[i.shape for i in attribution_map]

[torch.Size([10000, 2000]), torch.Size([10000, 1]), torch.Size([10000, 1])]

Captum returns an attribution map of either `tensor(n_features * output_size, n_features), n_inputs` or ?

In [166]:
attribution_map[0].shape

torch.Size([10000, 2000])

In [115]:
attribution_map

(tensor([[-1.7431e-08, -1.7431e-08, -1.7431e-08,  ..., -1.7431e-08,
          -1.7431e-08, -1.7431e-08],
         [-1.0250e-10, -1.0250e-10, -1.0250e-10,  ..., -1.0250e-10,
          -1.0250e-10, -1.0250e-10],
         [-1.9924e-08, -1.9924e-08, -1.9924e-08,  ..., -1.9924e-08,
          -1.9924e-08, -1.9924e-08],
         ...,
         [-8.1410e-07, -8.1410e-07, -8.1410e-07,  ..., -8.1410e-07,
          -8.1410e-07, -8.1410e-07],
         [-9.7044e-01, -9.7044e-01, -9.7044e-01,  ..., -9.7044e-01,
          -9.7044e-01, -9.7044e-01],
         [-1.0539e-06, -1.0539e-06, -1.0539e-06,  ..., -1.0539e-06,
          -1.0539e-06, -1.0539e-06]]),
 tensor([[ 1.2044e-12],
         [-4.3241e-15],
         [ 1.4188e-12],
         [ 1.7650e-12],
         [ 5.3920e-14],
         [ 1.0132e-12],
         [ 1.1324e-12],
         [ 2.9206e-12],
         [ 1.7511e-16],
         [ 1.1511e-12],
         [ 2.4874e-13],
         [ 1.7375e-12],
         [ 8.7122e-13],
         [ 0.0000e+00],
         [-2.1124e

In [116]:
attribution_map[0].shape

torch.Size([140, 2000])

In [117]:
attribution_map[1].shape

torch.Size([140, 1])

### measure against classification probabilities

In [179]:
ablator = FeatureAblation(model.module.classify)

In [159]:
# model.module.classify(batch['X'], batch['batch']).shape

torch.Size([5, 28])

In [182]:
# per feature per output
attribution_map = ablator.attribute(
    (batch['X'], batch['batch']),
#     target=['Macrophages']*batch['X'].shape[1],
    feature_mask=(torch.tensor(feature_mask), torch.tensor(feature_mask)))

In [183]:
[i.shape for i in attribution_map]

[torch.Size([140, 2000]), torch.Size([140, 1])]

### measure against latent