In [1]:
import os
import json
import pandas as pd
from pandas.api.types import is_float_dtype
from collections import defaultdict
from utils.constant import ViTExperiment
from itertools import product

In [2]:
# 変数の定義
ds = "c100"
k = 0
tgt_rank_list = [1, 2, 3]
misclf_type_list = ["src_tgt", "tgt"]
fpfn_list = [None, "fp", "fn"]
tgt_split_list = ["repair", "test"]
num_reps = 5
alpha = 10 / 11
fl_method_list = ["random"]
w_num = 472
setting_id = f"n{w_num}_alpha{alpha}_boundsArachne"

In [3]:
# 保存先ディレクトリ（仮定）
exp_obj = getattr(ViTExperiment, ds.replace("-", "_"))
pretrained_dir = exp_obj.OUTPUT_DIR.format(k=k)

for fl_method in fl_method_list:
    # 結果の格納用
    results = []
    
    for tgt_rank, misclf_type, fpfn in product(tgt_rank_list, misclf_type_list, fpfn_list):
        if (misclf_type in ["src_tgt", "all"] and fpfn is not None) or (misclf_type == "tgt" and fpfn is None):
            continue

        misclf_ptn = misclf_type if fpfn is None else f"{misclf_type}_{fpfn}"
        save_dir = os.path.join(pretrained_dir, f"misclf_top{tgt_rank}", f"{misclf_ptn}_repair_weight_by_de")
        row = {"ds": ds, "tgt_rank": tgt_rank, "misclf_type": misclf_ptn, "fl_method": fl_method}

        for tgt_split in tgt_split_list:
            rr_list, br_list, racc_list, diff_corr_list, t_repair_list = [], [], [], [], []
            for reps_id in range(num_reps):
                filename = f"exp-repair-4-1-metrics_for_{tgt_split}_{setting_id}_{fl_method}_reps{reps_id}.json"
                json_path = os.path.join(save_dir, filename)
                if not os.path.exists(json_path):
                    raise FileNotFoundError(f"JSON file not found: {json_path}")
                with open(json_path, "r") as f:
                    d = json.load(f)
                rr_list.append(d.get("repair_rate_tgt"))
                br_list.append(d.get("break_rate_overall"))
                racc_list.append(d.get("r_acc"))
                diff_corr_list.append(d.get("diff_correct"))
                if tgt_split == "repair":
                    t_repair_list.append(d.get("tot_time"))

            row[f"RR_{tgt_split}"] = sum(rr_list)/len(rr_list)
            row[f"BR_{tgt_split}"] = sum(br_list)/len(br_list)
            row[f"Racc_{tgt_split} (#diff)"] = f"{sum(racc_list)/len(racc_list):.4f} ({sum(diff_corr_list)/len(diff_corr_list):.1f})"
            if tgt_split == "repair":
                if fl_method == "ours":
                    fl_time_path = f"/src/src/exp-repair-4-1-1_time_{ds}.csv"
                    df_fl_time = pd.read_csv(fl_time_path)
                    fpfn_match = "" if fpfn is None else fpfn
                    matched_row = df_fl_time[
                        (df_fl_time["ds"] == ds) &
                        (df_fl_time["k"] == k) &
                        (df_fl_time["tgt_rank"] == tgt_rank) &
                        (df_fl_time["misclf_type"] == misclf_type) &
                        (df_fl_time["fpfn"].fillna("") == fpfn_match)
                    ]
                    row["t_fl"] = matched_row["elapsed_time"].values[0]
                else:
                    assert fl_method == "random", "Unknown FL method"
                    row["t_fl"] = 0.0  # random methodではFL時間は0とする
                row["t_repair"] = sum(t_repair_list)/len(t_repair_list)
        results.append(row)
    # データフレーム化
    df_flat = pd.DataFrame(results)
    # 小数表示のフォーマット設定（小数第3位）
    float_cols = [col for col in df_flat.columns if is_float_dtype(df_flat[col])]
    # 表示桁数を揃える（実体は変えず文字列化せず）
    df_flat[float_cols] = df_flat[float_cols].round(3)
    csv_path = f"./exp-repair-4-2-{setting_id}_{fl_method}_{ds}.csv"
    # 実行時間列を最後に移動
    time_cols = ["t_fl", "t_repair"]
    other_cols = [col for col in df_flat.columns if col not in time_cols]
    df_flat = df_flat[other_cols + time_cols]
    df_flat.to_csv(csv_path, index=False, float_format="%.4f")