Inspired by the implementation proposed in Fisher, Rudin, Dominici (2018) https://arxiv.org/abs/1801.01489

In [7]:
# !conda install -c conda-forge shap
# !pip install captum

In [1]:
import torch
import scvi
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

sc.set_figure_params(dpi=100, frameon=False, color_map='Reds', facecolor=None)
sc.logging.print_header()

Global seed set to 0


scanpy==1.9.1 anndata==0.8.0 umap==0.5.3 numpy==1.21.5 scipy==1.9.1 pandas==1.4.4 scikit-learn==1.1.2 statsmodels==0.13.2 pynndescent==0.5.7


## load model

In [2]:
base_path = '/home/icb/yuge.ji/projects/feature-attribution-sc'
hlca_path = f'{base_path}/datasets/hlca_subset.h5ad'
adata = sc.read(hlca_path)
adata

AnnData object with n_obs × n_vars = 584944 × 2000
    obs: 'sample', 'original_celltype_ann', 'study_long', 'study', 'last_author_PI', 'subject_ID', 'subject_ID_as_published', 'pre_or_postnatal', 'age_in_years', 'age_range', 'sex', 'ethnicity', 'mixed_ethnicity', 'smoking_status', 'smoking_history', 'BMI', 'known_lung_disease', 'condition', 'subject_type', 'cause_of_death', 'sample_type', 'anatomical_region_coarse', 'anatomical_region_detailed', 'tissue_dissociation_protocol', 'cells_or_nuclei', 'single_cell_platform', "3'_or_5'", 'enrichment', 'sequencing_platform', 'reference_genome_coarse', 'ensembl_release_reference_genome', 'cell_ranger_version', 'disease_status', 'fresh_or_frozen', 'cultured', 'cell_viability_%', 'comments', 'Processing_site', 'dataset', 'anatomical_region_level_1', 'anatomical_region_level_2', 'anatomical_region_level_3', 'anatomical_region_highest_res', 'age', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'ann_highest_res', 'ann_ne

In [3]:
model = scvi.model.SCANVI.load('/home/icb/yuge.ji/projects/HLCA_reproducibility/notebooks/3_atlas_extension/scanvi_model/', adata)
model

[34mINFO    [0m File [35m/home/icb/yuge.ji/projects/HLCA_reproducibility/notebooks/3_atlas_extension/sca[0m
         [35mnvi_model/[0m[95mmodel.pt[0m already downloaded                                               






In [100]:
model.

'scanvi_label'

In [142]:
batch_size=10

In [143]:
scdl = model._make_data_loader(adata=adata, indices=list(range(1000)), batch_size=batch_size)
batch = next(scdl.__iter__())

## captum

In [144]:
import numpy as np
# from captum.attr import visualization as viz
from captum.attr import FeatureAblation

Generate feature mask.

In [15]:
# # faux mask for when we might want to ablate or permute in groups
# feature_mask = np.array(adata.shape[1])
# feature_mask[feature_mask.isin([5, 10])] = 1 
# feature_mask[feature_mask == 20] = 2
# feature_mask[feature_mask == 255] = 3

### measure against posterior

Wrap `model.module.forward` because captum has an internal check that the inputs pass are tensors.

Captum returns an attribution map of either `tensor(n_features * output_size, n_features), n_inputs` or ?

### measure against classification probabilities

In [145]:
ablator = FeatureAblation(model.module.classify)

In [201]:
model.module.classify(batch['X'],batch_index=batch['batch']).shape

torch.Size([10, 28])

This attribution below outputs two tensors. One ablates the gene features one by one, for every cell, and has shape (n_cells * n_ct_classes) * n_genes. The other ablates the (biological) batch variable, which it takes as a single, continuous variable (even though it is a one-hot encoded (n_datasets)-dimensional variable), and therefore ablates it once per cell, and outputs a (n_cells * n_ct_classes) * 1 shape tensor. We'll ignore that one.

In [147]:
# per feature per output
attribution_map = ablator.attribute((batch['X'], batch['batch']))

In [148]:
attribution_map_genes = attribution_map[0] # take only the first tensor (explained above)

reshape, such that n_cells\*n_classes is split into two dimensions (280 becomes 10*280)

In [125]:
reshaped = attribution_map_genes.reshape((10,28,2000))

Calculate means across cells

TO DO INSTEAD: For each class, calculate mean only across cells of that class (= cell type), ignore other cells.

In [154]:
batch['labels']

tensor([[13.],
        [19.],
        [ 1.],
        [13.],
        [ 4.],
        [ 4.],
        [27.],
        [26.],
        [ 2.],
        [26.]])

In [167]:
(batch['labels'] == 13).reshape(-1)

tensor([ True, False, False,  True, False, False, False, False, False, False])

In [160]:
reshaped.shape

torch.Size([10, 28, 2000])

In [170]:
reshaped[(batch['labels'] == 1).reshape(-1),:,:].shape

torch.Size([1, 28, 2000])

In [183]:
batch['labels'].unique()

tensor(1.)

In [180]:
for ct in batch['labels'].unique():
    print(ct)
    print(batch['labels'] == ct)

IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number

In [149]:
means = torch.mean(reshaped.float(), dim=0)

In [150]:
means.shape

torch.Size([28, 2000])

### measure against latent

In [152]:
model.get_

<bound method BaseModelClass.get_anndata_manager of >

## Store results:

In [184]:
from scvi import REGISTRY_KEYS

In [203]:
ct_names = model.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY)['categorical_mapping']
# remove "unlabeled"
ct_names = [ct for ct in ct_names if ct != "unlabeled"]

In [204]:
means_df = pd.DataFrame(data=means.T,index=adata.var_names, columns=ct_names)

In [205]:
means_df.to_csv("../outputs/ablation/task2.csv")