In [45]:
import os
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from matplotlib_venn import venn2
from utils.vit_util import ViTExperiment  # 仮のインポート、実際のパスに合わせてください
import matplotlib.gridspec as gridspec
from collections import defaultdict
import matplotlib.patheffects as path_effects
from matplotlib import rcParams

In [46]:
# ====================================
# 設定
# ====================================
ds_list = ["c100", "tiny-imagenet"]
k_list = [0]
tgt_rank_list = [1, 2, 3]
misclf_type_list = ["src_tgt", "tgt"]
fpfn_list = [None, "fp", "fn"]
w_num_list = [236, 472, 944]
tgt_split = "test"
num_reps = 5

# ====================================
# 図の出力フォルダ
# ====================================
output_dir = "./venn_outputs"
os.makedirs(output_dir, exist_ok=True)

# ====================================
# 補助関数
# ====================================
def load_indices(base_path, tgt_rank, wnum, fl_method, reps_id, misclf_type, fpfn, indices_type="repair_indices_tgt"):
    # setting id
    parts = [f"n{wnum}", "alpha0.9090909090909091", "boundsArachne"]
    setting_id = "_".join(parts)

    if misclf_type == "src_tgt":
        misclf_folder = f"misclf_top{tgt_rank}/{misclf_type}_repair_weight_by_de"
    elif misclf_type == "tgt" and fpfn is not None:
        misclf_folder = f"misclf_top{tgt_rank}/{misclf_type}_{fpfn}_repair_weight_by_de"
    else:
        return None

    save_dir = os.path.join(base_path, misclf_folder)
    filename = f"exp-repair-4-1-change_indices_{tgt_split}_{setting_id}_{fl_method}_reps{reps_id}.npz"
    filepath = os.path.join(save_dir, filename)

    if not os.path.exists(filepath):
        return None
    data = np.load(filepath)
    return set(data[indices_type])


# N_w, rankごとにまとめたグリッド状のプロットを作る
各プロットは横軸が誤分類タイプ (3つ) ，縦がデータセット (2つ) ．
それが N_w, rankの分だけある (合計9つ)

In [47]:
output_dir = "./venn_outputs"
os.makedirs(output_dir, exist_ok=True)

type_fpfn_labels = [("src_tgt", None), ("tgt", "fp"), ("tgt", "fn")]
    

for tgt_rank, wnum in product(tgt_rank_list, w_num_list):
    fig = plt.figure(figsize=(9, 6))
    gs = gridspec.GridSpec(len(ds_list), len(type_fpfn_labels), figure=fig, hspace=0.2)

    for row, ds in enumerate(ds_list):
        for misclf_type, fpfn in product(misclf_type_list, fpfn_list):
            if misclf_type == "src_tgt" and fpfn is not None:
                continue
            if misclf_type == "tgt" and fpfn is None:
                continue

            ds_key = ds.replace("-", "_")
            pretrained_dir = getattr(ViTExperiment, ds_key).OUTPUT_DIR.format(k=0)

            # 5回すべてで登場したインデックスのみを残す
            ours_counter = defaultdict(int)
            bl_counter = defaultdict(int)

            for reps_id in range(num_reps):
                ours = load_indices(pretrained_dir, tgt_rank, wnum, "ours", reps_id, misclf_type, fpfn)
                bl = load_indices(pretrained_dir, tgt_rank, wnum, "bl", reps_id, misclf_type, fpfn)
                if ours is None or bl is None:
                    continue
                for idx in ours:
                    ours_counter[idx] += 1
                for idx in bl:
                    bl_counter[idx] += 1

            ours_final = set(idx for idx, cnt in ours_counter.items() if cnt == num_reps)
            bl_final = set(idx for idx, cnt in bl_counter.items() if cnt == num_reps)

            # プロット
            col = type_fpfn_labels.index((misclf_type, fpfn))
            # print(ds, row, col, misclf_type, fpfn, len(ours_final), len(bl_final))
            ax = fig.add_subplot(gs[row, col])
            venn = venn2([ours_final, bl_final], set_labels=("", ""), ax=ax, set_colors=("#1f77b4", "#2ca02c"))
            # ラベルの位置を調整（例：左の0を左に少し移動）
            if venn.get_label_by_id("10"):
                x, y = venn.get_label_by_id("10").get_position()
                venn.get_label_by_id("10").set_position((x - 0.025, y))

            if venn.get_label_by_id("01"):
                x, y = venn.get_label_by_id("01").get_position()
                venn.get_label_by_id("01").set_position((x + 0.025, y))
            # 色の透明度を調整して濃くする（デフォルトは alpha=0.4）
            for subset_id in ('10', '01', '11'):
                patch = venn.get_patch_by_id(subset_id)
                if patch:
                    patch.set_alpha(1)  # 1.0で完全に不透明にもできる
                    if subset_id == '11':
                        patch.set_color('grey')
                        patch.set_edgecolor('grey')  # 共通部分の境界線を黒に
                        patch.set_alpha(0.3)
            misclf_str = "SRC-TGT" if misclf_type == "src_tgt" else f"TGT-{fpfn.upper()}"
            ds_repr = "C100" if ds == "c100" else "TinyImg"
            ax.set_title(f"{ds_repr}, {misclf_str}", fontstyle="italic")

    fig.legend(["REPTRAN", "ArachneW"], loc="lower center", ncol=2, fontsize=12, frameon=True)
    fname = f"repaired_venn_rank{tgt_rank}_n{wnum}.pdf"
    plt.tight_layout(rect=[0, 0.1, 1, 1])
    plt.savefig(os.path.join(output_dir, fname), bbox_inches='tight', dpi=300)
    print(f"Saved: {os.path.join(output_dir, fname)}")
    plt.close()


findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'sans-serif' not found because none of the following families were found: Helvetica


Saved: ./venn_outputs/repaired_venn_rank1_n236.pdf




Saved: ./venn_outputs/repaired_venn_rank1_n472.pdf
Saved: ./venn_outputs/repaired_venn_rank1_n944.pdf




Saved: ./venn_outputs/repaired_venn_rank2_n236.pdf
Saved: ./venn_outputs/repaired_venn_rank2_n472.pdf




Saved: ./venn_outputs/repaired_venn_rank2_n944.pdf
Saved: ./venn_outputs/repaired_venn_rank3_n236.pdf




Saved: ./venn_outputs/repaired_venn_rank3_n472.pdf
Saved: ./venn_outputs/repaired_venn_rank3_n944.pdf


# (repaired) N_w ごとにまとめたプロット

In [80]:
output_dir = "./venn_outputs"
os.makedirs(output_dir, exist_ok=True)

type_fpfn_labels = [("src_tgt", None), ("tgt", "fp"), ("tgt", "fn")]
    

for wnum in w_num_list:
    fig = plt.figure(figsize=(12, 6))
    gs = gridspec.GridSpec(len(tgt_rank_list), len(type_fpfn_labels) * len(ds_list), figure=fig, hspace=0.3, wspace=0.2)

    for row, tgt_rank in enumerate(tgt_rank_list):
        for id_ds, ds in enumerate(ds_list):
            ds_key = ds.replace("-", "_")
            pretrained_dir = getattr(ViTExperiment, ds_key).OUTPUT_DIR.format(k=0)
            
            for misclf_type, fpfn in product(misclf_type_list, fpfn_list):
                if misclf_type == "src_tgt" and fpfn is not None:
                    continue
                if misclf_type == "tgt" and fpfn is None:
                    continue
                
                # 5回すべてで登場したインデックスのみを残す
                ours_counter = defaultdict(int)
                bl_counter = defaultdict(int)

                for reps_id in range(num_reps):
                    ours = load_indices(pretrained_dir, tgt_rank, wnum, "ours", reps_id, misclf_type, fpfn)
                    bl = load_indices(pretrained_dir, tgt_rank, wnum, "bl", reps_id, misclf_type, fpfn)
                    if ours is None or bl is None:
                        continue
                    for idx in ours:
                        ours_counter[idx] += 1
                    for idx in bl:
                        bl_counter[idx] += 1

                ours_final = set(idx for idx, cnt in ours_counter.items() if cnt == num_reps)
                bl_final = set(idx for idx, cnt in bl_counter.items() if cnt == num_reps)

                # プロット
                col = type_fpfn_labels.index((misclf_type, fpfn)) + id_ds * len(type_fpfn_labels)
                # print(ds, row, col, misclf_type, fpfn, len(ours_final), len(bl_final))
                ax = fig.add_subplot(gs[row, col])
                venn = venn2([ours_final, bl_final], set_labels=("", ""), ax=ax, set_colors=("#1f77b4", "#2ca02c"))
                # ラベルの位置を調整（例：左の0を左に少し移動）
                if venn.get_label_by_id("10"):
                    x, y = venn.get_label_by_id("10").get_position()
                    venn.get_label_by_id("10").set_position((x - 0.1, y))
                if venn.get_label_by_id("01"):
                    x, y = venn.get_label_by_id("01").get_position()
                    venn.get_label_by_id("01").set_position((x + 0.1, y))
                # 色の透明度を調整して濃くする（デフォルトは alpha=0.4）
                for subset_id in ('10', '01', '11'):
                    patch = venn.get_patch_by_id(subset_id)
                    label = venn.get_label_by_id(subset_id)
                    if patch:
                        patch.set_edgecolor("black")     # 枠線の色
                        patch.set_linewidth(1.0)         # 枠線の太さ
                        patch.set_alpha(1.0)  # 1.0で完全に不透明にもできる
                        if subset_id == '11':
                            patch.set_color("#888888")
                            patch.set_alpha(0.2)
                    if label:
                        # 値が0なら非表示に
                        if label.get_text() == "0":
                            label.set_text("")
                        label.set_fontsize(20)
                        label.set_fontweight("bold")
                        if subset_id == '10' or subset_id == '01':
                            label.set_color("black")
                            label.set_path_effects([
                                path_effects.Stroke(linewidth=2.0, foreground='white'),
                                path_effects.Normal()
                            ])
                misclf_str = "SRC-TGT" if misclf_type == "src_tgt" else f"TGT-{fpfn.upper()}"
                ds_repr = "C100" if ds == "c100" else "TinyImg"
                ax.set_title(f"{ds_repr},\nRank {tgt_rank}, {misclf_str}", fontsize=10, pad=1)

    fig.legend(["REPTRAN", "ArachneW"], loc="lower center", ncol=2, fontsize=12, frameon=True)
    fname = f"repaired_venn_n{wnum}.pdf"
    plt.tight_layout(rect=[0, 0, 1, 1])
    plt.savefig(os.path.join(output_dir, fname), bbox_inches='tight', dpi=300)
    print(f"Saved: {os.path.join(output_dir, fname)}")
    plt.close()




Saved: ./venn_outputs/repaired_venn_n236.pdf




Saved: ./venn_outputs/repaired_venn_n472.pdf




Saved: ./venn_outputs/repaired_venn_n944.pdf


# (broken) N_w ごとにまとめたプロット

In [32]:
output_dir = "./venn_outputs"
os.makedirs(output_dir, exist_ok=True)

type_fpfn_labels = [("src_tgt", None), ("tgt", "fp"), ("tgt", "fn")]
    

for wnum in w_num_list:
    fig = plt.figure(figsize=(12, 6))
    gs = gridspec.GridSpec(len(tgt_rank_list), len(type_fpfn_labels) * len(ds_list), figure=fig, hspace=0.35, wspace=0.2)

    for row, tgt_rank in enumerate(tgt_rank_list):
        for id_ds, ds in enumerate(ds_list):
            ds_key = ds.replace("-", "_")
            pretrained_dir = getattr(ViTExperiment, ds_key).OUTPUT_DIR.format(k=0)
            
            for misclf_type, fpfn in product(misclf_type_list, fpfn_list):
                if misclf_type == "src_tgt" and fpfn is not None:
                    continue
                if misclf_type == "tgt" and fpfn is None:
                    continue
                
                # 5回すべてで登場したインデックスのみを残す
                ours_counter = defaultdict(int)
                bl_counter = defaultdict(int)

                for reps_id in range(num_reps):
                    ours = load_indices(pretrained_dir, tgt_rank, wnum, "ours", reps_id, misclf_type, fpfn, indices_type="break_indices_overall")
                    bl = load_indices(pretrained_dir, tgt_rank, wnum, "bl", reps_id, misclf_type, fpfn, indices_type="break_indices_overall")
                    if ours is None or bl is None:
                        continue
                    for idx in ours:
                        ours_counter[idx] += 1
                    for idx in bl:
                        bl_counter[idx] += 1

                ours_final = set(idx for idx, cnt in ours_counter.items() if cnt == num_reps)
                bl_final = set(idx for idx, cnt in bl_counter.items() if cnt == num_reps)

                # プロット
                col = type_fpfn_labels.index((misclf_type, fpfn)) + id_ds * len(type_fpfn_labels)
                # print(ds, row, col, misclf_type, fpfn, len(ours_final), len(bl_final))
                ax = fig.add_subplot(gs[row, col])
                venn = venn2([ours_final, bl_final], set_labels=("", ""), ax=ax, set_colors=("#1f77b4", "#2ca02c"))
                # ラベルの位置を調整（例：左の0を左に少し移動）
                if venn.get_label_by_id("10"):
                    x, y = venn.get_label_by_id("10").get_position()
                    venn.get_label_by_id("10").set_position((x - 0.1, y))

                if venn.get_label_by_id("01"):
                    x, y = venn.get_label_by_id("01").get_position()
                    venn.get_label_by_id("01").set_position((x + 0.1, y))
                # 色の透明度を調整して濃くする（デフォルトは alpha=0.4）
                for subset_id in ('10', '01', '11'):
                    patch = venn.get_patch_by_id(subset_id)
                    label = venn.get_label_by_id(subset_id)
                    if patch:
                        patch.set_edgecolor("black")     # 枠線の色
                        patch.set_linewidth(1.0)         # 枠線の太さ
                        patch.set_alpha(1.0)  # 1.0で完全に不透明にもできる
                        if subset_id == '11':
                            patch.set_color('grey')
                            patch.set_edgecolor('grey')  # 共通部分の境界線を黒に
                            patch.set_alpha(0.3)
                    if label:
                        label.set_fontsize(14)
                        if subset_id == '10' or subset_id == '01':
                            label.set_fontweight("bold")
                            label.set_color("white")
                            label.set_path_effects([
                                path_effects.Stroke(linewidth=2.5, foreground='black'),
                                path_effects.Normal()
                            ])
                misclf_str = "SRC-TGT" if misclf_type == "src_tgt" else f"TGT-{fpfn.upper()}"
                ds_repr = "C100" if ds == "c100" else "TinyImg"
                ax.set_title(f"{ds_repr},\nRank {tgt_rank}, {misclf_str}", fontsize=10, pad=1)

    fig.legend(["REPTRAN", "ArachneW"], loc="lower center", ncol=2, fontsize=12, frameon=True)
    fname = f"broken_venn_n{wnum}.pdf"
    plt.tight_layout(rect=[0, 0, 1, 1])
    plt.savefig(os.path.join(output_dir, fname), bbox_inches='tight', dpi=300)
    print(f"Saved: {os.path.join(output_dir, fname)}")
    plt.close()


findfont: Font family ['serif'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'serif' not found because none of the following families were found: Times, Palatino, serif
findfont: Font family ['serif'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'serif' not found because none of the following families were found: Times, Palatino, serif
findfont: Font family ['serif'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'serif' not found because none of the following families were found: Times, Palatino, serif
findfont: Font family ['serif'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'serif' not found because none of the following families were found: Times, Palatino, serif


Saved: ./venn_outputs/broken_venn_n236.pdf




Saved: ./venn_outputs/broken_venn_n472.pdf




Saved: ./venn_outputs/broken_venn_n944.pdf
