In [None]:
#!/usr/bin/python
from __future__ import absolute_import

import sys
sys.path.append("../../")

import numpy as np
#from tensorflow import keras
import os
import argparse
import copy
import random
import time
import torch

from shapley.apps import Label, InclusionExclusion
from shapley.loader import FashionMnist, MNIST, Flower, CIFAR
from shapley.measures import KNN_Shapley, KNN_LOO, G_Shapley, LOO, TMC_Shapley, FastWeightedShapley
from shapley.utils.plotter import LabelPlotter, PoisoningPlotter

%matplotlib inline
# import global vars
import init 

seed = 1023
init.set_seed(seed)
torch.backends.cudnn.enabled = False

assert torch.cuda.is_available(), "CUDA not availabel"

In [None]:
num_train = 50000
num_test = 10000
datasource = "cifar"
model_family = "resnet18"
app_name = "inclusion_exclusion"
model_checkpoint_dir = f"../../checkpoints/{app_name}/{datasource}"

if not os.path.exists(model_checkpoint_dir):
    os.makedirs(model_checkpoint_dir)


#loader = MNIST(num_train=num, one_hot=False, shuffle=True, by_label=True)
loader = CIFAR(num_train=num_train,num_test = num_test, all_classes=True, seed = seed)
# loader = Flower(num_train=num)
X_data, y_data, X_test_data, y_test_data = loader.prepare_data()
print(X_test_data.shape, X_data.shape)

In [None]:
start = time.time()

measure_fast_weighted_shapley = FastWeightedShapley(K = init.K, model_checkpoint_dir=model_checkpoint_dir)
app = InclusionExclusion(X=X_data
                        , y=y_data
                        , X_test=X_test_data
                        , y_test=y_test_data
                        , method_name = "fw_shapley"
                        , app_name = app_name
                        , dataset = "cifar"
                        , model_family=model_family
                        , model_checkpoint_dir=model_checkpoint_dir)
exclusion_scores = app.run(measure_fast_weighted_shapley)
exclusion_scores

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
tmp = torch.concat(measure_fast_weighted_shapley.utility).to("cpu")
# for _ in range(100)
indexes = torch.randperm(tmp.shape[0])[:10000]
sns.kdeplot(tmp[indexes])

### KNN Shapley

In [None]:
start = time.time()

measure_KNN = KNN_Shapley(K=10)
app = InclusionExclusion(X=X_data
                        , y=y_data
                        , X_test=X_test_data
                        , y_test=y_test_data
                        , method_name = "knn_shapley"
                        , app_name = app_name
                        , dataset = "cifar"
                        , model_family=model_family
                        , model_checkpoint_dir=model_checkpoint_dir)
exclusion_scores = app.run(measure_KNN)


In [None]:
scores = np.load("../../results/inclusion_exclusion/cifar/knn_shapley_exc_scores.npy")
scores

In [None]:
import matplotlib.pyplot as plt
import numpy as np
shap_scores = np.load("../../results/inclusion_exclusion/cifar/fw_shapley_shap_scores.npy")[:,0]
plt.hist(shap_scores)

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image

shap_scores = np.load("../../results/inclusion_exclusion/cifar/fw_shapley_shap_scores.npy")[:,0]
indexes = np.argsort(shap_scores).copy()


plt.figure(figsize=(20, 20))

print(np.bincount(y_data[indexes[:200],0]))
for i, index in enumerate(indexes[:100]):
    inst_x = X_data[index]
    # inst_x = inst_x.reshape(28, 28)
    # inst_x = Image.fromarray(np.uint8(inst_x))
    plt.subplot(10, 5, i+1)
    plt.imshow(inst_x.transpose(1,2,0))
    plt.title(y_data[index])

In [None]:
print(np.bincount(y_data[indexes[:100],0]))