In [3]:
import os
import json
from glob import glob
from collections import defaultdict

import pandas as pd

def aggregate_comp_metrics(base_dir: str,
                           save_csv: str | None = None,
                           save_json: str | None = None) -> pd.DataFrame:
    """
    base_dir: .../checkpoints/generated_loras  (그 아래에 <TASK>/compressed/lora_0/comp_metrics.json)
    반환: 모듈별 A/B mse, bpp의 (단순 평균, 파라미터 가중 평균)를 담은 DataFrame
    """

    # comp_metrics.json 파일들 수집
    pattern = os.path.join(base_dir, "*", "compressed", "lora_0", "comp_metrics.json")
    files = glob(pattern)
    if not files:
        raise FileNotFoundError(f"No comp_metrics.json found under: {pattern}")

    # 누적용 딕셔너리
    acc = defaultdict(lambda: {
        # 단순합(나중에 count로 나눠 단순평균)
        "sum_mse_A": 0.0, "sum_mse_B": 0.0,
        "sum_bpp_A": 0.0, "sum_bpp_B": 0.0,
        "count": 0,

        # 파라미터 가중 평균 계산용
        "w_sum_mse_A": 0.0, "w_sum_mse_B": 0.0,
        "w_sum_bpp_A": 0.0, "w_sum_bpp_B": 0.0,
        "sum_params_A": 0,  "sum_params_B": 0,
    })

    # 파일 단위로 읽어와서 모듈별로 누적
    for fp in files:
        with open(fp, "r") as f:
            data = json.load(f)  # {module: {...}}

        for module, m in data.items():
            # 필수 키가 모두 있는지 체크 (없으면 건너뜀)
            need_keys = ["mse_A", "mse_B", "bpp_A", "bpp_B", "num_params_A", "num_params_B"]
            if not all(k in m for k in need_keys):
                continue

            acc[module]["sum_mse_A"] += float(m["mse_A"])
            acc[module]["sum_mse_B"] += float(m["mse_B"])
            acc[module]["sum_bpp_A"] += float(m["bpp_A"])
            acc[module]["sum_bpp_B"] += float(m["bpp_B"])
            acc[module]["count"]     += 1

            nA = int(m["num_params_A"])
            nB = int(m["num_params_B"])
            acc[module]["w_sum_mse_A"] += float(m["mse_A"]) * nA
            acc[module]["w_sum_mse_B"] += float(m["mse_B"]) * nB
            acc[module]["w_sum_bpp_A"] += float(m["bpp_A"]) * nA
            acc[module]["w_sum_bpp_B"] += float(m["bpp_B"]) * nB
            acc[module]["sum_params_A"] += nA
            acc[module]["sum_params_B"] += nB

    # 집계 → DataFrame
    rows = []
    for module, s in acc.items():
        c  = max(1, s["count"])
        nA = max(1, s["sum_params_A"])
        nB = max(1, s["sum_params_B"])

        row = {
            "module": module,
            # 단순 평균(파일 개수 기준)
            "mean_mse_A":  s["sum_mse_A"] / c,
            "mean_mse_B":  s["sum_mse_B"] / c,
            "mean_bpp_A":  s["sum_bpp_A"] / c,
            "mean_bpp_B":  s["sum_bpp_B"] / c,

            # 파라미터 가중 평균
            "wmean_mse_A": s["w_sum_mse_A"] / nA,
            "wmean_mse_B": s["w_sum_mse_B"] / nB,
            "wmean_bpp_A": s["w_sum_bpp_A"] / nA,
            "wmean_bpp_B": s["w_sum_bpp_B"] / nB,

            # 참고로 남기는 집계량
            "num_files": s["count"],
            "sum_params_A": s["sum_params_A"],
            "sum_params_B": s["sum_params_B"],
        }
        rows.append(row)

    df = pd.DataFrame(rows).sort_values("module").reset_index(drop=True)

    if save_csv:
        os.makedirs(os.path.dirname(save_csv), exist_ok=True)
        df.to_csv(save_csv, index=False)
    if save_json:
        os.makedirs(os.path.dirname(save_json), exist_ok=True)
        with open(save_json, "w") as f:
            json.dump(df.to_dict(orient="records"), f, indent=2)

    return df


if __name__ == "__main__":
    # 예시 경로 (너의 케이스에 맞게 base_dir만 바꿔주면 됨)
    base_dir = "/workspace/Weight_compression/text-to-lora/train_outputs/compnet_recon/hyper_lora/20250904-101518_5rcL6bSr/checkpoints/generated_loras"

    df = aggregate_comp_metrics(
        base_dir,
        save_csv=os.path.join(base_dir, "aggregated_comp_metrics.csv"),
        save_json=os.path.join(base_dir, "aggregated_comp_metrics.json"),
    )
    
    # with pd.option_context('display.float_format', '{:.6e}'.format):
    print(df)


   module  mean_mse_A    mean_mse_B  mean_bpp_A  mean_bpp_B  wmean_mse_A  \
0  q_proj    0.000002  1.131920e-06    2.565513    2.865241     0.000002   
1  v_proj    0.000002  1.762381e-07    2.611844   11.812771     0.000002   

    wmean_mse_B  wmean_bpp_A  wmean_bpp_B  num_files  sum_params_A  \
0  1.131920e-06     2.565513     2.865241         29      30408704   
1  1.762381e-07     2.611844    11.812771         29      30408704   

   sum_params_B  
0      30408704  
1       7602176  
