In [None]:
import pickle
from collections.abc import Callable
from pprint import pprint

import numpy as np
import pandas as pd

In [None]:
with open("../artifacts/human_eval_cache.pkl", "rb") as f:
    cache = pickle.load(f)

len(cache)

In [None]:
keys = list(cache.keys())
values = list(cache.values())
value_keys = list(values[0].keys())
value_values = list(values[0].values())

print(
    f"Key: {type(keys[0])}",
    f"Key elements: {[type(k) for k in keys[0]]}",
    f"Value: {type(values[0])}",
    f"Value key: {type(value_keys[0])}",
    f"Value value: {type(value_values[0])}",
    sep="\n",
)

In [None]:
print("Key:")
pprint(keys[0])
print()
print("Value:")
pprint(values[0])

In [None]:
hashes = [hash(t) for t in cache]
len(hashes), len(set(hashes))

In [None]:
df = pd.DataFrame(
    [
        {
            "q": q,
            "a": a,
            "n": n,
            "ann": sorted(ann["val_annotations"]),
        }
        for (q, a, n), ann in cache.items()
    ],
)

In [None]:
df.head()

In [None]:
df["ann"] = df["ann"].map(lambda x: [i + 2 for i in x])
df["ann"].explode().agg(["min", "max"])

In [None]:
df["std"] = df["ann"].map(np.std)
df["std"].describe()

In [None]:
df.sort_values("std")

In [None]:
df[df["ann"].map(lambda x: x == [0, 2, 4])]

In [None]:
lst = [0, 2, 4]
np.diff(lst)

In [None]:
def listeq(lst: list[int]) -> Callable[[list[int]], bool]:
    def eq(el: list[int]) -> bool:
        return lst == el

    return eq


def entropy(data: list[int]) -> float:
    counts = np.bincount(data)
    p = counts / len(data)
    return -np.sum(p[p > 0] * np.log2(p[p > 0]))


def gini_coefficient(data: list[int]) -> float:
    if np.all(data == data[0]) or np.isclose(np.sum(data), 0):
        return 0

    sorted_data = np.sort(data)
    n = len(data)

    return (
        2 * np.sum(np.arange(1, n + 1) * sorted_data) / (n * np.sum(sorted_data))
    ) - (n + 1) / n


dd = df.assign(
    entropy=lambda x: x["ann"].map(entropy),
    gini=lambda x: x["ann"].map(gini_coefficient),
)
examples = [
    [0, 0, 0],
    [1, 1, 1],
    [1, 2, 3],
    [0, 0, 4],
    [0, 2, 4],
]
pd.concat(
    [dd[dd["ann"].map(listeq(lst))].iloc[0] for lst in examples],
    axis=1,
).transpose()[["ann", "std", "entropy", "gini"]]