In [25]:
import pandas as pd

from coolchic.eval.results import parse_hypernet_metrics
from coolchic.utils.paths import COOLCHIC_REPO_ROOT

In [33]:
dataset = "clic20-pro-valid"
# dataset = "kodak"
ablation_dir = COOLCHIC_REPO_ROOT / "switch-ablation-exps"

all_ablation_metrics = {}
for dir in ablation_dir.iterdir():
    parts = dir.stem.split("_")
    dir_metrics = pd.DataFrame(
        [
            s.model_dump()
            for seq_res in parse_hypernet_metrics(dir, dataset=dataset).values()
            for s in seq_res
        ]
    ).assign(variant=dir.name if dir.name != "none" else "baseline")
    all_ablation_metrics[dir.name] = dir_metrics

all_metrics = pd.concat(
    all_ablation_metrics.values(),
    ignore_index=True,
).sort_values(by=["seq_name", "lmbda"])

In [34]:
all_metrics

Unnamed: 0,seq_name,lmbda,rate_bpp,rate_latent_bpp,rate_nn_bpp,psnr_db,n_itr,n_train_loops,variant
0,alberto-montalesi-176097,0.0001,0.901182,0.898579,0.002603,38.296180,,,synthesis_upsampling
246,alberto-montalesi-176097,0.0001,0.857501,0.853158,0.004343,38.296180,,,synthesis_arm_upsampling
492,alberto-montalesi-176097,0.0001,0.854909,0.853169,0.001740,38.205856,,,arm
738,alberto-montalesi-176097,0.0001,0.898579,0.898579,0.000000,38.205694,,,baseline
1,alberto-montalesi-176097,0.0002,0.359428,0.356822,0.002607,37.007979,,,synthesis_upsampling
...,...,...,...,...,...,...,...,...,...
981,zugr-108,0.0040,0.075918,0.075918,0.000000,34.704006,,,baseline
244,zugr-108,0.0200,0.040181,0.038201,0.001980,31.268002,,,synthesis_upsampling
490,zugr-108,0.0200,0.039226,0.035550,0.003677,31.268002,,,synthesis_arm_upsampling
736,zugr-108,0.0200,0.037237,0.035540,0.001697,31.085431,,,arm


In [37]:
agg_df = all_metrics.groupby(["variant", "lmbda"]).agg(
    {
        "rate_nn_bpp": "mean",
        "psnr_db": "mean",
        "rate_latent_bpp": "mean",
    }
)

# First group the baseline rows by 'lmbda' and subtract them from each group.
baseline = agg_df.loc["baseline"]


def subtract_baseline(group):
    lmbda_val = group.name
    base_row = baseline.loc[lmbda_val]
    return group - base_row


# Apply the subtraction grouped by 'lmbda'
diffs = agg_df.groupby("lmbda").apply(subtract_baseline, include_groups=False)
diffs = (
    diffs.reset_index(level=0, drop=True)
    # .drop("baseline", level=0)
    .reset_index()
    .sort_values(
        by=["variant", "lmbda"],
        key=lambda x: x.map(
            {
                "with ARM modulations": 0,
                "with synthesis modulations": 1,
                "full hypernetwork (ours)": 2,
            }
        ),
    )
    .rename(
        columns={
            "rate_nn_bpp": "rate_nn_bpp_diff",
            "psnr_db": "psnr_db_diff",
            "rate_latent_bpp": "rate_latent_bpp_diff",
        }
    )
    .assign(total_rate_diff=lambda x: x["rate_nn_bpp_diff"] + x["rate_latent_bpp_diff"])
)
diffs

Unnamed: 0,variant,lmbda,rate_nn_bpp_diff,psnr_db_diff,rate_latent_bpp_diff,total_rate_diff
0,arm,0.0001,0.00336,0.000188,-0.040215,-0.036855
1,baseline,0.0001,0.0,0.0,0.0,0.0
2,synthesis_arm_upsampling,0.0001,0.008481,0.071659,-0.040183,-0.031703
3,synthesis_upsampling,0.0001,0.005113,0.071659,0.0,0.005113
4,arm,0.0002,0.003281,-1.6e-05,-0.033741,-0.03046
5,baseline,0.0002,0.0,0.0,0.0,0.0
6,synthesis_arm_upsampling,0.0002,0.008493,0.070451,-0.033764,-0.02527
7,synthesis_upsampling,0.0002,0.005204,0.070451,0.0,0.005204
8,arm,0.0004,0.003413,-2.1e-05,-0.020804,-0.017391
9,baseline,0.0004,0.0,0.0,0.0,0.0


In [42]:
diffs.groupby("variant").mean().reset_index().round(4)

Unnamed: 0,variant,lmbda,rate_nn_bpp_diff,psnr_db_diff,rate_latent_bpp_diff,total_rate_diff
0,arm,0.0043,0.0033,0.0,-0.0189,-0.0155
1,baseline,0.0043,0.0,0.0,0.0,0.0
2,synthesis_arm_upsampling,0.0043,0.0081,0.0708,-0.0188,-0.0107
3,synthesis_upsampling,0.0043,0.0048,0.0708,0.0,0.0048
