# scANVI Explainer demo

date: 12-09-2024

author: Martin Proks

In [None]:
!which pip

In [None]:
import shap
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from anndata import AnnData
from scvi.hub import HubModel
from scanvi_explainer import SCANVIDeep
from sklearn.model_selection import train_test_split

import warnings

from numba.core.errors import NumbaDeprecationWarning
warnings.simplefilter('ignore', category=NumbaDeprecationWarning)


def train_test_group_split(adata: AnnData, groupby: str):
    """
    Function to split anndata object 80/20 per group in format
    required for SCANVIDeep explainer.
    """
    groups = adata.obs.groupby(groupby)
    train, test = [], []
    for _, cells in groups.groups.items():
        train_test = train_test_split(cells.values, test_size=0.2)
        
        train.append(train_test[0])
        test.append(train_test[1])

    train, test = np.concatenate(train), np.concatenate(test)

    X_train = {
        'X': torch.from_numpy(adata[train].layers['counts'].A).type(torch.DoubleTensor),
        'batch': torch.from_numpy(adata[train].obs.batch.cat.codes.values[:, np.newaxis]),
        'labels': torch.from_numpy(adata[train].obs.ct.cat.codes.values[:, np.newaxis])
    }

    X_test = {
        'X': torch.from_numpy(adata[test].layers['counts'].A).type(torch.DoubleTensor),
        'batch': torch.from_numpy(adata[test].obs.batch.cat.codes.values[:, np.newaxis]),
        'labels': torch.from_numpy(adata[test].obs.ct.cat.codes.values[:, np.newaxis])
    }
    
    return X_train, X_test


def feature_plot(X_test, shap_values: np.ndarray, classes: pd.Index, features: np.ndarray, subset: bool = False):
    """
    Prints feature contribution (absolute mean SHAP value) for each cell type (top 10).

    X_test: 
        test dataset
    shap_values: 
        SHAP values
    classes: 
        list of classifiers (cell types in this case)
    features: 
        list of genes (HVGs)
    subset: 
        If True calculate contribution by subsetting for test cells which belong to that particual classifier
        Else Be generic and return contributing features even when testing set has different cell types
    """
    fig, ax = plt.subplots(8, 2, sharex=False, figsize=[20, 40])
    
    for idx, ct in enumerate(classes):
    
        shaps = pd.DataFrame(shap_values[idx], columns=features)

        if subset:
            shaps['ct'] = X_test['labels']
            shaps = shaps.query('ct == @idx').iloc[:, :-1]

            tmp_avg = shaps\
                .mean(axis=0)\
                .sort_values(ascending=False)\
                .reset_index()\
                .rename(columns={'index':'feature',0:'weight'})\
            
            positive = tmp_avg.query('weight > 0').head(5)
            negative = tmp_avg.query('weight < 0').tail(5)

            avg = pd.concat([positive, negative])
            title = f'Mean(SHAP value average importance for: {ct}'
            
        else:
            avg = shaps\
                .abs()\
                .mean(axis=0)\
                .sort_values(ascending=False)\
                .reset_index()\
                .rename(columns={'index':'feature',0:'weight'})\
                .query('weight > 0')\
                .head(10)
            title = f'Mean(|SHAP value|) average importance for: {ct}'
    
        sns.barplot(x='weight', y='feature', data=avg, ax=ax[idx // 2, idx % 2])
        ax[idx // 2, idx % 2].set_title(title)

In [None]:
hmo = HubModel.pull_from_huggingface_hub(
    repo_name="brickmanlab/mouse-scanvi",
    cache_dir="/tmp/mouse_scanvi",
    revision="v1.0",
)

In [None]:
lvae = hmo.model
lvae

In [None]:
background, test = train_test_group_split(lvae.adata, groupby='ct')

In [None]:
e = SCANVIDeep(lvae.module, background)

In [None]:
shap_values = e.shap_values(test)

In [None]:
shap.summary_plot(
    shap_values, 
    test['X'], 
    feature_names=lvae.adata.var_names, 
    class_names=lvae.adata.obs.ct.cat.categories
)

In [None]:
feature_plot(test, shap_values, classes=lvae.adata.obs.ct.cat.categories, features=lvae.adata.var_names, subset=True)