修正 or 再学習の前後で誤分類の頻度や種類がどれくらい変わったかを見たい

In [1]:
import os, sys, time
import argparse
import numpy as np
import pickle
from itertools import product
from utils.constant import ViTExperiment

In [2]:
ds_name = "c100"
k = 0
tgt_rank = 1
alpha_list = [0.2, 0.4, 0.6, 0.8]
n_list = [77]

In [3]:
def load_misinfo(save_dir, tgt_split):
    # mis_matrixはnpyで，それ以外はpklで保存
    mis_matrix = np.load(os.path.join(save_dir, f"{tgt_split}_mis_matrix.npy"))
    with open(os.path.join(save_dir, f"{tgt_split}_mis_ranking.pkl"), "rb") as f:
        mis_ranking = pickle.load(f)
    with open(os.path.join(save_dir, f"{tgt_split}_mis_indices.pkl"), "rb") as f:
        mis_indices = pickle.load(f)
    with open(os.path.join(save_dir, f"{tgt_split}_met_dict.pkl"), "rb") as f:
        met_dict = pickle.load(f)
    return mis_matrix, mis_ranking, mis_indices, met_dict

def get_misinfo_each_model(tgt_split):
    # original modelのmisinfoを取得
    ori_pretrained_dir = getattr(ViTExperiment, ds_name).OUTPUT_DIR.format(k=k)
    ori_misinfo_dir = os.path.join(ori_pretrained_dir, "misclf_info")
    ori_misinfo = load_misinfo(ori_misinfo_dir, tgt_split)
    # retrained modelのmisinfoを取得
    tgt_rank_dir = os.path.join(ori_pretrained_dir, f"misclf_top{tgt_rank}")
    retrained_dir = os.path.join(tgt_rank_dir, "retraining_with_only_repair_target") # SHOULD BE CHANGED
    retrained_misinfo_dir = os.path.join(retrained_dir, "misclf_info")
    retrained_misinfo = load_misinfo(retrained_misinfo_dir, tgt_split)
    # repaired modelのmisinfoを取得
    repaired_dir = os.path.join(tgt_rank_dir, "src_tgt_repair_weight_by_de")
    repaired_misinfo_list = []
    for n, alpha in product(n_list, alpha_list):
        repaired_misinfo_dir = os.path.join(repaired_dir, f"misclf_info_n{n}_alpha{alpha}")
        repaired_misinfo_list.append(load_misinfo(repaired_misinfo_dir, tgt_split))
    return ori_misinfo, retrained_misinfo, repaired_misinfo_list

def summarize_met(met_list, topn):
    sorted_indices = np.argsort(met_list)
    return np.array([[sorted_indices[i], met_list[sorted_indices[i]]] for i in range(topn)])

In [4]:
topn = 10
for tgt_split in ["repair", "test"]:
    ori_misinfo, retrained_misinfo, repaired_misinfo_list = get_misinfo_each_model(tgt_split)
    ori_top10_misclf = np.array(ori_misinfo[1][:topn])
    retrained_top10_misclf = np.array(retrained_misinfo[1][:topn])
    repaired_top10_misclf_list = [np.array(repaired_misinfo[1][:topn]) for repaired_misinfo in repaired_misinfo_list]
    misclf_arr = np.concatenate([ori_top10_misclf, retrained_top10_misclf] + repaired_top10_misclf_list, axis=1)
    np.savetxt(f'./misclf_cnt_{tgt_split}.csv', misclf_arr, delimiter=',', fmt='%d')
    for metric in ["precision", "recall", "f1"]:
        ori_class_met = ori_misinfo[-1][metric]
        ori_met_summary = summarize_met(ori_class_met, topn)
        retrained_class_met = retrained_misinfo[-1][metric]
        retrained_met_summary = summarize_met(retrained_class_met, topn)
        repaired_class_met_list = [repaired_misinfo[-1][metric] for repaired_misinfo in repaired_misinfo_list]
        repaired_met_summary_list = [summarize_met(repaired_classmet, topn) for repaired_classmet in repaired_class_met_list]
        met_summary_list = [ori_met_summary, retrained_met_summary] + repaired_met_summary_list
        met_summary_arr = np.concatenate(met_summary_list, axis=1)
        np.savetxt(f'./{metric}_summary_{tgt_split}.csv', met_summary_arr, delimiter=',', fmt='%.3f')