# Shapley sampling

- Theoretical background: https://www.sciencedirect.com/science/article/abs/pii/S0305054808000804
- API docs: https://captum.ai/api/shapley_value_sampling.html

In [4]:
import torch
import scvi
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

sc.set_figure_params(dpi=100, frameon=False, color_map='Reds', facecolor=None)
sc.logging.print_header()

Global seed set to 0


scanpy==1.9.1 anndata==0.8.0 umap==0.5.3 numpy==1.23.3 scipy==1.9.1 pandas==1.4.4 scikit-learn==1.1.2 statsmodels==0.13.2 pynndescent==0.5.7


## load model

In [5]:
hlca_path = '../datasets/hlca.h5ad'
adata = sc.read(hlca_path)
adata

AnnData object with n_obs × n_vars = 584944 × 2000
    obs: 'sample', 'original_celltype_ann', 'study_long', 'study', 'last_author_PI', 'subject_ID', 'subject_ID_as_published', 'pre_or_postnatal', 'age_in_years', 'age_range', 'sex', 'ethnicity', 'mixed_ethnicity', 'smoking_status', 'smoking_history', 'BMI', 'known_lung_disease', 'condition', 'subject_type', 'cause_of_death', 'sample_type', 'anatomical_region_coarse', 'anatomical_region_detailed', 'tissue_dissociation_protocol', 'cells_or_nuclei', 'single_cell_platform', "3'_or_5'", 'enrichment', 'sequencing_platform', 'reference_genome_coarse', 'ensembl_release_reference_genome', 'cell_ranger_version', 'disease_status', 'fresh_or_frozen', 'cultured', 'cell_viability_%', 'comments', 'Processing_site', 'dataset', 'anatomical_region_level_1', 'anatomical_region_level_2', 'anatomical_region_level_3', 'anatomical_region_highest_res', 'age', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'ann_highest_res', 'ann_ne

In [6]:
model = scvi.model.SCANVI.load('../models/scanvi_model/', adata)
model

[34mINFO    [0m File ..[35m/models/scanvi_model/[0m[95mmodel.pt[0m already downloaded                                                   






## Applying shapley sampling

### Measure against classification probabilities

In [7]:
from captum.attr import visualization as viz
from captum.attr import ShapleyValueSampling

In [8]:
shapley_sampling = ShapleyValueSampling(model.module.classify)

In [20]:
n_cells = 10
n_cell_types = 28
n_labels = n_cells * n_cell_types
scdl = model._make_data_loader(adata=adata,
                               indices=list(range(n_cells)),
                               batch_size=100)
batch = next(scdl.__iter__())

labels = list(adata.obs._scvi_labels)[:n_labels]
# TODO The target is not supposed to be fixed -> need to find a way to input all classes.
# It kinda expects 28, but list(range(28)) didn't work either
attr = shapley_sampling.attribute((batch['X'], batch['batch']),
                                  target=1)

## Collapsing into required output structure

The output that we desire is genes x cell types. Currently we have number_batches * (28 cell types * batch size) * number_genes and this 14 times for the 14 datasets.

In [32]:
# 1. Collapse datasets attributions into a single tuple of attributions
# 2. Collapse the batches and cell types into a single vector
# 3. Reshape into genes * cell types