以下の3つそれぞれのモデルの，repair set/test set に対する rACC (newAcc/oldAcc), repair rate, break rateを算出する．
- retrained model
- neuronに対するrepaired model
- weightに対するrepaired model

In [8]:
import os, sys, math
from tqdm import tqdm
from collections import defaultdict
import numpy as np
import argparse
import torch
import pickle
import evaluate
met_acc = evaluate.load("accuracy")
met_f1 = evaluate.load("f1")
from datasets import load_from_disk
from transformers import DefaultDataCollator, ViTForImageClassification, Trainer
from utils.helper import get_device
from utils.vit_util import processor, transforms, compute_metrics, transforms_c100, localize_weights
from utils.constant import ViTExperiment
device = get_device()

ds_name = "c100"
k = 0

splits = ["repair", "test"]

pretrained_dir = getattr(ViTExperiment, ds_name).OUTPUT_DIR.format(k=k)
retrained_dir = os.path.join(pretrained_dir, "retraining_with_repair_set")
neuron_repair_dir = os.path.join(pretrained_dir, "repair_neuron_by_de")
weight_repair_dir = os.path.join(pretrained_dir, "repair_weight_by_de")

Device: cuda


In [35]:
# original vs retrained result
ori_res_dir = os.path.join(pretrained_dir, "pred_results", "PredictionOutput")
retrain_res_dir = os.path.join(retrained_dir, "pred_results", "PredictionOutput")
for s in splits:
    print(f"Split: {s}")
    # original
    ori_filename = os.path.join(ori_res_dir, f"{s}_pred.pkl")
    with open(ori_filename, "rb") as f:
        ori_pred = pickle.load(f)
    ori_pred_labels = ori_pred.predictions.argmax(-1)
    ori_true_labels = ori_pred.label_ids
    ori_acc = met_acc.compute(predictions=ori_pred_labels, references=ori_true_labels)
    print("ori_acc", ori_acc)
    # retrained
    retrained_filename = os.path.join(retrain_res_dir, f"{s}_pred.pkl")
    with open(retrained_filename, "rb") as f:
        retrained_pred = pickle.load(f)
    retrained_pred_labels = retrained_pred.predictions.argmax(-1)
    retrained_true_labels = retrained_pred.label_ids
    retrained_acc = met_acc.compute(predictions=retrained_pred_labels, references=retrained_true_labels)
    print("retrained_acc", retrained_acc)
    # ori_true_labelsとretrained_true_labelsは同じはず
    assert all(ori_true_labels == retrained_true_labels)
    # relative acc
    racc = 100*retrained_acc["accuracy"]/ori_acc["accuracy"]
    print(f"relative acc = {racc:.1f} %")
    # oriが不正解だったサンプルをretrainedが正解できた割合
    correct = (ori_pred_labels == ori_true_labels)
    ori_incorrect = np.where(correct == False)[0]
    retrained_correct = (retrained_pred_labels == retrained_true_labels)
    retrained_correct_in_ori_incorrect = retrained_correct[ori_incorrect]
    rr = 100*retrained_correct_in_ori_incorrect.mean()
    print(f"repair rate = {rr:.1f} %")
    # oriが正解だったサンプルをretrainedが正解できなかった割合
    ori_correct = np.where(correct == True)[0]
    retrained_incorrect = (retrained_pred_labels != retrained_true_labels)
    retrained_incorrect_in_ori_correct = retrained_incorrect[ori_correct]
    br = 100*retrained_incorrect_in_ori_correct.mean()
    print(f"break rate = {br:.1f} %")

Split: repair
ori_acc {'accuracy': 0.9074}


retrained_acc {'accuracy': 0.9865}
relative acc = 108.7 %
repair rate = 87.3 %
break rate = 0.2 %
Split: test
ori_acc {'accuracy': 0.9118}
retrained_acc {'accuracy': 0.893}
relative acc = 97.9 %
repair rate = 28.7 %
break rate = 4.8 %


In [47]:
# original vs neuron repair result
ori_res_dir = os.path.join(pretrained_dir, "pred_results", "PredictionOutput")
rep_res_dir = os.path.join(neuron_repair_dir, "pred_results", "best_patch_1")
for s in splits:
    print(f"Split: {s}")
    # original
    ori_filename = os.path.join(ori_res_dir, f"{s}_pred.pkl")
    with open(ori_filename, "rb") as f:
        ori_pred = pickle.load(f)
    ori_pred_labels = ori_pred.predictions.argmax(-1)
    ori_true_labels = ori_pred.label_ids
    ori_acc = met_acc.compute(predictions=ori_pred_labels, references=ori_true_labels)
    print("ori_acc", ori_acc)
    # repaired
    rep_filename = os.path.join(rep_res_dir, f"{s}_pred_results.npz")
    rep_pred = np.load(rep_filename)
    rep_pred_labels = rep_pred["all_pred_labels"]
    rep_true_labels = rep_pred["true_labels"]
    repaired_acc = met_acc.compute(predictions=rep_pred_labels, references=rep_true_labels)
    print("repaired_acc", repaired_acc)
    # ori_true_labelsとrepaired_true_labelsは同じはず
    assert all(ori_true_labels == rep_true_labels)
    # relative acc
    racc = 100*repaired_acc["accuracy"]/ori_acc["accuracy"]
    print(f"relative acc = {racc:.1f} %")
    # oriが不正解だったサンプルをrepairedが正解できた割合
    correct = (ori_pred_labels == ori_true_labels)
    ori_incorrect = np.where(correct == False)[0]
    repaired_correct = (rep_pred_labels == rep_true_labels)
    repaired_correct_in_ori_incorrect = repaired_correct[ori_incorrect]
    rr = 100*repaired_correct_in_ori_incorrect.mean()
    print(f"repair rate = {rr:.1f} %")
    # oriが正解だったサンプルをretrainedが正解できなかった割合
    ori_correct = np.where(correct == True)[0]
    repaired_incorrect = (rep_pred_labels != rep_true_labels)
    repaired_incorrect_in_ori_correct = repaired_incorrect[ori_correct]
    br = 100*repaired_incorrect_in_ori_correct.mean()
    print(f"break rate = {br:.1f} %")

Split: repair
ori_acc {'accuracy': 0.9074}
repaired_acc {'accuracy': 0.8923}
relative acc = 98.3 %
repair rate = 15.2 %
break rate = 3.2 %
Split: test
ori_acc {'accuracy': 0.9118}
repaired_acc {'accuracy': 0.8956}
relative acc = 98.2 %
repair rate = 13.3 %
break rate = 3.1 %


In [48]:
# original vs weight repair result
ori_res_dir = os.path.join(pretrained_dir, "pred_results", "PredictionOutput")
rep_res_dir = os.path.join(weight_repair_dir, "pred_results", "best_patch_1")
for s in splits:
    print(f"Split: {s}")
    # original
    ori_filename = os.path.join(ori_res_dir, f"{s}_pred.pkl")
    with open(ori_filename, "rb") as f:
        ori_pred = pickle.load(f)
    ori_pred_labels = ori_pred.predictions.argmax(-1)
    ori_true_labels = ori_pred.label_ids
    ori_acc = met_acc.compute(predictions=ori_pred_labels, references=ori_true_labels)
    print("ori_acc", ori_acc)
    # repaired
    rep_filename = os.path.join(rep_res_dir, f"{s}_pred_results.npz")
    rep_pred = np.load(rep_filename)
    rep_pred_labels = rep_pred["all_pred_labels"]
    rep_true_labels = rep_pred["true_labels"]
    repaired_acc = met_acc.compute(predictions=rep_pred_labels, references=rep_true_labels)
    print("repaired_acc", repaired_acc)
    # ori_true_labelsとrepaired_true_labelsは同じはず
    assert all(ori_true_labels == rep_true_labels)
    # relative acc
    racc = 100*repaired_acc["accuracy"]/ori_acc["accuracy"]
    print(f"relative acc = {racc:.1f} %")
    # oriが不正解だったサンプルをrepairedが正解できた割合
    correct = (ori_pred_labels == ori_true_labels)
    ori_incorrect = np.where(correct == False)[0]
    repaired_correct = (rep_pred_labels == rep_true_labels)
    repaired_correct_in_ori_incorrect = repaired_correct[ori_incorrect]
    rr = 100*repaired_correct_in_ori_incorrect.mean()
    print(f"repair rate = {rr:.1f} %")
    # oriが正解だったサンプルをretrainedが正解できなかった割合
    ori_correct = np.where(correct == True)[0]
    repaired_incorrect = (rep_pred_labels != rep_true_labels)
    repaired_incorrect_in_ori_correct = repaired_incorrect[ori_correct]
    br = 100*repaired_incorrect_in_ori_correct.mean()
    print(f"break rate = {br:.1f} %")

Split: repair
ori_acc {'accuracy': 0.9074}
repaired_acc {'accuracy': 0.8944}
relative acc = 98.6 %
repair rate = 13.6 %
break rate = 2.8 %
Split: test
ori_acc {'accuracy': 0.9118}
repaired_acc {'accuracy': 0.8971}
relative acc = 98.4 %
repair rate = 11.6 %
break rate = 2.7 %
