In [1]:
from pathlib import Path

import pandas as pd
from datasets import load_from_disk
from tqdm.auto import tqdm

In [2]:
data_path = Path("../data/prepared/")

In [3]:
def format_integers(v: int) -> str:
    return f"\\integer{{{v}}}"


def format_float(v: float) -> str:
    return f"\\percentage[2]{{{v}}}"


def format_labels(v: str) -> str:
    return "" if len(v) < 1 else f"\\code{{{v}}}"

In [12]:
stats = {}

for p in tqdm(list(data_path.iterdir())):
    if "agnews" not in p.name:
        continue

    stats[p.name] = {}
    ds_dict = load_from_disk(p)

    # GET THE INFO
    s = {}
    for split in ds_dict:
        names = ds_dict[split].features["labels"].names
        df = pd.DataFrame(ds_dict[split]["labels"], columns=["labels"])

        df = pd.concat(
            [df["labels"].value_counts(normalize=True), df["labels"].value_counts()], axis=1, keys=("perc", "count")
        )
        df["labels"] = [names[i] for i in df.index]

        df = df.sort_index(ascending=True)[df.columns.tolist()[::-1]]
        df = pd.concat([df, df.sum(0).to_frame().T])
        df.iloc[-1, 0] = ""
        s[split] = df

    df = pd.concat(s)
    df = (
        df.reset_index()
        .rename(columns={"level_0": "Split", "labels": "Label", "perc": "Proportion", "count": "Count"})
        .drop(columns=["level_1"])
        .assign(
            Dataset=p.name,
            Split=lambda _df: _df["Split"].str.title(),
            Count=lambda _df: _df["Count"].map(format_integers),
            Proportion=lambda _df: _df["Proportion"].map(format_float),
            Label=lambda _df: _df["Label"].map(format_labels),
        )[["Dataset", "Split", "Label", "Count", "Proportion"]]
    )

    stats[p.name] = df

  0%|          | 0/13 [00:00<?, ?it/s]

In [14]:
stats = {
    k: v
    for k, v in stats.items()
    # if any(i in k for i in ["amazoncat", "wiki"])
    if "bert-base" in k
}

In [16]:
tab = pd.concat([stats[k] for k in sorted(stats.keys())])
(
    tab.melt(id_vars=["Dataset", "Label", "Split"])
    .set_index(["Dataset", "Label", "Split", "variable"])
    .unstack(-2)
    .unstack(-1)
    .to_latex(escape=False, buf="../results/tables/data_stats.tex")
)