In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")

import seaborn as sns
import pandas as pd
import torch
import numpy as np
from interpretability import sample_from_hist, get_hist
import random
import matplotlib.pyplot as plt

torch.manual_seed(4444)
np.random.seed(4444)
random.seed(4444)

plt.rcParams.update({'font.size': 15})

In [None]:
def shapley_value(x, b, i, func, hist=None):
    from itertools import chain, combinations
    from math import factorial
    def powerset(iterable):
        """
        powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
        """
        s = list(iterable)
        return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

    n = len(x)
    l = list(range(n))
    del l[i]
    sv = 0

    if hist is not None:
        b = torch.tensor([sample_from_hist(hist[i]) for i in range(n)], dtype=torch.float32).flatten()

    for S in powerset(l):
        S = np.array(S).flatten()
        v1, v2 = b.clone(), b.clone()
        v1[i] = x[i]
        if len(S) != 0:
            v1[S] = x[S]
            v2[S] = x[S]
        const = factorial(len(S))*factorial(n-len(S)-1)/factorial(n)
        sv += const*(func(v1)-func(v2))
    return abs(sv)

In [None]:
def func1():
    means = [1,2,3,4,1]
    phi = torch.tensor(means, dtype=torch.float32)
    def foo(x):
        return torch.sum(phi * x)
    return foo

def func2():
    def foo(x):
        return torch.sum(torch.exp(-1/2* torch.square(x)))
    return foo

In [None]:
def generate1():
    locs = [1,1,1,1,8]
    x = np.concatenate([np.random.normal(loc=i, scale=1, size=(1,1)) for i in locs], axis=1)
    x = torch.tensor(x, dtype=torch.float32).flatten()
    return x

def generate2():
    a = torch.normal(mean=10, std=1, size=(1,1))
    b = torch.normal(mean=-10, std=1, size=(1,1))
    sel = torch.rand(1, 1) < 0.5
    x = torch.where(sel, a, b)
    x = torch.cat([x, torch.normal(mean=0, std=1, size=(1,1))], dim=1)
    return x.flatten()

In [None]:
def run(generate, func, reps=10, n=100, apply_center=False, take_hist=True, take_mean=True, alpha=1, ax=None, title=None):
    m = generate().shape[0]

    res = []
    for _ in range(reps):
        aux2 = []
        f = func()
        x = torch.stack([generate() for _ in range(n)])

        if apply_center:
            center = torch.mean(x, dim=0)
            x = x - center

        hist_input = None
        if take_hist:
            hist_input = [get_hist(x[:,j], alpha=alpha) for j in range(m)]

        b = torch.zeros((x.shape[1])) # ZEROS
        if take_mean:
            b = torch.mean(x, dim=0) # MEAN

        for j in range(n):
            aux = [shapley_value(x[j], b, i, f, hist=hist_input).item() for i in range(m)]
            aux2.append(aux)
        aux2 = np.mean(np.array(aux2), axis=0)
        res.append(aux2)
    res = np.array(res)
    data = {"x": res.flatten(), "class": np.tile(range(m), reps)+1}
    ax.set(xlabel="Group", ylabel="Average absolute Shapley Value")
    ax = sns.boxplot(x="class", y="x", data=data, ax=ax).set_title(title)
    plt.title(title)

## Experiment 1

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(6, 5), constrained_layout=True)
run(generate1, func1, take_hist=False, take_mean=False, ax=axs, title="Zero baseline")
plt.savefig("base_zeros.png", dpi=100)

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(13, 10), constrained_layout=True)
run(generate1, func1, take_hist=False, take_mean=True, ax=axs[0,0], title="Mean baseline")
run(generate1, func1, alpha=1, ax=axs[0,1], title=r"Distribution baseline, $\alpha = 1$")
run(generate1, func1, alpha=0, ax=axs[1,0], title=r"Uniform baseline, $\alpha = 0$")
run(generate1, func1, alpha=-1, ax=axs[1,1], title=r"Inverse proportional baseline, $\alpha = -1$")
plt.savefig("base.png", dpi=100)

## Experiment 2

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(12, 9), constrained_layout=True)
run(generate2, func2, apply_center=True, take_hist=False, take_mean=True, ax=axs[0,0], title="Mean baseline")
run(generate2, func2, apply_center=True, alpha=1, ax=axs[0,1], title=r"Distribution baseline, $\alpha = 1$")
run(generate2, func2, apply_center=True, alpha=0, ax=axs[1,0], title=r"Uniform baseline, $\alpha = 0$")
run(generate2, func2, apply_center=True, alpha=-1, ax=axs[1,1], title=r"Inverse proportional baseline, $\alpha = -1$")
plt.savefig("base2.png", dpi=100)