In [45]:
import os, sys, time, pickle, json, math
from tqdm import tqdm
from collections import defaultdict
from itertools import product
from sklearn.metrics import confusion_matrix
import numpy as np
import pandas as pd
import torch
from utils.helper import get_device, json2dict
from utils.vit_util import identfy_tgt_misclf, ViTFromLastLayer, maybe_initialize_repair_weights_, get_new_model_predictions, get_batched_hs, get_batched_labels
from utils.constant import ViTExperiment, ExperimentRepair1, Experiment3, ExperimentRepair2
from utils.log import set_exp_logging
from logging import getLogger
from datasets import load_from_disk
from transformers import ViTForImageClassification
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
# Notebook向け
%matplotlib inline
from itertools import product
from sklearn.preprocessing import MinMaxScaler

logger = getLogger("base_logger")
device = get_device()

def main(ds_name, k, tgt_rank, misclf_type, fpfn, tgt_split="repair"):
    print(f"ds_name: {ds_name}, fold_id: {k}, tgt_rank: {tgt_rank}, misclf_type: {misclf_type}, fpfn: {fpfn}")
    
    ts = time.perf_counter()
    
    # ds_name = ds_name.replace("-", "_")  # ds_nameのフォーマットを揃える
    # datasetをロード (true_labelsが欲しいので)
    ds_dirname = f"{ds_name}_fold{k}"
    ds = load_from_disk(os.path.join(ViTExperiment.DATASET_DIR, ds_dirname))
    if ds_name == "c10" or ds_name == "tiny-imagenet":
        label_col = "label"
    elif ds_name == "c100":
        label_col = "fine_label"
    # ラベルの取得 (shuffleされない)
    labels = {
        "train": np.array(ds["train"][label_col]),
        "repair": np.array(ds["repair"][label_col]),
        "test": np.array(ds["test"][label_col])
    }
    tgt_pos = ViTExperiment.CLS_IDX
    
    # 結果とかログの保存先を先に作っておく
    # pretrained modelのディレクトリ
    pretrained_dir = getattr(ViTExperiment, ds_name.replace("-", "_")).OUTPUT_DIR.format(k=k)
        
    # tgt_rankの誤分類情報を取り出す
    tgt_layer = 11 # NOTE: we only use the last layer for repairing
    misclf_info_dir = os.path.join(pretrained_dir, "misclf_info")
    misclf_pair, tgt_label, tgt_mis_indices = identfy_tgt_misclf(misclf_info_dir, tgt_split="repair", tgt_rank=tgt_rank, misclf_type=misclf_type, fpfn=fpfn)
    if tgt_split == "repair":
        print(f"misclf_pair: {misclf_pair}, tgt_label: {tgt_label}, len(tgt_mis_indices): {len(tgt_mis_indices)}")
    elif tgt_split == "test":
        tgt_mis_indices = []
        model, loading_info = ViTForImageClassification.from_pretrained(pretrained_dir, output_loading_info=True)
        model.to(device).eval()
        model = maybe_initialize_repair_weights_(model, loading_info["missing_keys"])
        vit_from_last_layer = ViTFromLastLayer(model)
        vit_from_last_layer.eval()
        hs_save_dir = os.path.join(pretrained_dir, f"cache_hidden_states_before_layernorm_{tgt_split}")
        hs_save_path = os.path.join(hs_save_dir, f"hidden_states_before_layernorm_{tgt_layer}.npy")
        if not os.path.exists(hs_save_path):
            print(f"[ERROR] {hs_save_path} does not exist.")
            sys.exit(1)
        hs_before_layernorm = torch.from_numpy(np.load(hs_save_path)).to(device)

        # 全repair setの hidden states
        batch_size = ViTExperiment.BATCH_SIZE
        ori_tgt_labels = labels[tgt_split]
        batch_hs_before_layernorm = get_batched_hs(hs_save_path, batch_size, device=device, hs=hs_before_layernorm)
        batch_labels = get_batched_labels(ori_tgt_labels, batch_size)
        pred_labels_old, true_labels_old = get_new_model_predictions(
            vit_from_last_layer,
            batch_hs_before_layernorm,
            batch_labels,
            tgt_pos=0
        )
        print(f"pred_labels_old: {pred_labels_old.shape}, true_labels_old: {true_labels_old.shape}")
        # 混同行列の作成
        cm = confusion_matrix(true_labels_old, pred_labels_old)
        num_classes = cm.shape[0]

        # 可視化
        # plt.figure(figsize=(16, 12))
        # sns.heatmap(cm, fmt="d", cmap="Blues", cbar=False,
        #             xticklabels=range(num_classes), yticklabels=range(num_classes))
        # plt.xlabel("Predicted Label")
        # plt.ylabel("True Label")
        # plt.title(f"Confusion Matrix on {tgt_split} split")
        # plt.tight_layout()
        # # plt.show()
        # False Negative = 行合計 - 対角成分
        # if misclf_type == "src_tgt":
        #     src, tgt = misclf_pair
        #     src_tgt_count = cm[tgt, src]
        #     print(f"Count of misclassification from true {tgt} to predicted {src}: {src_tgt_count}")
        # elif misclf_type == "tgt":
        #     if fpfn == "fn":
        #         fn_count = cm[tgt_label].sum() - cm[tgt_label, tgt_label]
        #         print(f"False Negative count for label {tgt_label}: {fn_count}")
        #     elif fpfn == "fp":
        #         fp_count = cm[:, tgt_label].sum() - cm[tgt_label, tgt_label]
        #         print(f"False Positive count for label {tgt_label}: {fp_count}")
        assert tgt_split == "test", f"tgt_split={tgt_split} is not supported."
        for idx, (pl, tl) in enumerate(zip(pred_labels_old, true_labels_old)):
            if misclf_type == "src_tgt":
                if pl == misclf_pair[0] and tl == misclf_pair[1]:
                    tgt_mis_indices.append(idx)
            elif misclf_type == "tgt" and fpfn == "fp":
                if pl == tgt_label and tl != tgt_label:
                    tgt_mis_indices.append(idx)
            elif misclf_type == "tgt" and fpfn == "fn":
                if tl == tgt_label and pl != tgt_label:
                    tgt_mis_indices.append(idx)
            else:
                raise ValueError(f"misclf_type={misclf_type} and fpfn={fpfn} is not supported.")
        print(f"misclf_pair: {misclf_pair}, tgt_label: {tgt_label}, len(tgt_mis_indices): {len(tgt_mis_indices)}")
        

Device: cuda


In [43]:
ds_name_list = ["c100", "tiny-imagenet"]
k_list = [0]
tgt_rank_list = [1, 2, 3]
misclf_type_fpfn_pair_list = [
    ("src_tgt", None), 
    ("tgt", "fp"), 
    ("tgt", "fn"), 
]

for ds_name, k, (misclf_type, fpfn), tgt_rank in product(ds_name_list, k_list, misclf_type_fpfn_pair_list, tgt_rank_list):
    main(ds_name, k, tgt_rank, misclf_type, fpfn)
    # print(f"Finished: ds_name: {ds_name}, k: {k}, tgt_rank: {tgt_rank}, misclf_type: {misclf_type}, fpfn: {fpfn}")

ds_name: c100, fold_id: 0, tgt_rank: 1, misclf_type: src_tgt, fpfn: None
misclf_pair: (52, 47), tgt_label: None, len(tgt_mis_indices): 19
ds_name: c100, fold_id: 0, tgt_rank: 2, misclf_type: src_tgt, fpfn: None
misclf_pair: (61, 10), tgt_label: None, len(tgt_mis_indices): 16
ds_name: c100, fold_id: 0, tgt_rank: 3, misclf_type: src_tgt, fpfn: None
misclf_pair: (11, 35), tgt_label: None, len(tgt_mis_indices): 14
ds_name: c100, fold_id: 0, tgt_rank: 1, misclf_type: tgt, fpfn: fp
misclf_pair: None, tgt_label: 47, len(tgt_mis_indices): 27
ds_name: c100, fold_id: 0, tgt_rank: 2, misclf_type: tgt, fpfn: fp
misclf_pair: None, tgt_label: 11, len(tgt_mis_indices): 32
ds_name: c100, fold_id: 0, tgt_rank: 3, misclf_type: tgt, fpfn: fp
misclf_pair: None, tgt_label: 52, len(tgt_mis_indices): 31
ds_name: c100, fold_id: 0, tgt_rank: 1, misclf_type: tgt, fpfn: fn
misclf_pair: None, tgt_label: 47, len(tgt_mis_indices): 27
ds_name: c100, fold_id: 0, tgt_rank: 2, misclf_type: tgt, fpfn: fn
misclf_pair: No

In [46]:
ds_name_list = ["c100", "tiny-imagenet"]
k_list = [0]
tgt_rank_list = [1, 2, 3]
misclf_type_fpfn_pair_list = [
    ("src_tgt", None), 
    ("tgt", "fp"), 
    ("tgt", "fn"), 
]

for ds_name, k, (misclf_type, fpfn), tgt_rank in product(ds_name_list, k_list, misclf_type_fpfn_pair_list, tgt_rank_list):
    main(ds_name, k, tgt_rank, misclf_type, fpfn, tgt_split="test")

ds_name: c100, fold_id: 0, tgt_rank: 1, misclf_type: src_tgt, fpfn: None


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /src/src/out_vit_c100_fold0 and are newly initialized: ['vit.encoder.layer.5.intermediate.repair.weight', 'vit.encoder.layer.6.intermediate.repair.weight', 'vit.encoder.layer.2.intermediate.repair.weight', 'vit.encoder.layer.7.intermediate.repair.weight', 'vit.encoder.layer.9.intermediate.repair.weight', 'vit.encoder.layer.11.intermediate.repair.weight', 'vit.encoder.layer.10.intermediate.repair.weight', 'vit.encoder.layer.0.intermediate.repair.weight', 'vit.encoder.layer.8.intermediate.repair.weight', 'vit.encoder.layer.4.intermediate.repair.weight', 'vit.encoder.layer.1.intermediate.repair.weight', 'vit.encoder.layer.3.intermediate.repair.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


🛠️ Initializing intermediate.repair.weight as identity matrix (for missing weights)
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: (52, 47), tgt_label: None, len(tgt_mis_indices): 21
ds_name: c100, fold_id: 0, tgt_rank: 2, misclf_type: src_tgt, fpfn: None


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /src/src/out_vit_c100_fold0 and are newly initialized: ['vit.encoder.layer.5.intermediate.repair.weight', 'vit.encoder.layer.6.intermediate.repair.weight', 'vit.encoder.layer.2.intermediate.repair.weight', 'vit.encoder.layer.7.intermediate.repair.weight', 'vit.encoder.layer.9.intermediate.repair.weight', 'vit.encoder.layer.11.intermediate.repair.weight', 'vit.encoder.layer.10.intermediate.repair.weight', 'vit.encoder.layer.0.intermediate.repair.weight', 'vit.encoder.layer.8.intermediate.repair.weight', 'vit.encoder.layer.4.intermediate.repair.weight', 'vit.encoder.layer.1.intermediate.repair.weight', 'vit.encoder.layer.3.intermediate.repair.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


🛠️ Initializing intermediate.repair.weight as identity matrix (for missing weights)
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: (61, 10), tgt_label: None, len(tgt_mis_indices): 8
ds_name: c100, fold_id: 0, tgt_rank: 3, misclf_type: src_tgt, fpfn: None


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /src/src/out_vit_c100_fold0 and are newly initialized: ['vit.encoder.layer.5.intermediate.repair.weight', 'vit.encoder.layer.6.intermediate.repair.weight', 'vit.encoder.layer.2.intermediate.repair.weight', 'vit.encoder.layer.7.intermediate.repair.weight', 'vit.encoder.layer.9.intermediate.repair.weight', 'vit.encoder.layer.11.intermediate.repair.weight', 'vit.encoder.layer.10.intermediate.repair.weight', 'vit.encoder.layer.0.intermediate.repair.weight', 'vit.encoder.layer.8.intermediate.repair.weight', 'vit.encoder.layer.4.intermediate.repair.weight', 'vit.encoder.layer.1.intermediate.repair.weight', 'vit.encoder.layer.3.intermediate.repair.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


🛠️ Initializing intermediate.repair.weight as identity matrix (for missing weights)
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: (11, 35), tgt_label: None, len(tgt_mis_indices): 8
ds_name: c100, fold_id: 0, tgt_rank: 1, misclf_type: tgt, fpfn: fp


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /src/src/out_vit_c100_fold0 and are newly initialized: ['vit.encoder.layer.5.intermediate.repair.weight', 'vit.encoder.layer.6.intermediate.repair.weight', 'vit.encoder.layer.2.intermediate.repair.weight', 'vit.encoder.layer.7.intermediate.repair.weight', 'vit.encoder.layer.9.intermediate.repair.weight', 'vit.encoder.layer.11.intermediate.repair.weight', 'vit.encoder.layer.10.intermediate.repair.weight', 'vit.encoder.layer.0.intermediate.repair.weight', 'vit.encoder.layer.8.intermediate.repair.weight', 'vit.encoder.layer.4.intermediate.repair.weight', 'vit.encoder.layer.1.intermediate.repair.weight', 'vit.encoder.layer.3.intermediate.repair.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


🛠️ Initializing intermediate.repair.weight as identity matrix (for missing weights)
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: None, tgt_label: 47, len(tgt_mis_indices): 34
ds_name: c100, fold_id: 0, tgt_rank: 2, misclf_type: tgt, fpfn: fp


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /src/src/out_vit_c100_fold0 and are newly initialized: ['vit.encoder.layer.5.intermediate.repair.weight', 'vit.encoder.layer.6.intermediate.repair.weight', 'vit.encoder.layer.2.intermediate.repair.weight', 'vit.encoder.layer.7.intermediate.repair.weight', 'vit.encoder.layer.9.intermediate.repair.weight', 'vit.encoder.layer.11.intermediate.repair.weight', 'vit.encoder.layer.10.intermediate.repair.weight', 'vit.encoder.layer.0.intermediate.repair.weight', 'vit.encoder.layer.8.intermediate.repair.weight', 'vit.encoder.layer.4.intermediate.repair.weight', 'vit.encoder.layer.1.intermediate.repair.weight', 'vit.encoder.layer.3.intermediate.repair.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


🛠️ Initializing intermediate.repair.weight as identity matrix (for missing weights)
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: None, tgt_label: 11, len(tgt_mis_indices): 23
ds_name: c100, fold_id: 0, tgt_rank: 3, misclf_type: tgt, fpfn: fp


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /src/src/out_vit_c100_fold0 and are newly initialized: ['vit.encoder.layer.5.intermediate.repair.weight', 'vit.encoder.layer.6.intermediate.repair.weight', 'vit.encoder.layer.2.intermediate.repair.weight', 'vit.encoder.layer.7.intermediate.repair.weight', 'vit.encoder.layer.9.intermediate.repair.weight', 'vit.encoder.layer.11.intermediate.repair.weight', 'vit.encoder.layer.10.intermediate.repair.weight', 'vit.encoder.layer.0.intermediate.repair.weight', 'vit.encoder.layer.8.intermediate.repair.weight', 'vit.encoder.layer.4.intermediate.repair.weight', 'vit.encoder.layer.1.intermediate.repair.weight', 'vit.encoder.layer.3.intermediate.repair.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


🛠️ Initializing intermediate.repair.weight as identity matrix (for missing weights)
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: None, tgt_label: 52, len(tgt_mis_indices): 42
ds_name: c100, fold_id: 0, tgt_rank: 1, misclf_type: tgt, fpfn: fn


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /src/src/out_vit_c100_fold0 and are newly initialized: ['vit.encoder.layer.5.intermediate.repair.weight', 'vit.encoder.layer.6.intermediate.repair.weight', 'vit.encoder.layer.2.intermediate.repair.weight', 'vit.encoder.layer.7.intermediate.repair.weight', 'vit.encoder.layer.9.intermediate.repair.weight', 'vit.encoder.layer.11.intermediate.repair.weight', 'vit.encoder.layer.10.intermediate.repair.weight', 'vit.encoder.layer.0.intermediate.repair.weight', 'vit.encoder.layer.8.intermediate.repair.weight', 'vit.encoder.layer.4.intermediate.repair.weight', 'vit.encoder.layer.1.intermediate.repair.weight', 'vit.encoder.layer.3.intermediate.repair.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


🛠️ Initializing intermediate.repair.weight as identity matrix (for missing weights)
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: None, tgt_label: 47, len(tgt_mis_indices): 28
ds_name: c100, fold_id: 0, tgt_rank: 2, misclf_type: tgt, fpfn: fn


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /src/src/out_vit_c100_fold0 and are newly initialized: ['vit.encoder.layer.5.intermediate.repair.weight', 'vit.encoder.layer.6.intermediate.repair.weight', 'vit.encoder.layer.2.intermediate.repair.weight', 'vit.encoder.layer.7.intermediate.repair.weight', 'vit.encoder.layer.9.intermediate.repair.weight', 'vit.encoder.layer.11.intermediate.repair.weight', 'vit.encoder.layer.10.intermediate.repair.weight', 'vit.encoder.layer.0.intermediate.repair.weight', 'vit.encoder.layer.8.intermediate.repair.weight', 'vit.encoder.layer.4.intermediate.repair.weight', 'vit.encoder.layer.1.intermediate.repair.weight', 'vit.encoder.layer.3.intermediate.repair.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


🛠️ Initializing intermediate.repair.weight as identity matrix (for missing weights)
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: None, tgt_label: 11, len(tgt_mis_indices): 25
ds_name: c100, fold_id: 0, tgt_rank: 3, misclf_type: tgt, fpfn: fn


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /src/src/out_vit_c100_fold0 and are newly initialized: ['vit.encoder.layer.5.intermediate.repair.weight', 'vit.encoder.layer.6.intermediate.repair.weight', 'vit.encoder.layer.2.intermediate.repair.weight', 'vit.encoder.layer.7.intermediate.repair.weight', 'vit.encoder.layer.9.intermediate.repair.weight', 'vit.encoder.layer.11.intermediate.repair.weight', 'vit.encoder.layer.10.intermediate.repair.weight', 'vit.encoder.layer.0.intermediate.repair.weight', 'vit.encoder.layer.8.intermediate.repair.weight', 'vit.encoder.layer.4.intermediate.repair.weight', 'vit.encoder.layer.1.intermediate.repair.weight', 'vit.encoder.layer.3.intermediate.repair.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


🛠️ Initializing intermediate.repair.weight as identity matrix (for missing weights)
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: None, tgt_label: 35, len(tgt_mis_indices): 29
ds_name: tiny-imagenet, fold_id: 0, tgt_rank: 1, misclf_type: src_tgt, fpfn: None
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: (66, 0), tgt_label: None, len(tgt_mis_indices): 14
ds_name: tiny-imagenet, fold_id: 0, tgt_rank: 2, misclf_type: src_tgt, fpfn: None
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: (95, 86), tgt_label: None, len(tgt_mis_indices): 12
ds_name: tiny-imagenet, fold_id: 0, tgt_rank: 3, misclf_type: src_tgt, fpfn: None
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: (117, 157), tgt_label: None, len(tgt_mis_indices): 9
ds_name: tiny-imagenet, fold_id: 0, tgt_rank: 1, misclf_type: tgt, fpfn: fp
pred_labels_old: (10000,), true_labels_old: (10000,)
misclf_pair: None, tgt_label: 9, len(tgt_mis_indices): 25
ds_name: tiny-ima