# Initialization


## Imports


In [None]:
import numpy as np
import pandas as pd
import altair as alt
from scipy import stats

alt.data_transformers.enable("vegafusion")

## Utils


In [None]:
def read_wandb_table(path: str) -> pd.DataFrame:
    import json

    with open(path, "r") as file:
        data = json.load(file)
    columns = data["columns"]
    rows = data["data"]
    return pd.DataFrame(rows, columns=columns)

# Metrics


## Download


In [None]:
# import wandb

# from utils.wandb import wandb_path

# runs = wandb.Api().runs(
#     wandb_path(False),
#     filters={
#         "jobType": "test",
#         "createdAt": {"$lt": "2026-01-08T00:00:00Z"},
#     },
# )

# len(runs)

# for run in runs:
#     run_id = run.name.split(" ")[2]
#     dataset = run.config["test_dataset"]
#     model = run.config["model"]
#     group = run.group
#     run.logged_artifacts()[1].download(
#         f"logs/wandb/3_metrics/{group} {model} {dataset} {run_id}"
#     )
#     print(group, model, dataset, run_id)

In [None]:
import os

df_list = []

wandb_dir = "logs/wandb/3_metrics"
for dir in os.listdir(wandb_dir):
    group, model, dataset, _ = dir.split(" ")
    df = read_wandb_table(f"{wandb_dir}/{dir}/metrics.table.json")
    df.drop(columns=["type", "epoch", "loss"], inplace=True)
    df.insert(0, "dataset", dataset)
    df.insert(0, "model", model)
    df.insert(0, "method", group)
    df_list.append(df)

meta_metrics_df = pd.concat(df_list)
meta_metrics_df.to_csv("logs/wandb/3_meta_metrics.csv", index=False)

## Preparation


In [None]:
meta_metrics_df = pd.read_csv("logs/wandb/3_meta_metrics.csv")

meta_metrics_df["iou"] = meta_metrics_df["iou"] * 100


def def_full_method(row):
    if "deeplab" in row["model"]:
        model = "DL3+"
    else:
        model = "UNM"
    return f"{row['method']} {model}"


meta_metrics_df["method"] = meta_metrics_df.apply(def_full_method, axis=1)
meta_metrics_df.drop(columns=["model"], inplace=True)

meta_metrics_df

## Comparison


In [None]:
def compare_metrics(use_best: bool) -> pd.DataFrame:
    data = meta_metrics_df[
        ~(
            (meta_metrics_df["sparsity_mode"] == "point")
            & (meta_metrics_df["sparsity_value"] == 1)
        )
    ]

    if use_best:
        comparison_df = data[
            ["dataset", "method", "shot", "sparsity_mode", "sparsity_value", "iou"]
        ].copy()
        comparison_df = (
            comparison_df.groupby(
                ["dataset", "method", "shot", "sparsity_mode", "sparsity_value"],
                dropna=False,
            )
            .agg(
                iou=("iou", "mean"),
                iou_std=("iou", "std"),
                iou_count=("iou", "count"),
            )
            .reset_index()
        )
        comparison_df = comparison_df.loc[
            comparison_df.groupby(["dataset", "method"])["iou"].idxmax()
        ]
    else:
        comparison_df = (
            data[["dataset", "method", "iou"]]
            .groupby(["dataset", "method"])
            .agg(
                iou=("iou", "mean"),
                iou_std=("iou", "std"),
                iou_count=("iou", "count"),
            )
        ).reset_index()

    comparison_df["iou_std_err"] = (
        comparison_df["iou_std"] / comparison_df["iou_count"] ** 0.5
    )
    comparison_df["iou_low"] = (
        comparison_df["iou"] - 1.96 * comparison_df["iou_std_err"]
    )
    comparison_df["iou_high"] = (
        comparison_df["iou"] + 1.96 * comparison_df["iou_std_err"]
    )

    return comparison_df

In [None]:
comparison_df = compare_metrics(False)
best_comparison_df = compare_metrics(True)

In [None]:
comparison_df.sort_values(by=["dataset", "iou"], ascending=False)

In [None]:
comparison_df.groupby("method")[["iou"]].mean().sort_values(by="iou", ascending=False)

## Visualization


In [None]:
def compose_bar_chart(
    data: pd.DataFrame,
    scale: tuple[float, float],
    title: str,
    hide_header: bool = False,
):
    ordered_methods = [
        "PS UNM",
        "PA UNM",
        "PAS UNM",
        "PS DL3+",
        "PA DL3+",
        "PAS DL3+",
    ]

    color_scale = alt.Scale(
        domain=ordered_methods,
        range=["#ff9896", "#aec7e8", "#98df8a", "#d62728", "#1f77b4", "#2ca02c"],
    )

    base = alt.Chart(data).encode(
        x=alt.X(
            "method:N",
            title=None,
            sort=ordered_methods,
            axis=alt.Axis(labels=False, ticks=False),
        ),
    )
    y_scale = alt.Scale(domain=scale, clamp=True)

    layered = (
        base.mark_bar().encode(
            y=alt.Y(
                "iou:Q",
                title=None,
                scale=y_scale,
            ),
            color=alt.Color(
                "method:N",
                scale=color_scale,
                title="Variant",
                legend=alt.Legend(
                    # orient="right",
                    orient="bottom",
                    direction="horizontal",
                    titleAnchor="start",
                    columns=3,
                ),
            ),
        )
        + base.mark_errorbar(
            extent="ci", thickness=2.0, ticks=True, color="black"
        ).encode(
            y=alt.Y(
                "iou_low:Q",
                title=None,
                scale=y_scale,
            ),
            y2="iou_high:Q",
        )
        + base.mark_text(align="center", baseline="top", dy=85, fontSize=16).encode(
            text=alt.Text("iou:Q", format=".0f"),
        )
    ).properties(width=200, height=200)  # type: ignore

    header = alt.Header(labelFontSize=0) if hide_header else alt.Header()
    return layered.facet(
        column=alt.Column("dataset:N", title=title, header=header),
        spacing=4,
    )

In [None]:
(
    compose_bar_chart(comparison_df, (30, 80), "Comparison of Overall Mean IoU (%)")
    .configure_axis(labelFontSize=14, titleFontSize=16)
    .configure_header(labelFontSize=14, titleFontSize=16)
    .configure_legend(labelFontSize=14, titleFontSize=16)
)

In [None]:
(
    compose_bar_chart(best_comparison_df, (35, 85), "Comparison of Best Mean IoU (%)")
    .configure_axis(labelFontSize=14, titleFontSize=16)
    .configure_header(labelFontSize=14, titleFontSize=16)
    .configure_legend(labelFontSize=14, titleFontSize=16)
)

In [None]:
def compose_line_chart(method: str, scale: tuple[float, float]) -> alt.Chart:
    color_scale = alt.Scale(
        domain=["point", "contour", "grid", "region", "skeleton"],
        scheme="category10",
    )

    new_data = meta_metrics_df.copy()
    # new_data = meta_metrics_df[
    #     ~(
    #         (meta_metrics_df["sparsity_mode"] == "point")
    #         & (meta_metrics_df["sparsity_value"] == 1)
    #     )
    # ]
    new_data = new_data[new_data["method"] == method]
    new_data["shot"] = new_data["shot"].apply(
        lambda x: f"{x} shot" if x == 1 else f"{x} shots"
    )

    new_data = (
        new_data.groupby(["dataset", "shot", "sparsity_mode", "sparsity_value"])
        .agg(
            iou_mean=("iou", "mean"),
            iou_var=("iou", "var"),
            count=("iou", "count"),
        )
        .reset_index()
    )

    new_data_list = []
    for shot in new_data["shot"].unique():
        for sparsity_mode in new_data["sparsity_mode"].unique():
            for sparsity_value in new_data["sparsity_value"].unique():
                subset = new_data[
                    (new_data["shot"] == shot)
                    & (new_data["sparsity_mode"] == sparsity_mode)
                    & (new_data["sparsity_value"] == sparsity_value)
                ]
                if len(subset) == 0:
                    continue
                count = subset["count"].sum()
                iou_mean = subset["iou_mean"].mean()
                iou_var = subset["iou_var"].mean() + np.var(subset["iou_mean"])
                iou_delta = stats.t.ppf(0.975, count) * np.sqrt(iou_var / count)
                if sparsity_mode == "point":
                    sparsity_value /= 50
                new_data_list.append(
                    {
                        "shot": shot,
                        "sparsity_mode": sparsity_mode,
                        "sparsity_value": sparsity_value,
                        "iou_mean": iou_mean,
                        "iou_lower": iou_mean - iou_delta,
                        "iou_upper": iou_mean + iou_delta,
                        "iou_var": iou_var,
                    }
                )
    new_data = pd.DataFrame(new_data_list)

    encodings = {
        "x": alt.X(
            "sparsity_value", title=None, scale=alt.Scale(domain=[0.1, 1.0], clamp=True)
        ),
        "color": alt.Color(
            "sparsity_mode:N",
            title="Sparse Label Type",
            scale=color_scale,
            legend=alt.Legend(
                orient="bottom", direction="horizontal", titleAnchor="start"
            ),
        ),
    }
    y_kwargs = {
        "title": "Mean IoU (%)",
        "scale": alt.Scale(domain=scale, clamp=True),
    }

    error_area = (
        alt.Chart(new_data)
        .mark_area(opacity=0.1)
        .encode(
            y=alt.Y("iou_upper", **y_kwargs),
            y2=alt.Y2("iou_lower"),
            **encodings,
        )
    )
    line = (
        alt.Chart(new_data)
        .mark_line(strokeWidth=1.5)
        .encode(y=alt.Y("iou_mean", **y_kwargs), **encodings)
    )
    point = (
        alt.Chart(new_data)
        .mark_point(size=7)
        .encode(y=alt.Y("iou_mean", **y_kwargs), **encodings)
    )

    combined_chart = line + point + error_area  # type: ignore
    combined_chart = (
        combined_chart.properties(width=170, height=200)
        .facet(
            column=alt.Column(
                "shot",
                sort=["1 shot", "5 shots", "10 shots", "15 shots", "20 shots"],
                header=alt.Header(title="Density Values", titleOrient="bottom"),
            ),
            spacing=10,
        )
        .resolve_scale(x="independent")
        .configure_axis(labelFontSize=12, titleFontSize=16)
        .configure_header(labelFontSize=16, titleFontSize=16)
        .configure_legend(labelFontSize=14, titleFontSize=16)
    )

    return combined_chart

In [None]:
compose_line_chart("PAS DL3+", (55, 75))

In [None]:
compose_line_chart("PA DL3+", (55, 75))