In [None]:
#!/usr/bin/python
from __future__ import absolute_import

import sys
sys.path.append("../../")

import numpy as np
#from tensorflow import keras
import pickle
import argparse
import copy
import random
import time
import torch
import pickle

from shapley.apps import Label, Poisoning
from shapley.loader import FashionMnist, MNIST, Flower, CIFAR
from shapley.measures import KNN_Shapley, KNN_LOO, G_Shapley, LOO, TMC_Shapley, FastWeightedShapley
from shapley.utils.plotter import LabelPlotter, PoisoningPlotter

%matplotlib inline
# import global vars
from init import set_seed

seed = 1023
set_seed(seed)
torch.backends.cudnn.enabled = False

In [None]:
torch.cuda.is_available()

In [None]:
num_train = 60000
num_test = 10000
datasource = "cifar"
model_family = "resnet18"
app_name = "noisy_label"
model_checkpoint_dir = f"../../checkpoints/{app_name}/{datasource}"


#loader = MNIST(num_train=num, one_hot=False, shuffle=True, by_label=True)
loader = CIFAR(num_train=num_train,num_test = num_test, all_classes=True, seed = seed)
# loader = Flower(num_train=num)
X_data, y_data, X_test_data, y_test_data = loader.prepare_data()
print(X_test_data.shape, X_data.shape)

In [None]:
app = Label(X_data, y_data, X_test_data, y_test_data, model_family=model_family, model_checkpoint_dir=model_checkpoint_dir)
with open("./app_label.pkl", 'wb') as outp:
    pickle.dump(app, outp, pickle.HIGHEST_PROTOCOL)

In [None]:
# FW Shapley (1,1)
start = time.time()

measure_fast_weighted_shapley = FastWeightedShapley(K = 5, model_checkpoint_dir=model_checkpoint_dir, alpha=1, beta=1)

res = app.run(measure_fast_weighted_shapley)
res_fwshapley = np.zeros(len(res))
for i in range(len(res)): res_fwshapley[i] = res[i][0]
print('Fast Weighted Shapley compute time: ',time.time()-start)
np.save(f"../../results/{app_name}/{datasource}/fw_shapley_results_1_1.npy", res)

In [None]:
# FW Shapley (1,1)
start = time.time()

measure_fast_weighted_shapley = FastWeightedShapley(K = 5, model_checkpoint_dir=model_checkpoint_dir, alpha=4, beta=1)

res = app.run(measure_fast_weighted_shapley)
res_fwshapley = np.zeros(len(res))
for i in range(len(res)): res_fwshapley[i] = res[i][0]
print('Fast Weighted Shapley compute time: ',time.time()-start)
np.save(f"../../results/{app_name}/{datasource}/fw_shapley_results_4_1.npy", res)

In [None]:
# FW Shapley (1,1)
start = time.time()

measure_fast_weighted_shapley = FastWeightedShapley(K = 5, model_checkpoint_dir=model_checkpoint_dir, alpha=8, beta=1)

res = app.run(measure_fast_weighted_shapley)
res_fwshapley = np.zeros(len(res))
for i in range(len(res)): res_fwshapley[i] = res[i][0]
print('Fast Weighted Shapley compute time: ',time.time()-start)
np.save(f"../../results/{app_name}/{datasource}/fw_shapley_results_8_1.npy", res)

In [None]:
start = time.time()

measure_fast_weighted_shapley = FastWeightedShapley(K = 5, model_checkpoint_dir=model_checkpoint_dir, alpha=16, beta=1)

res = app.run(measure_fast_weighted_shapley)
res_fwshapley = np.zeros(len(res))
for i in range(len(res)): res_fwshapley[i] = res[i][0]
print('Fast Weighted Shapley compute time: ',time.time()-start)
np.save(f"../../results/{app_name}/{datasource}/fw_shapley_results_16_1.npy", res)

### Get data improtance using other measures

In [None]:
# if 'app' not in locals() or not app:
#     with open('./app_label.pkl', 'rb') as inp:
#         app = pickle.load(inp)

measure_KNN = KNN_Shapley(K=5)
measure_KNNLOO = KNN_LOO(K=5)
measure_gshap = G_Shapley()
measure_tmc = TMC_Shapley()
measure_LOO = LOO()

dir_path = f"../../results/{app_name}/{datasource}"

start = time.time()
res_knn = app.run(measure_KNN)
print('KNN Shapley compute time: ',time.time()-start)
np.save(f"{dir_path}/knn_shapley_results.npy", res_knn)

start = time.time()
res_knnloo = app.run(measure_KNNLOO)
print('KNN LOO compute time: ',time.time()-start)
np.save(f"{dir_path}/res_knnloo.npy", res_knnloo)

start = time.time()
res_gshap = app.run(measure_gshap)
print('GShapley compute time: ',time.time()-start)
np.save(f"{dir_path}/g_shapley_results.npy", res_gshap)

# start = time.time()
# res_tmc = app.run(measure_tmc)
# print('TMC Shapley compute time: ',time.time()-start)
# np.save(f"{dir_path}/tmc_results.npy", res_tmc)

# start = time.time()
# res_loo = app.run(measure_LOO)
# print('LOO compute time: ',time.time()-start)


### Plot importances

In [None]:
import matplotlib.pyplot as plt
plt.hist(res_fwshapley)

In [None]:
with open('./app_label.pkl', 'rb') as inp:
    app = pickle.load(inp)

In [None]:
# res_fwshapley = res[:,0]
dir_path = f"../../results/{app_name}/{datasource}"
res_fwshapley_16_1 = np.load(f"{dir_path}/fw_shapley_results_16_1.npy")[:,0]
res_fwshapley_1_1 = np.load(f"{dir_path}/fw_shapley_results_1_1.npy")[:,0]
res_fwshapley_4_1 = np.load(f"{dir_path}/fw_shapley_results_4_1.npy")[:,0]
res_fwshapley_8_1 = np.load(f"{dir_path}/fw_shapley_results_8_1.npy")[:,0]
# res_knn = np.load(f"{dir_path}/knn_shapley_results.npy")
# res_gshap = np.load(f"{dir_path}/g_shapley_results.npy")
# res_knnloo = np.load(f"{dir_path}/res_knnloo.npy")

LabelPlotter(app, 
                  ('KNN-Shapley',res_knn),
                #   ('G-Shapley',res_gshap),
                #   ('KNN-LOO',res_knnloo),
                 
                 ('FW-Shapley (1,1)', res_fwshapley_1_1),
                 ('FW-Shapley (4,1)', res_fwshapley_4_1),
                 ('FW-Shapley (8,1)', res_fwshapley_8_1),
                 ('FW-Shapley (16,1)', res_fwshapley_16_1)
                 ).plot()

In [None]:
res_fwshapley

In [None]:
# High importance images
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image

indexes = np.argsort(res_fwshapley)


plt.figure(figsize=(20, 20))


for i, index in enumerate(indexes[:20]):
    inst_x = X_data[index]
    inst_x = inst_x.reshape(28, 28)
    # inst_x = Image.fromarray(np.uint8(inst_x))
    plt.subplot(10, 5, i+1)
    plt.imshow(inst_x, cmap='gray')
    plt.title(y_data[index])

plt.savefig(f"./figs/{datasource}/fw-shapley.png")

In [None]:
# High importance images

import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image

indexes = np.argsort(-res_knn)
plt.figure(figsize=(20, 20))
for i, index in enumerate(indexes[:20]):
    inst_x = X_data[index]
    inst_x = inst_x.reshape(28, 28)
    # inst_x = Image.fromarray(np.uint8(inst_x))
    plt.subplot(10, 5, i+1)
    plt.imshow(inst_x, cmap='gray')
    plt.title(y_data[index])
