# Feature Permutation

- Theoretical background: https://arxiv.org/abs/1312.6034
- API docs: https://captum.ai/api/saliency.html

In [1]:
import torch
import scvi
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

sc.set_figure_params(dpi=100, frameon=False, color_map='Reds', facecolor=None)
sc.logging.print_header()

Global seed set to 0


scanpy==1.9.1 anndata==0.8.0 umap==0.5.3 numpy==1.23.3 scipy==1.9.1 pandas==1.4.4 scikit-learn==1.1.2 statsmodels==0.13.2 pynndescent==0.5.7


## load model

In [2]:
hlca_path = '../datasets/hlca.h5ad'
adata = sc.read(hlca_path)
adata

AnnData object with n_obs × n_vars = 584944 × 2000
    obs: 'sample', 'original_celltype_ann', 'study_long', 'study', 'last_author_PI', 'subject_ID', 'subject_ID_as_published', 'pre_or_postnatal', 'age_in_years', 'age_range', 'sex', 'ethnicity', 'mixed_ethnicity', 'smoking_status', 'smoking_history', 'BMI', 'known_lung_disease', 'condition', 'subject_type', 'cause_of_death', 'sample_type', 'anatomical_region_coarse', 'anatomical_region_detailed', 'tissue_dissociation_protocol', 'cells_or_nuclei', 'single_cell_platform', "3'_or_5'", 'enrichment', 'sequencing_platform', 'reference_genome_coarse', 'ensembl_release_reference_genome', 'cell_ranger_version', 'disease_status', 'fresh_or_frozen', 'cultured', 'cell_viability_%', 'comments', 'Processing_site', 'dataset', 'anatomical_region_level_1', 'anatomical_region_level_2', 'anatomical_region_level_3', 'anatomical_region_highest_res', 'age', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'ann_highest_res', 'ann_ne

In [3]:
model = scvi.model.SCANVI.load('../models/scanvi_model/', adata)
model

[34mINFO    [0m File ..[35m/models/scanvi_model/[0m[95mmodel.pt[0m already downloaded                                                   






## Applying saliency

### Measure against classification probabilities

In [6]:
from captum.attr import visualization as viz
from captum.attr import Saliency

In [7]:
saliency = Saliency(model.module.classify)

In [44]:
model.module.classify(batch['X'], batch["batch"]).shape

torch.Size([10, 28])

In [39]:
def divide_chunks(l, n):
    for i in range(0, len(l), n):
        yield l[i:i + n]
 

labels_2 = list(divide_chunks(labels, 10))
labels_2

[[13, 19, 1, 13, 4, 4, 27, 26, 2, 26],
 [8, 26, 13, 13, 13, 2, 4, 26, 13, 4],
 [13, 4, 28, 28, 13, 13, 1, 13, 1, 0],
 [4, 13, 13, 26, 4, 8, 3, 4, 4, 13],
 [1, 22, 22, 26, 1, 23, 13, 17, 8, 26],
 [13, 4, 22, 13, 13, 8, 4, 2, 8, 4],
 [20, 26, 22, 1, 8, 22, 26, 8, 22, 25],
 [28, 13, 17, 19, 8, 13, 13, 17, 13, 28],
 [19, 1, 22, 13, 4, 28, 13, 4, 1, 1],
 [1, 4, 1, 4, 1, 1, 17, 1, 13, 26],
 [22, 17, 22, 7, 26, 13, 10, 26, 13, 13],
 [27, 22, 12, 7, 28, 26, 13, 4, 4, 4],
 [13, 22, 23, 1, 26, 13, 22, 8, 19, 8],
 [13, 10, 4, 27, 19, 22, 13, 13, 8, 7],
 [4, 4, 1, 26, 13, 13, 26, 17, 13, 22],
 [8, 10, 4, 28, 21, 8, 4, 8, 13, 4],
 [7, 27, 8, 17, 17, 13, 13, 22, 19, 26],
 [4, 1, 13, 4, 7, 26, 13, 0, 13, 22],
 [13, 4, 28, 4, 1, 26, 12, 4, 8, 10],
 [13, 13, 28, 22, 4, 28, 22, 4, 13, 4],
 [8, 19, 8, 28, 13, 22, 13, 4, 13, 1],
 [10, 4, 22, 8, 28, 2, 22, 21, 13, 1],
 [4, 4, 26, 17, 13, 4, 13, 13, 26, 13],
 [17, 17, 13, 13, 22, 1, 13, 4, 4, 13],
 [22, 4, 22, 8, 12, 28, 8, 7, 28, 17],
 [17, 22, 8, 22, 17, 

In [49]:
import numpy as np

REEEEEEEEEEEEEEEEEEEEEE = np.array(labels[:10]).reshape([10, 1])

In [51]:
n_cells = 10
n_cell_types = 28
n_labels = n_cells * n_cell_types
scdl = model._make_data_loader(adata=adata,
                               indices=list(range(n_cells)),
                               batch_size=100)
batch = next(scdl.__iter__())

labels = list(adata.obs._scvi_labels)[:n_labels]
# Not quite clear why we couldn't get this to work
# We're thinking that we somehow need to get a 10 x 28 tensor. The splitting above (label_2) didn't really work.
attr = saliency.attribute((batch['X'], batch['batch']),
                          target=torch.Tensor(REEEEEEEEEEEEEEEEEEEEEE)
                         # target=tuple([labels_2, labels_2])
                         )

AssertionError: Tensor target dimension torch.Size([10, 1]) is not valid. torch.Size([10, 28])

## Collapsing into required output structure

The output that we desire is genes x cell types. Currently we have number_batches * (28 cell types * batch size) * number_genes and this 14 times for the 14 datasets.

In [32]:
# 1. Collapse datasets attributions into a single tuple of attributions
# 2. Collapse the batches and cell types into a single vector
# 3. Reshape into genes * cell types