# 目的
違う手法ごとに特定された重みの位置の性質を見る．

具体的には，中間ニューロンのインデックスのユニーク性をチェックしたい．

In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
from utils.vit_util import transforms_c100, get_batched_hs, get_batched_labels, ViTFromLastLayer, get_ori_model_predictions, identfy_tgt_misclf
from utils.constant import ViTExperiment, Experiment1, Experiment3, ExperimentRepair1, ExperimentRepair2, ExperimentRepair3
from utils.helper import get_device
from utils.de import set_new_weights
from transformers import ViTForImageClassification
from tqdm import tqdm
import torch
import torch.nn as nn
import os
from collections import defaultdict
import numpy as np
import pandas as pd
from datasets import load_from_disk

# デバイスの設定
device = get_device()

def get_location_path(n, fl_method, location_dir, beta=None):
    if fl_method == "ours":
        location_file = f"exp-fl-6_location_n{n}_beta{beta}_weight_ours.npy"
    elif fl_method == "bl":
        location_file = f"exp-fl-2_location_n{n}_weight_bl.npy"
    elif fl_method == "random":
        location_file = f"exp-fl-1_location_n{n}_weight_random.npy"
    else:
        raise ValueError(f"Unknown fl_method: {fl_method}")
    location_path = os.path.join(location_dir, location_file)
    return location_path

Device: cuda


In [10]:
def main(fl_method, n, beta, rank):
    # rank を tgt_rank として利用（int型）
    tgt_rank = rank
    tgt_pos = ViTExperiment.CLS_IDX
    
    # プリトレーニング済みモデルとキャッシュの hidden states のロード
    pretrained_dir = ViTExperiment.c100.OUTPUT_DIR.format(k=0)
    
    # 実験条件のループ（misclf_type と fpfn は内部ループ）
    misclf_type_list = ["src_tgt", "tgt"] # allは対象外にする
    fpfn_list = [None, "fp", "fn"]
    results_list = []
    for misclf_type in misclf_type_list:
        for fpfn in fpfn_list:
            # ルール：misclf_type=="all" は tgt_rank が 1 かつ fpfn が None のみ有効
            if misclf_type == "all":
                if tgt_rank >= 2 or fpfn is not None:
                    continue
            # misclf_type=="src_tgt" の場合、fpfn は None のみ
            if misclf_type == "src_tgt" and fpfn is not None:
                continue
            location_dir = os.path.join(pretrained_dir, f"misclf_top{tgt_rank}", f"{misclf_type}_weights_location")
            location_path = get_location_path(n, fl_method, location_dir, beta)
            pos_before, pos_after = np.load(location_path, allow_pickle=True)
            print(f"Location file: {location_path}")
            print(f"pos_before.shape: {pos_before.shape}, pos_after.shape: {pos_after.shape}")
            results_list.append({
                "misclf_type": misclf_type,
                "fpfn": fpfn,
                "pos_before": pos_before,
                "pos_after": pos_after,
            })
    return results_list

In [17]:
all_results = []
fl_method_list = ["ours", "bl", "random"]
tgt_rank_list = [1, 2, 3, 4, 5]
n_list = [12, 24, 48, 96]

for n in n_list:
    for tgt_rank in tgt_rank_list:
        for fl_method in fl_method_list:
            if fl_method == "ours":
                beta_list = [0.1, 0.25, 0.5, 0.75, 1.0]
            else:
                beta_list = [None]
            for beta in beta_list:
                print(f"\nn: {n}, tgt_rank: {tgt_rank}, fl_method: {fl_method}, beta: {beta}")
                print(f"========================")
                ret_list = main(fl_method=fl_method, n=n, beta=beta, rank=tgt_rank)
                for result in ret_list:
                    all_results.append({
                        "fl_method": fl_method,
                        "tgt_rank": tgt_rank,
                        "n": n,
                        "beta": beta,
                        "misclf_type": result["misclf_type"],
                        "fpfn": result["fpfn"],
                        "pos_before": result["pos_before"],
                        "pos_after": result["pos_after"],
                    })


n: 12, tgt_rank: 1, fl_method: ours, beta: 0.1
Location file: /src/src/out_vit_c100_fold0/misclf_top1/src_tgt_weights_location/exp-fl-6_location_n12_beta0.1_weight_ours.npy
pos_before.shape: (739, 2), pos_after.shape: (413, 2)
Location file: /src/src/out_vit_c100_fold0/misclf_top1/tgt_weights_location/exp-fl-6_location_n12_beta0.1_weight_ours.npy
pos_before.shape: (769, 2), pos_after.shape: (383, 2)
Location file: /src/src/out_vit_c100_fold0/misclf_top1/tgt_weights_location/exp-fl-6_location_n12_beta0.1_weight_ours.npy
pos_before.shape: (769, 2), pos_after.shape: (383, 2)
Location file: /src/src/out_vit_c100_fold0/misclf_top1/tgt_weights_location/exp-fl-6_location_n12_beta0.1_weight_ours.npy
pos_before.shape: (769, 2), pos_after.shape: (383, 2)

n: 12, tgt_rank: 1, fl_method: ours, beta: 0.25
Location file: /src/src/out_vit_c100_fold0/misclf_top1/src_tgt_weights_location/exp-fl-6_location_n12_beta0.25_weight_ours.npy
pos_before.shape: (586, 2), pos_after.shape: (566, 2)
Location file:

In [18]:
df_all = pd.DataFrame(all_results)
df_all["misclf_type_name"] = df_all.apply(
    lambda row: row["misclf_type"] if row["fpfn"] is None else f'{row["misclf_type"]}_{row["fpfn"]}',
    axis=1
)

In [None]:
df_all

Unnamed: 0,fl_method,tgt_rank,n,beta,misclf_type,fpfn,pos_before,pos_after,misclf_type_name
0,ours,1,12,0.10,src_tgt,,"[[1313, 678], [260, 678], [1313, 475], [2055, ...","[[674, 1755], [674, 1313], [131, 1755], [174, ...",src_tgt
1,ours,1,12,0.10,tgt,,"[[1509, 678], [322, 448], [200, 435], [200, 14...","[[263, 1754], [135, 1755], [42, 1755], [766, 1...",tgt
2,ours,1,12,0.10,tgt,fp,"[[1509, 678], [322, 448], [200, 435], [200, 14...","[[263, 1754], [135, 1755], [42, 1755], [766, 1...",tgt_fp
3,ours,1,12,0.10,tgt,fn,"[[1509, 678], [322, 448], [200, 435], [200, 14...","[[263, 1754], [135, 1755], [42, 1755], [766, 1...",tgt_fn
4,ours,1,12,0.25,src_tgt,,"[[1313, 678], [260, 678], [1313, 475], [2055, ...","[[674, 1313], [674, 1755], [131, 1755], [174, ...",src_tgt
...,...,...,...,...,...,...,...,...,...
555,bl,5,96,,tgt,fn,"[[328, 678], [1389, 678], [115, 597], [115, 49...","[[189, 2245], [635, 863], [56, 2245], [68, 224...",tgt_fn
556,random,5,96,,src_tgt,,"[[1841, 639], [1841, 194], [1841, 311], [1841,...","[[679, 1841], [679, 1187], [679, 1174], [679, ...",src_tgt
557,random,5,96,,tgt,,"[[104, 647], [104, 438], [104, 670], [104, 697...","[[682, 104], [682, 1296], [682, 1284], [682, 2...",tgt
558,random,5,96,,tgt,fp,"[[104, 647], [104, 438], [104, 670], [104, 697...","[[682, 104], [682, 1296], [682, 1284], [682, 2...",tgt_fp


In [24]:
# pos_before, pos_after はそれぞれリスト型（もしくは list のように len() が使える型）と仮定

# 各行で「before の数 + after の数」を計算
df_all['num_before'] = df_all['pos_before'].apply(len)
df_all['num_after']  = df_all['pos_after'].apply(len)
df_all['num_total']  = df_all['num_before'] + df_all['num_after']

# 期待値 8 * n * n を計算
df_all['expected'] = 8 * df_all['n'] ** 2

# 一致しているかどうかのブール列を追加
df_all['matches'] = df_all['num_total'] == df_all['expected']

# 全行一致しているかの確認
all_ok = df_all['matches'].all()
print(f"全行一致している: {all_ok}")

# もし不一致の行があれば表示
mismatch = df_all.loc[~df_all['matches'], ['fl_method','tgt_rank','n','num_total','expected']]
if not mismatch.empty:
    print("不一致の行:")
    print(mismatch)


全行一致している: True


In [45]:
from collections import defaultdict

# 結果を記録する辞書
summary_before = defaultdict(defaultdict)
summary_after = defaultdict(defaultdict)

# グループ単位で集計
for (misclf_type_name, tgt_rank), group in df_all.groupby(["misclf_type_name", "tgt_rank"]):
    summary_before[misclf_type_name][tgt_rank] = {}
    summary_after[misclf_type_name][tgt_rank] = {}
    for fl_method in group["fl_method"].unique():
        subset = group[group["fl_method"] == fl_method]
        # pos_before の行インデックスを集めてユニーク数をカウント
        all_rows = []
        for arr in subset["pos_before"]:
            arr = np.array(arr)  # 念のため明示
            all_rows.append(arr[:, 0])  # 行方向インデックス（3072次元の方）を抽出
        if all_rows:
            flat = np.concatenate(all_rows)
            flat_unique = np.unique(flat)
            unique_count = len(flat_unique)
        else:
            unique_count = 0
        summary_before[misclf_type_name][tgt_rank][fl_method] = unique_count
        
        # pos_after の行インデックスを集めてユニーク数をカウント
        all_cols = []
        for arr in subset["pos_after"]:
            arr = np.array(arr)
            all_cols.append(arr[:, 1])  # 行方向（W_aft なら 768の方）に変えるならここで調整
        if all_cols:
            flat = np.concatenate(all_cols)
            flat_unique = np.unique(flat)
            unique_count = len(flat_unique)
        else:
            unique_count = 0
        summary_after[misclf_type_name][tgt_rank][fl_method] = unique_count

# 表形式で表示（pandas）
# import pandas as pd
# df_summary = pd.DataFrame(summary).T  # index: (misclf_type, fpfn), columns: fl_method
# df_summary.index.names = ["misclf_type", "fpfn"]

In [26]:
summary_before

defaultdict(collections.defaultdict,
            {'src_tgt': defaultdict(None,
                         {1: {'ours': 253, 'bl': 259, 'random': 672},
                          2: {'ours': 289, 'bl': 311, 'random': 671},
                          3: {'ours': 329, 'bl': 353, 'random': 673},
                          4: {'ours': 285, 'bl': 325, 'random': 667},
                          5: {'ours': 299, 'bl': 334, 'random': 672}}),
             'tgt': defaultdict(None,
                         {1: {'ours': 376, 'bl': 432, 'random': 672},
                          2: {'ours': 452, 'bl': 562, 'random': 668},
                          3: {'ours': 455, 'bl': 499, 'random': 669},
                          4: {'ours': 474, 'bl': 578, 'random': 674},
                          5: {'ours': 696, 'bl': 733, 'random': 662}}),
             'tgt_fn': defaultdict(None,
                         {1: {'ours': 376, 'bl': 432, 'random': 672},
                          2: {'ours': 452, 'bl': 562, 'random': 668}

In [27]:
summary_after

defaultdict(collections.defaultdict,
            {'src_tgt': defaultdict(None,
                         {1: {'ours': 205, 'bl': 222, 'random': 672},
                          2: {'ours': 213, 'bl': 226, 'random': 671},
                          3: {'ours': 216, 'bl': 240, 'random': 673},
                          4: {'ours': 214, 'bl': 243, 'random': 667},
                          5: {'ours': 212, 'bl': 243, 'random': 672}}),
             'tgt': defaultdict(None,
                         {1: {'ours': 233, 'bl': 267, 'random': 672},
                          2: {'ours': 271, 'bl': 331, 'random': 668},
                          3: {'ours': 266, 'bl': 298, 'random': 669},
                          4: {'ours': 249, 'bl': 300, 'random': 674},
                          5: {'ours': 377, 'bl': 399, 'random': 662}}),
             'tgt_fn': defaultdict(None,
                         {1: {'ours': 233, 'bl': 267, 'random': 672},
                          2: {'ours': 271, 'bl': 331, 'random': 668}

# 逆にしてみる

In [28]:
from collections import defaultdict

# 結果を記録する辞書
summary_before = defaultdict(defaultdict)
summary_after = defaultdict(defaultdict)

# グループ単位で集計
for (misclf_type_name, tgt_rank), group in df_all.groupby(["misclf_type_name", "tgt_rank"]):
    summary_before[misclf_type_name][tgt_rank] = {}
    summary_after[misclf_type_name][tgt_rank] = {}
    for fl_method in group["fl_method"].unique():
        subset = group[group["fl_method"] == fl_method]
        # pos_before の行インデックスを集めてユニーク数をカウント
        all_rows = []
        for arr in subset["pos_before"]:
            arr = np.array(arr)  # 念のため明示
            all_rows.append(arr[:, 1])  # 行方向インデックス（3072次元の方）を抽出
        if all_rows:
            flat = np.concatenate(all_rows)
            flat_unique = np.unique(flat)
            unique_count = len(flat_unique)
        else:
            unique_count = 0
        summary_before[misclf_type_name][tgt_rank][fl_method] = unique_count
        
        # pos_after の行インデックスを集めてユニーク数をカウント
        all_cols = []
        for arr in subset["pos_after"]:
            arr = np.array(arr)
            all_cols.append(arr[:, 0])  # 行方向（W_aft なら 768の方）に変えるならここで調整
        if all_cols:
            flat = np.concatenate(all_cols)
            flat_unique = np.unique(flat)
            unique_count = len(flat_unique)
        else:
            unique_count = 0
        summary_after[misclf_type_name][tgt_rank][fl_method] = unique_count

# 表形式で表示（pandas）
# import pandas as pd
# df_summary = pd.DataFrame(summary).T  # index: (misclf_type, fpfn), columns: fl_method
# df_summary.index.names = ["misclf_type", "fpfn"]

In [29]:
summary_before

defaultdict(collections.defaultdict,
            {'src_tgt': defaultdict(None,
                         {1: {'ours': 768, 'bl': 681, 'random': 170},
                          2: {'ours': 768, 'bl': 698, 'random': 168},
                          3: {'ours': 768, 'bl': 715, 'random': 160},
                          4: {'ours': 768, 'bl': 703, 'random': 164},
                          5: {'ours': 768, 'bl': 709, 'random': 168}}),
             'tgt': defaultdict(None,
                         {1: {'ours': 768, 'bl': 766, 'random': 164},
                          2: {'ours': 768, 'bl': 765, 'random': 169},
                          3: {'ours': 768, 'bl': 765, 'random': 169},
                          4: {'ours': 768, 'bl': 761, 'random': 164},
                          5: {'ours': 768, 'bl': 766, 'random': 167}}),
             'tgt_fn': defaultdict(None,
                         {1: {'ours': 768, 'bl': 766, 'random': 164},
                          2: {'ours': 768, 'bl': 765, 'random': 169}

In [30]:
summary_after

defaultdict(collections.defaultdict,
            {'src_tgt': defaultdict(None,
                         {1: {'ours': 768, 'bl': 344, 'random': 160},
                          2: {'ours': 768, 'bl': 338, 'random': 167},
                          3: {'ours': 768, 'bl': 326, 'random': 165},
                          4: {'ours': 768, 'bl': 316, 'random': 167},
                          5: {'ours': 768, 'bl': 320, 'random': 166}}),
             'tgt': defaultdict(None,
                         {1: {'ours': 768, 'bl': 560, 'random': 170},
                          2: {'ours': 768, 'bl': 535, 'random': 165},
                          3: {'ours': 768, 'bl': 502, 'random': 165},
                          4: {'ours': 768, 'bl': 439, 'random': 166},
                          5: {'ours': 768, 'bl': 527, 'random': 160}}),
             'tgt_fn': defaultdict(None,
                         {1: {'ours': 768, 'bl': 560, 'random': 170},
                          2: {'ours': 768, 'bl': 535, 'random': 165}

In [31]:
from collections import defaultdict

# 結果を記録する辞書
summary_before = defaultdict(dict)
summary_after = defaultdict(dict)

# グループ単位で集計
for misclf_type_name, group in df_all.groupby(["misclf_type_name"]):
    for fl_method in group["fl_method"].unique():
        subset = group[group["fl_method"] == fl_method]
        # pos_before の行インデックスを集めてユニーク数をカウント
        all_rows = []
        for arr in subset["pos_before"]:
            arr = np.array(arr)  # 念のため明示
            all_rows.append(arr[:, 1])  # 行方向インデックス（3072次元の方）を抽出 # XXX: ここが逆！
        if all_rows:
            flat = np.concatenate(all_rows)
            flat_unique = np.unique(flat)
            unique_count = len(flat_unique)
        else:
            unique_count = 0
        summary_before[misclf_type_name][fl_method] = unique_count
        
        # pos_after の行インデックスを集めてユニーク数をカウント
        all_cols = []
        for arr in subset["pos_after"]:
            arr = np.array(arr)
            all_cols.append(arr[:, 0])  # 行方向（W_aft なら 768の方）に変えるならここで調整 # XXX: ここが逆！
        if all_cols:
            flat = np.concatenate(all_cols)
            flat_unique = np.unique(flat)
            unique_count = len(flat_unique)
        else:
            unique_count = 0
        summary_after[misclf_type_name][fl_method] = unique_count


In [32]:
summary_before

defaultdict(dict,
            {'src_tgt': {'ours': 768, 'bl': 768, 'random': 533},
             'tgt': {'ours': 768, 'bl': 768, 'random': 538},
             'tgt_fn': {'ours': 768, 'bl': 768, 'random': 538},
             'tgt_fp': {'ours': 768, 'bl': 768, 'random': 538}})

In [33]:
summary_after

defaultdict(dict,
            {'src_tgt': {'ours': 768, 'bl': 746, 'random': 536},
             'tgt': {'ours': 768, 'bl': 759, 'random': 539},
             'tgt_fn': {'ours': 768, 'bl': 759, 'random': 539},
             'tgt_fp': {'ours': 768, 'bl': 759, 'random': 539}})