In [None]:
"""Initial analysis of shap values behavior."""
# pylint: disable=redefined-outer-name, expression-not-assigned, import-error, not-callable, pointless-statement
from pathlib import Path

import numpy as np
import plotly.express as plt
import plotly.io as pio
import shap
from IPython.display import display

pio.renderers.default = "notebook"

from epi_ml.core import metadata
from epi_ml.core.analysis import SHAP_Handler
from epi_ml.core.model_pytorch import LightningDenseClassifier

In [None]:
%matplotlib inline

In [None]:
home = Path("/home/local/USHERBROOKE/rabj2301/Projects")
input_dir = home / "epilap/input"
metadata_path = input_dir / "metadata/merge_EpiAtlas_allmetadata-v11-mod.json"

output = home / "epilap/output"
logdir = output / "logs/hg38_2022-epiatlas/shap"
model_dir = output / "models/split0"

my_meta = metadata.Metadata(metadata_path)
target_mapping = LightningDenseClassifier.restore_model(model_dir).mapping

In [None]:
my_meta.remove_category_subsets("track_type", ["raw", "fc", "Unique_raw"])
# len(my_meta)
my_meta.remove_small_classes(10, "assay")
my_meta.display_labels("assay")
my_meta.display_labels("track_type")

In [None]:
display(target_mapping)
classes = target_mapping.values()
print(classes, list(classes))

In [None]:
eval_shaps_1 = SHAP_Handler.load_from_pickle(
    path=str(logdir / "shap_values_background_effect_test_2022-12-12_18-54-39.pickle")
)
eval_shaps_2 = SHAP_Handler.load_from_pickle(
    path=str(logdir / "shap_values_background_effect_test_2022-12-12_18-57-56.pickle")
)

In [None]:
eval_shaps_1.keys()

In [None]:
labels = [my_meta[md5]["assay"] for md5 in eval_shaps_1["ids"]]
print(labels[0:-1:5])

In [None]:
eval_shaps_1["shap"][0].shape

In [None]:
shap.summary_plot(eval_shaps_1["shap"], max_display=300, class_names=list(classes))

In [None]:
shap.summary_plot(eval_shaps_2["shap"], max_display=300, class_names=list(classes))

In [None]:
def average_impact(shap_values_matrices):
    """Return average absolute shap values."""
    shap_abs = np.zeros(shap_values_matrices[0].shape)
    for matrix in shap_values_matrices:
        shap_abs += np.absolute(matrix)
    shap_abs /= len(shap_values_matrices)
    return shap_abs

In [None]:
def n_most_important_features(sample_shaps, n):
    """Return features with highest shap values."""
    avg_shaps = sample_shaps.sum(axis=0)
    return np.flip(np.argsort(avg_shaps))[:n]

In [None]:
total_avg_1 = average_impact(eval_shaps_1["shap"])
total_avg_2 = average_impact(eval_shaps_2["shap"])

n = 1000
most_important_features = set(n_most_important_features(total_avg_1, n)) & set(
    n_most_important_features(total_avg_2, n)
)
print(len(most_important_features))

In [None]:
print(classes)
for i, (matrix1, matrix2) in enumerate(zip(eval_shaps_1["shap"], eval_shaps_2["shap"])):
    most_important_class_features = set(n_most_important_features(matrix1, 100)) & set(
        n_most_important_features(matrix2, 100)
    )
    print(i, len(most_important_class_features))

In [None]:
def box_plot(avg_impact):
    """Print a box plot"""
    plt.box(y=avg_impact.sum(axis=0)).show()

In [None]:
box_plot(total_avg_1)
box_plot(total_avg_2)