# Plot 1: RD plots for avgs in a dataset and some images

In [None]:
from typing import Literal

import pandas as pd

from coolchic.eval.hypernet import (
    get_hypernet_flops,
    plot_hypernet_rd,
    plot_hypernet_rd_avg,
)
from coolchic.eval.results import parse_hypernet_metrics
from coolchic.utils.paths import DATA_DIR, RESULTS_DIR

In [None]:
sweep_path = RESULTS_DIR / "exps/copied/delta-hn/from-orange/"
compare_no_path = RESULTS_DIR / "exps/copied/no-cchic/orange-nocc/"

In [None]:
def rd_plots_from_dataset(dataset: Literal["kodak", "clic20-pro-valid"]) -> None:
    metrics = parse_hypernet_metrics(sweep_path, dataset=dataset, premature=True)
    df = pd.concat(
        [
            pd.DataFrame(
                [s.model_dump() for seq_res in metrics.values() for s in seq_res]
            ).assign(anchor="hypernet"),
            pd.DataFrame(  # For comparison with no hypernet.
                [
                    s.model_dump()
                    for seq_res in parse_hypernet_metrics(
                        compare_no_path, dataset, premature=True
                    ).values()
                    for s in seq_res
                ]
            ).assign(anchor="no-coolchic"),
        ],
    ).sort_values(by=["seq_name", "lmbda"])  # So plots come out nice and in order.

    plot_hypernet_rd_avg(df, dataset=dataset)

    all_images = sorted(list((DATA_DIR / dataset).glob("*.png")), key=lambda x: x.stem)
    for img in all_images[:5]:
        seq_name = img.stem
        plot_hypernet_rd(seq_name, df, dataset=dataset)

In [None]:
rd_plots_from_dataset("kodak")

In [None]:
rd_plots_from_dataset("clic20-pro-valid")

# Plot 2: flops vs BD-rate

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from coolchic.eval.hypernet import get_hypernet_flops
from coolchic.hypernet.hypernet import DeltaWholeNet, NOWholeNet

sns.set_theme(context="notebook", style="whitegrid")

In [None]:
flops = {
    "coolchic": (
        coolchic_fwd := get_hypernet_flops(DeltaWholeNet, get_coolchic_flops=True)
    ),
    "coolchic_step": 3 * coolchic_fwd,
    "nocc": get_hypernet_flops(NOWholeNet),
    "delta": get_hypernet_flops(DeltaWholeNet),
}

bd_rates = {
    "kodak": {
        "nocc": 33.478488207355284,
        "delta": 33.478488207355284,
    },
    "clic20-pro-valid": {
        "nocc": 36.233904571367106,
        "delta": 28.284958907776783,
    },
}

metrics_dfg = pd.DataFrame(
    [
        (dataset, method, bd_rate, flops[method])
        for dataset, bd_rates in bd_rates.items()
        for method, bd_rate in bd_rates.items()
    ],
    columns=["dataset", "method", "bd_rate", "flops"],
)
metrics_dfg["num_coolchic_steps"] = metrics_dfg["flops"] / flops["coolchic_step"]
metrics_dfg.head()

In [None]:
fig, ax = plt.subplots()
sns.scatterplot(
    data=metrics_dfg[metrics_dfg["dataset"] == "clic20-pro-valid"],
    x="num_coolchic_steps",
    y="bd_rate",
    hue="method",
    ax=ax,
)
ax.set_title("BD Rate vs. Equivalent number of Cool-chic Steps, CLIC20 dataset")
sns.despine()

# Plot 3: finetuning

In [None]:
from collections import defaultdict

from coolchic.eval.hypernet import find_crossing_it
from coolchic.utils.paths import COOLCHIC_REPO_ROOT

In [None]:
finetuning_data = pd.read_csv(COOLCHIC_REPO_ROOT / "clic_finetuning_results.csv")
dataset = "clic20-pro-valid"
freq_valid = 10  # Assuming the finetunig_results file was generated with our code.

crossing_its: dict[str, dict[str, list[dict[Literal["hn", "scratch"], int]]]] = {
    "jpeg": defaultdict(list),
    "hm": defaultdict(list),
    "hypernet": defaultdict(list),
}
for image in (DATA_DIR / dataset).glob("*.png"):
    # Skip images if they are not in the results.
    if image.stem not in finetuning_data["seq_name"].values:
        continue
    plot_hypernet_rd(image.stem, finetuning_data, dataset)
    for anchor_name in crossing_its:
        crossing_its[anchor_name][image.stem].append(
            {
                "hn": find_crossing_it(
                    image.stem,
                    finetuning_data,
                    "nocc-finetuning",
                    anchor_name=anchor_name,
                    dataset=dataset,
                ),
                "scratch": find_crossing_it(
                    image.stem,
                    finetuning_data,
                    "coolchic-training",
                    anchor_name=anchor_name,
                    dataset=dataset,
                ),
            }
        )

for anchor_name, crossings_per_img in crossing_its.items():
    print(f"Crossing iterations for {anchor_name}")
    for seq_name, crossings in crossings_per_img.items():
        for cross in crossings:
            print(
                f"{seq_name:<30}, crossing iterations: "
                f"nocc-finetuning: {cross['hn'] * freq_valid if cross['hn'] > 0 else 'no':>4}, "
                f"coolchic-training: {cross['scratch'] * freq_valid if cross['scratch'] > 0 else 'no':>4}"
            )
# plt.show()

# Plot 4: show bd rate vs flops as training progresses

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from coolchic.eval.bd_rate import bd_rates_summary_anchor_name
from coolchic.eval.hypernet import get_hypernet_flops
from coolchic.eval.results import SummaryEncodingMetrics
from coolchic.hypernet.hypernet import DeltaWholeNet, NOWholeNet

sns.set_theme(context="notebook", style="whitegrid")

In [None]:
flops = {
    "coolchic": (
        coolchic_fwd := get_hypernet_flops(DeltaWholeNet, get_coolchic_flops=True)
    ),
    "coolchic_step": 3 * coolchic_fwd,
    "nocc": get_hypernet_flops(NOWholeNet),
    "delta": get_hypernet_flops(DeltaWholeNet),
}

bd_rates = {
    "kodak": {
        "nocc": 33.478488207355284,
        "delta": 33.478488207355284,
    },
    "clic20-pro-valid": {
        "nocc": 36.233904571367106,
        "delta": 28.284958907776783,
    },
}


metrics_df = pd.DataFrame(
    [
        (dataset, method, bd_rate, flops[method])
        for dataset, bd_rates in bd_rates.items()
        for method, bd_rate in bd_rates.items()
    ],
    columns=["dataset", "method", "bd_rate", "flops"],
)
metrics_df["num_coolchic_steps"] = metrics_df["flops"] / flops["coolchic_step"]
metrics_df.head()

In [None]:
dataset = "clic20-pro-valid"
CONFIG_NUM_TO_LMBDA = {"00": 0.0001, "01": 0.0004, "02": 0.001, "03": 0.004, "04": 0.02}

finetuning_dir = RESULTS_DIR / "finetuning" / dataset
finetuning_df = pd.concat([pd.read_csv(file) for file in finetuning_dir.glob("*.csv")])
finetuning_df = finetuning_df.reset_index(drop=True)

In [None]:
def get_bd_rate_from_df(df: pd.DataFrame) -> float:
    metrics = defaultdict(list)
    for _, row in df.iterrows():
        assert "seq_name" in row, "seq_name column is missing in the DataFrame"
        metrics[row["seq_name"]].append(
            SummaryEncodingMetrics(
                seq_name=row["seq_name"],
                lmbda=row["lmbda"],
                rate_bpp=row["rate_bpp"],
                psnr_db=row["psnr_db"],
            )
        )

    bds = bd_rates_summary_anchor_name(
        metrics, anchor="hm", dataset="clic20-pro-valid", only_latent_rate=False
    )
    assert len(bds) == 1, "Expected exactly one BD rate result."
    return list(bds.values())[0]


bd_df = (
    finetuning_df.groupby(["n_itr", "seq_name", "anchor"], group_keys=False)
    .apply(get_bd_rate_from_df)
    .reset_index()
    .rename(columns={0: "bd_rate"})
)
bd_df["num_coolchic_steps"] = bd_df["n_itr"] + flops["nocc"] / flops["coolchic_step"]

In [None]:
fig, ax = plt.subplots()
sns.scatterplot(
    data=metrics_df[metrics_df["dataset"] == dataset],
    x="num_coolchic_steps",
    y="bd_rate",
    hue="method",
    ax=ax,
)
avg_bd_df = (
    bd_df.groupby(["anchor", "n_itr"])
    .agg({"bd_rate": "mean", "num_coolchic_steps": "mean"})
    .reset_index()
)
select_bds = avg_bd_df[
    (avg_bd_df["anchor"] == "nocc-finetuning")
    & (avg_bd_df["n_itr"] > 50)
    & (avg_bd_df["n_itr"] < 1000)
]
sns.lineplot(
    data=select_bds,
    x="num_coolchic_steps",
    y="bd_rate",
    ax=ax,
)
ax.set_title("BD Rate vs. Equivalent number of Cool-chic Steps, CLIC20 dataset")
sns.despine()