In [None]:
%load_ext autoreload
%autoreload 2

In [61]:
import pandas as pd
import numpy as np


def soft_label_to_nparray(d: dict) -> np.ndarray:
    array = np.zeros(6)
    for k, v in d.items():
        try:
            array[int(k) - 1] = v
        except ValueError:
            logger.warning("Invalid key: %s", k)
            return pd.NA
    return array


In [None]:
from lewidi_lib.src.funcs import load_dataset, enable_logging

enable_logging()

ddf = load_dataset(dataset_name="CSC")
ddf = ddf.assign(request_idx=range(len(ddf)))
ddf["target"] = ddf["soft_label"].apply(soft_label_to_nparray)
print(len(ddf))
ddf.head(2)

In [None]:
import logging

logger = logging.getLogger(__name__)

rdf = pd.read_json("responses.jsonl", lines=True)
are_na = len(rdf.query("response.isna()"))
logger.info("Number of responses that are NA: %d", are_na)
rdf.query("~response.isna()", inplace=True)
rdf["response"] = rdf["response"].str.strip()
rdf.head(2)

In [None]:
import json_repair

rdf["pred"] = rdf["response"].apply(json_repair.loads).apply(soft_label_to_nparray)
logger.info("Dropping %d NA predictions", len(rdf.query("pred.isna()")))
rdf.query("~pred.isna()", inplace=True)

rdf["pred_sum"] = rdf["pred"].apply(lambda x: x.sum())
rdf["is_valid_pred"] = (rdf["pred_sum"] - 1).abs() < 0.01
rdf["reasoning_isnull"] = rdf["reasoning"].isna()

# Add columns indicating if the run has reasoning
reasoning_by_run = rdf.groupby("run_id", as_index=False).agg(
    is_reasoning=("reasoning_isnull", lambda x: ~x.max())
)
rdf = rdf.merge(reasoning_by_run, on="run_id", how="left").drop(
    columns=["reasoning_isnull"]
)
rdf.head(2)

In [None]:
run_info = rdf.groupby("run_id").agg({"is_valid_pred": "mean"}).merge(
    reasoning_by_run, on="run_id", how="left"
)
run_info

In [None]:
logger.info(
    "Dropping %d predictions that don't sum to 1", len(rdf.query("~is_valid_pred"))
)
rdf.query("is_valid_pred", inplace=True)

In [None]:
joint_df = pd.merge(
    ddf[["request_idx", "target"]],
    rdf,
    on="request_idx",
)
joint_df.head(2)

In [79]:
import scipy


def l0_loss(tgt: np.ndarray, pred: np.ndarray) -> float:
    return np.abs(tgt - pred).sum()


def ws_loss(tgt: np.ndarray, pred: np.ndarray) -> float:
    """wasserstein distance between two distributions https://stackoverflow.com/a/76061410/5730291"""
    return scipy.stats.wasserstein_distance(range(6), range(6), tgt, pred)


joint_df["l0_loss"] = joint_df.apply(
    lambda row: l0_loss(row["target"], row["pred"]), axis=1
)
joint_df["ws_loss"] = joint_df.apply(
    lambda row: ws_loss(row["target"], row["pred"]), axis=1
)