In [1]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import srsly
from datasets import load_from_disk
from matplotlib import ticker
from scipy.special import entr, softmax
from sklearn.manifold import TSNE
from sklearn.metrics import jaccard_score
from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
from tqdm.auto import tqdm

In [None]:
path_to_experiment = Path("../outputs/multirun/reinit_study")
path_to_experiment

---
### Pool uncertainty

In [None]:
dataset = "agnews"
strategy = "entropy"

In [None]:
original_df = load_from_disk(f"../data/processed/{dataset}/")["train"].to_pandas()

In [None]:
paths = list((path_to_experiment / dataset).iterdir())
paths

In [None]:
data = []
stage = "pool"

list_df = []
for dataset in tqdm(list(path_to_experiment.iterdir())):
    if dataset.name == "imdb":
        continue
    original_df = load_from_disk(f"../data/processed/{dataset.name}/")[
        "train"
    ].to_pandas()
    for strategy in tqdm(list(dataset.iterdir())):
        if strategy.name == "random":
            continue
        for reinit in tqdm(list(strategy.iterdir())):
            for run in tqdm(list(reinit.iterdir())):
                data_seed, model_seed = run.name.split("_")

                df = pd.concat(
                    [
                        pd.read_csv(i).assign(stage=stage)
                        for i in (run / "logs" / stage).glob("*csv")
                    ]
                )

                # add labels
                df = pd.merge(
                    df,
                    original_df[["unique_id", "labels"]],
                    on="unique_id",
                    how="left",
                )

                # add logits and margin
                logits = df.iloc[:, df.columns.str.startswith("logit")].values
                probs = softmax(logits, axis=1)
                sorted_probs = np.take_along_axis(probs, probs.argsort(-1), axis=1)
                df = df.assign(
                    entropy=entr(probs).sum(-1),
                    margin=sorted_probs[:, -1] - sorted_probs[:, -2],
                    preds=probs.argmax(-1),
                    prob_pred=probs.max(-1),
                    prob_true=lambda df_: np.take_along_axis(
                        probs, df_["labels"].values.reshape(-1, 1), axis=1
                    ),
                )

                # aggregate
                df_agg = (
                    df.groupby("round")[
                        ["entropy", "margin", "prob_pred", "prob_true"]
                    ]
                    .agg(["mean", "max", "min", "std", "size"])
                    .assign(
                        data_seed=int(data_seed.split("=")[1]),
                        model_seed=int(model_seed.split("=")[1]),
                        dataset=dataset.name,
                        strategy=strategy.name,
                        reinit=eval(str(reinit.name).split("=")[-1]),
                    )
                )

                list_df.append(df_agg)

In [None]:
agnews = pd.concat(list_df).reset_index()
agnews

In [None]:
plot_data = (
    agnews.set_index(["round", "strategy", "reinit"])
    .drop(columns=["data_seed", "model_seed", "dataset"])
    .stack(0)
    .reset_index()
    .rename(columns={"level_3": "metric", "reinit": "re-initialise"})
    .rename(columns={})
    .sort_values("strategy")
)

In [None]:
plot_data

In [None]:
sns.diverging_palette(250, 30, l=65, center="dark", n=3)

In [None]:
def plot(plot_data, dpi, ymin, title, metric, y):
    colors = sns.diverging_palette(250, 30, l=65, center="dark", n=3)

    plt.style.use("bmh")
    sns.set_context("paper")

    # plot
    fig, ax = plt.subplots()
    fig.dpi = dpi

    sns.lineplot(
        data=plot_data.loc[(plot_data["metric"] == metric)],
        x="round",
        y=y,
        hue="strategy",
        style="re-initialise",
        ax=ax,
        palette=colors,
    )

    ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1))
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(5))
    # ax.set_xlim(0, 100)
    ax.set_ylim(ymin)

    fig.suptitle(title)
    ax.set_title(
        f"average {y} {metric} on the entire pool across rounds", fontsize=10
    )
    ax.set_ylabel(metric)
    ax.set_xlabel("Labelling round")
    ax.legend(fontsize=10, bbox_to_anchor=(1, 1))

    sns.despine()
    plt.show()

In [None]:
plot(plot_data, dpi=800, ymin=0.2, title="AGNEWS", metric="margin", y="mean")

---
### Similarity between selected subsets

In [None]:
path_to_experiment

In [None]:
list_dfs = []
for strategy in (path_to_experiment / "agnews").iterdir():
    for reinit in strategy.iterdir():
        for run in reinit.iterdir():
            data_seed, model_seed = run.name.split("_")
            df = pd.read_parquet(run / "logs" / "labelled_dataset.parquet").assign(
                reinit=eval(reinit.name.split("=")[1]),
                strategy=strategy.name,
                experiment=run.name,
                data_seed=int(data_seed.split("=")[1]),
                model_seed=int(model_seed.split("=")[1]),
            )
            list_dfs.append(df)

In [None]:
df = pd.concat(list_dfs).reset_index(drop=True)

In [None]:
def jaccard_similarity(A, B):
    # Find intersection of two sets
    nominator = A.intersection(B)

    # Find union of two sets
    denominator = A.union(B)

    # Take the ratio of sizes
    similarity = len(nominator) / len(denominator)

    return similarity


def compute_jaccard(df, a, b):
    A = set(df.loc[df["strategy"] == a, "unique_id"].unique())
    B = set(df.loc[df["strategy"] == b, "unique_id"].unique())

    return jaccard_similarity(A, B)


# def compute_jaccard(df, a, b):
#     A = df.loc[df["strategy"] == a, "unique_id"].unique()
#     B = df.loc[df["strategy"] == b, "unique_id"].unique()

#     return jaccard_score(A, B)

In [None]:
a = (
    df.groupby(["experiment", "reinit"])
    .apply(lambda df_: compute_jaccard(df_, "entropy", "margin_confidence"))
    .to_frame("jaccard")
)
b = (
    df.groupby(["experiment", "reinit"])
    .apply(lambda df_: compute_jaccard(df_, "entropy", "random"))
    .to_frame("jaccard")
)
c = (
    df.groupby(["experiment", "reinit"])
    .apply(lambda df_: compute_jaccard(df_, "margin_confidence", "random"))
    .to_frame("jaccard")
)
dd = (
    a.join(b, rsuffix="_er")
    .join(c, rsuffix="_mr")
    .rename(columns={"jaccard": "jaccard_em"})
    .reset_index()
)
dd

In [None]:
def compute_jaccard_strategy(df):
    A = set(df.loc[df["reinit"] == True, "unique_id"].unique())
    B = set(df.loc[df["reinit"] == False, "unique_id"].unique())

    return jaccard_similarity(A, B)


def compute_cosine(df):
    A = (
        embeddings[df.loc[df["reinit"] == True, "unique_id"].unique()]
        .mean(0)
        .reshape(1, -1)
    )
    B = (
        embeddings[df.loc[df["reinit"] == False, "unique_id"].unique()]
        .mean(0)
        .reshape(1, -1)
    )

    return cosine_similarity(A, B).item()

In [None]:
m = {}
for strategy in df["strategy"].unique():
    m[strategy] = (
        df.loc[df["strategy"] == strategy]
        .groupby("experiment")
        .apply(compute_cosine)
        .agg(["mean", "std"])
    )

print(pd.DataFrame(m).T.to_markdown())

In [None]:
a = (
    df.groupby(["labelling_round", "experiment", "reinit"])
    .apply(lambda df_: compute_jaccard(df_, "entropy", "margin_confidence"))
    .to_frame("jaccard")
)
b = (
    df.groupby(["labelling_round", "experiment", "reinit"])
    .apply(lambda df_: compute_jaccard(df_, "entropy", "random"))
    .to_frame("jaccard")
)
c = (
    df.groupby(["labelling_round", "experiment", "reinit"])
    .apply(lambda df_: compute_jaccard(df_, "margin_confidence", "random"))
    .to_frame("jaccard")
)
dd = (
    a.join(b, rsuffix="_er")
    .join(c, rsuffix="_mr")
    .rename(columns={"jaccard": "jaccard_em"})
    .reset_index()
)
dd = dd.melt(id_vars=["labelling_round", "experiment", "reinit"])
dd = dd.rename(columns={"reinit": "re-initialise", "variable": "pairs"})
dd["pairs"] = dd["pairs"].map(
    {
        "jaccard_em": "margin-entropy",
        "jaccard_mr": "margin-random",
        "jaccard_er": "entropy-random",
    }
)
dd

In [None]:
sns.diverging_palette(150, 40, l=65, center="dark", n=3)

In [None]:
a = df.loc[
    (df["experiment"] == "data.seed=0_model.seed=0")
    & (df["strategy"] == "entropy")
    & (df["labelling_round"] == 1)
]

In [None]:
b = a.loc[a["reinit"] == False, "unique_id"]
c = a.loc[a["reinit"] == True, "unique_id"]

In [None]:
embeddings = np.load("../data/processed/agnews/ag_news_index.npy")

In [None]:
b.values

In [None]:
cosine_similarity(
    embeddings[b].mean(0).reshape(1, -1), embeddings[c].mean(0).reshape(1, -1)
)

In [None]:
def compute_cosine(df)

In [None]:
a = (
    df.groupby(["labelling_round", "experiment", "reinit"])
    .apply(lambda df_: compute_jaccard(df_, "entropy", "margin_confidence"))
    .to_frame("jaccard")
)
b = (
    df.groupby(["labelling_round", "experiment", "reinit"])
    .apply(lambda df_: compute_jaccard(df_, "entropy", "random"))
    .to_frame("jaccard")
)
c = (
    df.groupby(["labelling_round", "experiment", "reinit"])
    .apply(lambda df_: compute_jaccard(df_, "margin_confidence", "random"))
    .to_frame("jaccard")
)
dd = (
    a.join(b, rsuffix="_er")
    .join(c, rsuffix="_mr")
    .rename(columns={"jaccard": "jaccard_em"})
    .reset_index()
)
dd = dd.melt(id_vars=["labelling_round", "experiment", "reinit"])
dd = dd.rename(columns={"reinit": "re-initialise", "variable": "pairs"})
dd["pairs"] = dd["pairs"].map(
    {
        "jaccard_em": "margin-entropy",
        "jaccard_mr": "margin-random",
        "jaccard_er": "entropy-random",
    }
)
dd

In [None]:
plt.style.use("bmh")
sns.set_context("paper")
palette = sns.diverging_palette(150, 40, l=60, center="dark", n=3)

fig, ax = plt.subplots()

sns.lineplot(
    data=dd,
    x="labelling_round",
    y="value",
    hue="pairs",
    style="re-initialise",
    ax=ax,
    palette=palette,
)
fig.dpi = 800
fig.suptitle("AGNEWS")
ax.set_title("pair-wise similarity between queried sets")
ax.set_ylabel("jaccard similarity")
ax.set_xlabel("labelling round")
ax.set_ylim(0.0, 0.025)
ax.xaxis.set_minor_locator(ticker.MultipleLocator(5))
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), fontsize=10)
# ax.legend(fontsize=10, bbox_to_anchor=(1, 1))
sns.despine()
plt.show()