In [None]:
import os

from sklearn.model_selection import train_test_split

import attack as wfpattack
from argparser import shap_parser
from dataset import TraceDataset
import joblib
import torch
import shap
from tqdm import tqdm

# prase arguments
# args = shap_parser().parse_args()
random_seed = 11
ds_root = "./data"
cache_root = "./data/cache"
# attack_name = args.attack
# note = args.note
# train = args.train
# dataset = args.dataset
# model_epoch = args.epoch
attack_name = "RF"
note = "20240307base"
train = "undefend"
dataset = "undefend"
model_epoch = 0
model_dir = f"data/dump/{note}/{attack_name}/train_{dataset}_d{train}"
assert os.path.exists(model_dir), model_dir

print("Loading test dataset...")
test_ds = TraceDataset(dataset, ds_root)
test_ds.load_defended_by_name(train)
num_classes = test_ds.num_classes()
attack: wfpattack.DNNAttack = wfpattack.get_attack(attack_name)(10000, num_classes, 0)


ds_len = len(test_ds)

_, evaluate_slice = train_test_split(
    [i for i in range(ds_len)], test_size=0.2, random_state=random_seed
)


def get_cached_data(ds, attack, evaluate_slice):
    global cache_root
    cache_path = os.path.join(cache_root, f"{attack.name}_{ds.get_hash()}.pkl")
    if os.path.exists(cache_path):
        all_data = joblib.load(cache_path)
        return {
            "traces": all_data["traces"][evaluate_slice],
            "labels": all_data["labels"][evaluate_slice],
        }
    else:
        data = attack.data_preprocess(*ds[evaluate_slice])
        return data


print("Preparing test data...")

test_data = get_cached_data(test_ds, attack, evaluate_slice)
test_features, test_labels = test_data["traces"], test_data["labels"]

attack.init_model()
print("Evaluating...")
result = attack.evaluate(
    test_features,
    test_labels,
    load_dir=model_dir,
    epoch=model_epoch,
    data=True,
)
assert isinstance(result, tuple) and len(result) == 3
metrics_dict, y_true, y_pred = result
y_true = y_true.argmax(axis=1)
y_pred = y_pred.argmax(axis=1)
good_features = test_features[y_true == y_pred]
good_labels = test_labels[y_true == y_pred]
good_labels_int = good_labels.argmax(axis=1)

In [None]:
bg_traces = []
test_traces = {}
for label in range(num_classes):
    traces = good_features[good_labels_int == label]
    bg_traces.append(traces[:2])
    test_traces[label] = traces[2:]
bg_traces = torch.cat(bg_traces, dim=0)

In [None]:
import shap

device = attack.device
model = attack.model
model.eval()
explainer = shap.DeepExplainer(model, bg_traces.to(device))

In [None]:
explainer.shap_values(test_traces[0][:1, :].to(device))

In [None]:
shap_values = []
for label in tqdm(range(num_classes)):
    cur_shap_values = explainer.shap_values(
        test_traces[label], ranked_outputs=1, output_rank_order="max"
    )[0][0].squeeze()
    shap_values.append(cur_shap_values.sum(axis=0).sum(axis=0))