In [1]:
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib import ticker
import polars as pl
import matplotlib.font_manager as fm
import numpy as np
import srsly
from itertools import permutations, combinations, product

In [2]:
path = Path("../outputs/multirun/seeds/")

In [3]:
list_df = []
for p in path.rglob("*tb_logs.parquet"):
    strategy = p.parents[0].name.split("_")[0]
    dataset = p.parents[1].name
    
    meta = srsly.read_yaml(p.parent / "hparams.yaml")
    
    df = (
        pl.scan_parquet(p)
        .filter(pl.col("tag") == "test/f1_class1_vs_budget")
        .with_columns(
            strategy=pl.lit(strategy), 
            dataset=pl.lit(dataset), 
            path=pl.lit(p.parent.name),
            data_seed=meta["data"]["seed"],
            model_seed=meta["model"]["seed"],
            seed=pl.lit(f'{meta["data"]["seed"]}-{meta["model"]["seed"]}')
        )
    )

    list_df.append(df)

df = (
    pl.concat(list_df)
    # .groupby(["strategy", "dataset", "step"])
    # .agg(mean=pl.col("value").mean(), std=pl.col("value").std(), n=pl.count("value"))
    .collect(streaming=True)
)

In [4]:
df["strategy"].value_counts()

strategy,counts
str,u32
"""all""",3838
"""random""",4141
"""all-anchorssub…",2929
"""randomsubset""",4242


In [5]:
df.groupby(["dataset", "strategy"]).count().sort(["dataset", "strategy"])

dataset,strategy,count
str,str,u32
"""agnews""","""all""",808
"""agnews""","""all-anchorssub…",909
"""agnews""","""random""",909
"""agnews""","""randomsubset""",909
"""amazon""","""all""",1111
"""amazon""","""all-anchorssub…",404
"""amazon""","""random""",505
"""amazon""","""randomsubset""",606
"""eurlex""","""all""",202
"""eurlex""","""all-anchorssub…",606


In [9]:
df.groupby(["dataset", "strategy"]).agg(pl.col("path").n_unique()).sort(["strategy"]).filter(pl.col("path") < 9)

dataset,strategy,path
str,str,u32
"""eurlex""","""all""",2
"""agnews""","""all""",8
"""pubmed""","""all""",8
"""pubmed""","""all-anchorssub…",8
"""wiki_toxic""","""all-anchorssub…",2
"""amazon""","""all-anchorssub…",4
"""eurlex""","""all-anchorssub…",6
"""amazon""","""random""",5
"""amazon""","""randomsubset""",6


- wiki_toxic random, all


In [15]:
df["strategy"].unique().to_list()

['randomsubset', 'random', 'all-anchorssubset-min', 'all']

In [23]:
done = set(
    df
    .filter(
        (pl.col("dataset") == "amazon")
        & (pl.col("strategy") == "randomsubset")
    )["seed"].unique()
)
total = set(f"{i}-{j}" for i, j in product([42, 0, 1994], [42, 0, 1994]))
total.difference(done)

{'0-1994', '42-1994', '42-42'}

In [None]:
mapper = {
    "random": "Random",
    "randomsubset": "Random Subset",
    "all": "AnchorAL",
    "all-anchorssubset-min": "SEALS",
}
df = df.with_columns(pl.col("strategy").map_dict(mapper).alias("strategy_name"))

In [None]:
mpl.rcParams["font.family"] = "monospace"  # "DejaVu Sans Mono"
plt.rcParams["font.size"] = 14
plt.rcParams["axes.linewidth"] = 2
plt.style.use("bmh")
sns.set_context("paper")

# sorted(matplotlib.font_manager.get_font_names())

In [None]:
plot_data = df.filter(
    (pl.col("strategy_name").is_in(mapper.values()))
    # & (pl.col("strategy") != "random")
)    

strategies = sorted(plot_data["strategy_name"].unique())

# palette = sns.diverging_palette(250, 30, l=65, center="dark", n=plot_data["strategy_name"].n_unique())
# palette

palette = dict(zip(strategies, sns.color_palette("Set1", n_colors=len(strategies))))


In [None]:
dataset = "amazon"

fig, ax = plt.subplots()

# fig.dpi = dpi

sns.lineplot(
    x="step",
    y="value",
    data=plot_data.filter(pl.col("dataset") == dataset),
    errorbar=("se", 2),
    ax=ax,
    hue="strategy_name",
    palette=palette,
)

ax.xaxis.set_minor_locator(ticker.MultipleLocator(100))
ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1))

ax.set_xlim(90, 1000)
ax.set_ylim(0, 0.8)

ax.set_xlabel("Budget")
ax.set_ylabel("F1 Minority Class")

# fig.suptitle(dataset.title())
ax.set_title(dataset.title().replace("_", ""), fontsize=12)
# ax.set_title(f"Average over 9 runs", fontsize=10)

# ax.legend(fontsize=10, bbox_to_anchor=(1, 1))
ax.legend(fontsize=10)

sns.despine()
plt.show()

In [None]:
pdf = plot_data.to_pandas()

ymax = 2600 

plt.style.use("bmh")
sns.set_context("paper")

g = sns.FacetGrid(
    pdf, 
    col="dataset",  
    col_wrap=3,
    sharex=True,
    legend_out=True,
    # palette=palette,
    despine=True,
    xlim=(0, ymax), 
    ylim=(-.001,None),
    sharey=False,
    # hue_order=strategies,)
)

g.tight_layout()

g.map_dataframe(
    sns.lineplot,
    x="step",
    y="value",
    data=pdf,
    errorbar=("se", 2),
    # ax=ax,
    hue="strategy_name",
    palette=palette,
    # style="strategy_name",
)

g.axes_dict["agnews"].set_ylim(-.001, 0.7 if ymax == 2600 else 0.6)
g.axes_dict["pubmed"].set_ylim(-.001, 0.7)
g.axes_dict["eurlex"].set_ylim(-.001, 0.7 if ymax == 2600 else 0.5)
g.axes_dict["wiki_toxic"].set_ylim(-.001, 0.8)
g.axes_dict["amazon"].set_ylim(-.001, 0.9)

g.set_axis_labels("", "")
g.set_titles(col_template="{col_name}",)

g.add_legend(label_order=strategies, bbox_to_anchor=(.82, .25))
# g.despine()
plt.savefig(fname="plots.png", dpi=1000, transparent=True)
plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
fig.subplots_adjust(hspace=0.05)  # adjust space between axes

# plot the same data on both axes
sns.lineplot(
    x="step",
    y="value",
    data=plot_data,
    errorbar=("se", 2),
    ax=ax1,
    hue="strategy",
    legend=False,
)
sns.lineplot(
    x="step",
    y="value",
    data=plot_data,
    errorbar=("se", 2),
    ax=ax2,
    hue="strategy",
    legend=False,
)

# zoom-in / limit the view to different portions of the data
ax1.set_ylim(.41, .65)  # outliers only
ax2.set_ylim(0, .21)  # most of the data

# hide the spines between ax and ax2
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()

# Now, let's turn towards the cut-out slanted lines.
# We create line objects in axes coordinates, in which (0,0), (0,1),
# (1,0), and (1,1) are the four corners of the axes.
# The slanted lines themselves are markers at those locations, such that the
# lines keep their angle and position, independent of the axes size or scale
# Finally, we need to disable clipping.

d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(
    marker=[(-1, -d), (1, d)], 
    markersize=12,
    linestyle="none", 
    color='k', 
    mec='k', 
    mew=1, 
    clip_on=False,
)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)


plt.show()

In [None]:
fig = plt.figure(layout="constrained")
axes = fig.subplot_mosaic("AA;B.")
sns.lineplot(
    x="step",
    y="value",
    data=plot_data,
    errorbar=("se", 2),
    ax=axes["A"],
    hue="strategy",
    legend=False,
)
sns.lineplot(
    x="step",
    y="value",
    data=plot_data,
    errorbar=("se", 2),
    ax=axes["B"],
    hue="strategy",
    legend=False,
)
axes["A"].set_ylim(0, .7)
axes["B"].set_ylim(0, .62)
axes["A"].set_xlim(200, 2600)
axes["B"].set_xlim(200, 800)
# axes["B"].legend(fontsize=10, bbox_to_anchor=(0.5, 0.5))
plt.show()


In [None]:
dataset = "pubmed"
plot_data = df.filter(pl.col("dataset") == dataset)

fig, ax = plt.figure().subplot_mosaic("AA;XB")
# fig.dpi = dpi

sns.lineplot(
    x="step",
    y="value",
    data=plot_data,
    errorbar=("se", 2),
    ax=axes["A"],
    hue="strategy",
)

mpl.transforms.TransformedBbox

ax.xaxis.set_minor_locator(ticker.MultipleLocator(100))
ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1))

ax.set_xlim(100, 2600)
ax.set_ylim(0)

ax.set_xlabel("Budget")
ax.set_ylabel("F1 (Minority Class)")

# fig.suptitle(dataset.title())
ax.set_title(dataset.title(), fontsize=12)
# ax.set_title(f"Average over 9 runs", fontsize=10)

ax.legend(fontsize=10, bbox_to_anchor=(1, 1))

sns.despine()
plt.show()



In [None]:
dataset = "agnews"
plot_data = df.filter(pl.col("dataset") == dataset)

fig, ax = plt.subplots(nrows=2)
# fig.dpi = dpi

sns.lineplot(
    x="step",
    y="value",
    data=plot_data,
    errorbar=("se", 2),
    ax=ax,
    hue="strategy",
)

ax.xaxis.set_minor_locator(ticker.MultipleLocator(100))
ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1))

ax.set_xlim(100, 2600)
ax.set_ylim(0)

ax.set_xlabel("Budget")
ax.set_ylabel("F1 (Minority Class)")

# fig.suptitle(dataset.title())
ax.set_title(dataset.title(), fontsize=12)
# ax.set_title(f"Average over 9 runs", fontsize=10)

ax.legend(fontsize=10, bbox_to_anchor=(1, 1))

sns.despine()
plt.show()

