# Explicability with SHAP values

In [52]:
import shap
import joblib
import pandas as pd
import re
import numpy as np
import seaborn as sns

import mygene
import matplotlib.pyplot as plt
from matplotlib.pyplot import savefig, close

We will be using LinearExplainer


Computes SHAP values for a linear model, optionally accounting for inter-feature correlations.

This computes the SHAP values for a linear model and can account for the correlations among the input features. Assuming features are independent leads to interventional SHAP values which for a linear model are coef[i] * (x[i] - X.mean(0)[i]) for the ith feature. If instead we account for correlations then we prevent any problems arising from colinearity and share credit among correlated features. Accounting for correlations can be computationally challenging, but LinearExplainer uses sampling to estimate a transform that can then be applied to explain any prediction of the mode

In [53]:
# Function to convert probes names

def convert_gene_name(ensembl_ids, ticks=False):
    mg = mygene.MyGeneInfo()
    gene_symbol_names = []
    for id in ensembl_ids:
        if ticks:
            id = id.get_text()
            # Note: this requires an internet connection
        query = mg.query(id.split(".")[0],
                         scopes='ensembl.gene', fields='symbol',
                         species='human', returnall=True,
                         as_datafarame=True, size=1)
        try:
            if query['hits']:
                if len(query['hits'][0]['symbol']) < 3:
                    query_add = mg.query(id.split(".")[0],
                                         scopes='ensembl.gene', fields='name',
                                         species='human', returnall=True,
                                         as_datafarame=True, size=1)
                    gene_symbol_names.append(query['hits'][0]['symbol'] + "\n(" +
                                             query_add['hits'][0]['name'] + ")")
                else:
                    gene_symbol_names.append(query['hits'][0]['symbol'])
            else:
                gene_symbol_names.append(id)
        except KeyError:
            gene_symbol_names.append(id)
    return gene_symbol_names


def normalize_score(shap_values_df): 
    """Normalize the score"""
    mean_abs_scores = shap_values_df.abs().mean()

    normalized_scores = (mean_abs_scores / mean_abs_scores.sum()) * 100

    normalized_scores_df = pd.DataFrame({
        'column_name': mean_abs_scores.index,
        'score': mean_abs_scores.values,
        'normalized_score': normalized_scores.values
    })
    return normalized_scores_df[normalized_scores_df.score != 0]




In [54]:
# Load models
lung = joblib.load("../../results/4.gexp_models/lung/13.pipeline_lung.pkl")
ovary = joblib.load("../../results/4.gexp_models/ovary/13.pipeline_ovary.pkl")

Trying to unpickle estimator QuantileTransformer from version 1.4.1.post1 when using version 1.3.0. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
Trying to unpickle estimator ElasticNet from version 1.4.1.post1 when using version 1.3.0. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
Trying to unpickle estimator Pipeline from version 1.4.1.post1 when using version 1.3.0. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [55]:
def compute_shap_values(pipeline, tissue):
    """Compute SHAP values using linearExplainer"""
    coef = pd.read_csv(f"../../results/4.gexp_models/{tissue}/13.EN_feature_importance_{tissue}.csv")
    
    # Load Lung data
    gexp = pd.read_csv(f"../../data/X_coding_{tissue}_log2.csv", header=0, index_col=0)
    test_set = pd.read_csv(f"../../metadata/{tissue}_test_metadata.csv")
    metadata = pd.read_csv(f"../../metadata/gene_expression_metadata/metadata_{tissue}.tsv", sep = "\t")


    # Parse metadata and subset test set
    gexp.columns = [re.sub(r"-SM-.*", "", i) for i in gexp.columns.to_list()]

    metadata_train = metadata.query("tissue_sample_id not in @test_set.sample_id")
    metadata_test = metadata.query("tissue_sample_id in @test_set.sample_id")
    
    # Feature selection
    gexp = gexp[coef.genes]
    ### Subset the test set for gexp data
    gexp_train = gexp[gexp.index.isin(metadata_train.tissue_sample_id)]
    gexp_test = gexp[gexp.index.isin(metadata_test.tissue_sample_id)]

    # Normalize the data 
    qt_norm = pipeline[0]
    gexp_train_norm = qt_norm.transform(gexp_train)
    gexp_test_norm = qt_norm.transform(gexp_test)


    explainer = shap.LinearExplainer(model = pipeline[1], masker = shap.maskers.Independent(gexp_train_norm))
    shap_values = explainer.shap_values(gexp_test_norm)

    return(shap_values, pd.DataFrame(shap_values, columns=gexp_train.columns.tolist()), gexp_test_norm)

In [56]:
shap_values_lung, shap_values_lung_df, meth_test_norm_lung = compute_shap_values(lung, "lung")
shap_values_ovary, shap_values_ovary_df,  meth_test_norm_ovary = compute_shap_values(ovary, "ovary")

In [57]:
## Convert gene names 
feature_names_lung = convert_gene_name(shap_values_lung_df.columns.tolist())
feature_names_ovary = convert_gene_name(shap_values_ovary_df.columns.tolist())

In [59]:
shap_values_lung_df_norm = normalize_score(shap_values_lung_df)
shap_values_ovary_df_norm = normalize_score(shap_values_ovary_df)



In [75]:

shap.summary_plot(
    shap_values_lung, meth_test_norm_lung, feature_names=feature_names_lung, max_display = 10, show = False
)
savefig("../../aging_notes/figures/6.feature_analysis/_lung_shap_feature_plot_gene_expression.pdf")
close()


shap.summary_plot(
    shap_values_ovary, meth_test_norm_ovary, feature_names=feature_names_ovary, max_display = 10, show = False
)
savefig("../../aging_notes/figures/6.feature_analysis/_ovary_shap_feature_violin_bar.pdf")
close()