In [1]:
from pathlib import Path
import pandas as pd
import numpy as np

In [2]:
from utils import generate_table, Cell, to_latex, merge_cols

In [3]:
data = pd.read_csv("../../results/cmnist/3_digits/first_decent_results.csv")

In [4]:
data["Group"].value_counts()

.same_as_for2.zs2.even_longer.3dig_4miss    30
Name: Group, dtype: int64

In [5]:
data = data.replace({"Group": {
    ".same_as_for2.zs2.even_longer.3dig_4miss": r"\texttt{ZSF}",
}}, inplace=False)

In [6]:
data.columns

Index(['Name', 'misc.seed', 'Runtime', 'Group', 'Hostname', 'State',
       'bias.log_dataset', 'suds.balanced_context', 'suds.lr',
       'misc.data_split_seed', 'misc.exp_group', 'misc.log_method',
       'Accuracy (pytorch_classifier)', 'Accuracy Discriminator (zy)',
       'Accuracy Predictor s', 'Accuracy Predictor y',
       'Accuracy_colour_0.0 (pytorch_classifier)',
       'Accuracy_colour_0.0-colour_1.0 (pytorch_classifier)',
       'Accuracy_colour_0.0-colour_2.0 (pytorch_classifier)',
       'Accuracy_colour_0.0÷colour_1.0 (pytorch_classifier)',
       'Accuracy_colour_0.0÷colour_2.0 (pytorch_classifier)',
       'Accuracy_colour_1.0 (pytorch_classifier)',
       'Accuracy_colour_1.0-colour_2.0 (pytorch_classifier)',
       'Accuracy_colour_1.0÷colour_2.0 (pytorch_classifier)',
       'Accuracy_colour_2.0 (pytorch_classifier)', 'ELBO', 'Learning rate',
       'Loss Adversarial', 'Loss Generator', 'Loss Predictor s',
       'Loss Predictor y', 'Loss reconstruction', 'Prior Lo

In [7]:
def compute_ratio_means(df, metric: str, suffix: str = " (pytorch_classifier)"):
    ratios = (
        df[f"{metric}_colour_0.0÷colour_1.0" + suffix],
        df[f"{metric}_colour_0.0÷colour_2.0" + suffix],
        df[f"{metric}_colour_1.0÷colour_2.0" + suffix],
    )
    mean = 0
    for ratio in ratios:
        assert (ratio <= 1.0).all(), str((ratio < 1.0))
        mean += ratio
    df[f"{metric} ratio mean"] = mean / len(ratios)

In [8]:
def compute_ratio_min(df, metric: str, suffix: str = " (pytorch_classifier)"):
    ratios = (
        df[f"{metric}_colour_0.0÷colour_1.0" + suffix],
        df[f"{metric}_colour_0.0÷colour_2.0" + suffix],
        df[f"{metric}_colour_1.0÷colour_2.0" + suffix],
    )
    min_ = pd.Series(1, ratios[0].index)
    for ratio in ratios:
        min_ = min_.where(min_ < ratio, ratio)
    df[f"{metric} ratio min"] = min_

In [9]:
def compute_diff_max(df, metric: str, suffix: str = " (pytorch_classifier)"):
    diffs = (
        df[f"{metric}_colour_0.0-colour_1.0" + suffix],
        df[f"{metric}_colour_0.0-colour_2.0" + suffix],
        df[f"{metric}_colour_1.0-colour_2.0" + suffix],
    )
    max_ = pd.Series(0, diffs[0].index)
    for diff in diffs:
        max_ = max_.where(max_ > diff, diff)
    df[f"{metric} diff max"] = max_

In [10]:
compute_ratio_means(data, "prob_pos")
compute_ratio_means(data, "TPR")
compute_ratio_means(data, "TNR")
compute_ratio_min(data, "prob_pos")
compute_ratio_min(data, "TPR")
compute_ratio_min(data, "TNR")
compute_diff_max(data, "prob_pos")
compute_diff_max(data, "TPR")
compute_diff_max(data, "TNR")

In [15]:
cols = {
#     'Clust/Context Accuracy': "Clust Acc",
#     'Clust/Context NMI': "Clust NMI",
#     'Clust/Context ARI': "Clust ARI",
#     "Accuracy (pytorch_classifier)": "Accuracy",
#     'prob_pos_colour_0.0÷colour_1.0 (pytorch_classifier)': "AR ratio",
#     "Renyi preds and s (pytorch_classifier)": "Renyi corr",
#     "TPR_colour_0.0÷colour_1.0 (pytorch_classifier)": "TPR ratio",
#     'TNR_colour_0.0÷colour_1.0 (pytorch_classifier)': "TNR ratio",
#     "prob_pos ratio mean": "AR ratio mean",
#     "TPR ratio mean": "TPR ratio mean",
#     "TNR ratio mean": "TNR ratio mean",
    "prob_pos ratio min": "AR ratio min",
    "TPR ratio min": "TPR ratio min",
    "TNR ratio min": "TNR ratio min",
    "prob_pos diff max": "AR diff max",
    "TPR diff max": "TPR diff max",
    "TNR diff max": "TNR diff max",
}

In [16]:
groupby = "Group"
res = (data[[groupby] + list(cols)]
        .groupby([groupby])
        .agg(Cell(round_to=3)).rename(columns=cols, inplace=False)
        .reset_index(level=[groupby], inplace=False)
        .rename(columns={groupby: "Method"}, inplace=False)
       )
res

Unnamed: 0,Method,AR ratio min,TPR ratio min,TNR ratio min,AR diff max,TPR diff max,TNR diff max
0,\texttt{ZSF},0.604 $\pm$ 0.213,0.866 $\pm$ 0.176,0.702 $\pm$ 0.292,0.236 $\pm$ 0.189,0.133 $\pm$ 0.175,0.297 $\pm$ 0.291


In [17]:
# baseline files
baseline_files = {
    r"\texttt{Kamiran \& Calders (2012) CNN}": "cmnist_cnn_baseline_color_60epochs.csv",
    r"\texttt{Kamiran \& Calders (2012) CNN} (more stable)": "bs512_lr1e-3_wd1e-4_cnn_baseline_20epochs.csv",
    r"\texttt{FWD \cite{HasSriNamLia18}}": "cmnist_dro_baseline_color_60epochs.csv",
}

In [18]:
def collate(file_dict, dir_):
    df_all = pd.DataFrame()
    for log_method, filename in file_dict.items():
        df = pd.read_csv(Path("../../results/cmnist/3_digits") / dir_ / filename)
#         df.insert(0, "exp_group", exp_group)
        df.insert(0, "log_method", log_method)
        if log_method == r"\texttt{FWD \cite{HasSriNamLia18}}":
            df["log_method"] = log_method + " (" + df["eta"].astype(str) + ")"
        df_all = pd.concat([df_all, df], axis="index", ignore_index=True, sort=False)
    return df_all

In [19]:
baselines = collate(baseline_files, ".")

In [20]:
compute_ratio_means(baselines, "prob_pos", "")
compute_ratio_means(baselines, "TPR", "")
compute_ratio_means(baselines, "TNR", "")
compute_ratio_min(baselines, "prob_pos", "")
compute_ratio_min(baselines, "TPR", "")
compute_ratio_min(baselines, "TNR", "")
compute_diff_max(baselines, "prob_pos", "")
compute_diff_max(baselines, "TPR", "")
compute_diff_max(baselines, "TNR", "")

In [21]:
baselines.columns

Index(['log_method', 'seed', 'data', 'method', 'wandb_url', 'Accuracy', 'TPR',
       'TNR', 'Renyi preds and s', 'Accuracy_colour_2.0',
       'Accuracy_colour_1.0', 'Accuracy_colour_0.0',
       'Accuracy_colour_0.0-colour_1.0', 'Accuracy_colour_0.0-colour_2.0',
       'Accuracy_colour_1.0-colour_2.0', 'Accuracy_colour_0.0÷colour_1.0',
       'Accuracy_colour_0.0÷colour_2.0', 'Accuracy_colour_1.0÷colour_2.0',
       'prob_pos_colour_2.0', 'prob_pos_colour_1.0', 'prob_pos_colour_0.0',
       'prob_pos_colour_0.0-colour_1.0', 'prob_pos_colour_0.0-colour_2.0',
       'prob_pos_colour_1.0-colour_2.0', 'prob_pos_colour_0.0÷colour_1.0',
       'prob_pos_colour_0.0÷colour_2.0', 'prob_pos_colour_1.0÷colour_2.0',
       'TPR_colour_2.0', 'TPR_colour_1.0', 'TPR_colour_0.0',
       'TPR_colour_0.0-colour_1.0', 'TPR_colour_0.0-colour_2.0',
       'TPR_colour_1.0-colour_2.0', 'TPR_colour_0.0÷colour_1.0',
       'TPR_colour_0.0÷colour_2.0', 'TPR_colour_1.0÷colour_2.0',
       'TNR_colour_2.0', 'TN

In [22]:
bl_cols = {
#     'Clust/Context Accuracy': "Clust Acc",
#     'Clust/Context NMI': "Clust NMI",
#     'Clust/Context ARI': "Clust ARI",
#     "Accuracy": "Accuracy",
#     'prob_pos_colour_0.0÷colour_1.0': "AR ratio",
#     "Renyi preds and s": "Renyi corr",
#     "TPR_colour_0.0÷colour_1.0": "TPR ratio",
#     'TNR_colour_0.0÷colour_1.0': "TNR ratio",
#     "prob_pos ratio mean": "AR, mean ratio",
#     "TPR ratio mean": "TPR, mean ratio",
#     "TNR ratio mean": "TNR, mean ratio",
    "prob_pos ratio min": "AR ratio min",
    "TPR ratio min": "TPR ratio min",
    "TNR ratio min": "TNR ratio min",
    "prob_pos diff max": "AR diff max",
    "TPR diff max": "TPR diff max",
    "TNR diff max": "TNR diff max",
}

In [23]:
groupby = "log_method"
res2 = (baselines[[groupby] + list(bl_cols)]
        .groupby([groupby])
        .agg(Cell(round_to=3)).rename(columns=bl_cols, inplace=False)
        .reset_index(level=[groupby], inplace=False)
        .rename(columns={groupby: "Method"}, inplace=False)
       )
res2

Unnamed: 0,Method,AR ratio min,TPR ratio min,TNR ratio min,AR diff max,TPR diff max,TNR diff max
0,\texttt{FWD \cite{HasSriNamLia18}} (0.01),0.009 $\pm$ 0.023,0.023 $\pm$ 0.059,0.056 $\pm$ 0.087,0.954 $\pm$ 0.071,0.977 $\pm$ 0.059,0.944 $\pm$ 0.087
1,\texttt{FWD \cite{HasSriNamLia18}} (0.1),0.022 $\pm$ 0.049,0.059 $\pm$ 0.13,0.152 $\pm$ 0.166,0.878 $\pm$ 0.123,0.941 $\pm$ 0.13,0.848 $\pm$ 0.166
2,\texttt{FWD \cite{HasSriNamLia18}} (0.3),0.027 $\pm$ 0.05,0.077 $\pm$ 0.145,0.128 $\pm$ 0.147,0.887 $\pm$ 0.101,0.923 $\pm$ 0.145,0.871 $\pm$ 0.147
3,\texttt{FWD \cite{HasSriNamLia18}} (1.0),0.016 $\pm$ 0.059,0.04 $\pm$ 0.164,0.125 $\pm$ 0.182,0.9 $\pm$ 0.127,0.959 $\pm$ 0.163,0.875 $\pm$ 0.182
4,\texttt{Kamiran \& Calders (2012) CNN},0.072 $\pm$ 0.077,0.208 $\pm$ 0.217,0.039 $\pm$ 0.054,0.904 $\pm$ 0.087,0.792 $\pm$ 0.217,0.961 $\pm$ 0.054
5,\texttt{Kamiran \& Calders (2012) CNN} (more s...,0.0009 $\pm$ 0.0016,0.003 $\pm$ 0.005,0.026 $\pm$ 0.052,0.981 $\pm$ 0.035,0.997 $\pm$ 0.005,0.974 $\pm$ 0.052


In [24]:
table = pd.concat([res, res2]).reset_index(drop=True)
table

Unnamed: 0,Method,AR ratio min,TPR ratio min,TNR ratio min,AR diff max,TPR diff max,TNR diff max
0,\texttt{ZSF},0.604 $\pm$ 0.213,0.866 $\pm$ 0.176,0.702 $\pm$ 0.292,0.236 $\pm$ 0.189,0.133 $\pm$ 0.175,0.297 $\pm$ 0.291
1,\texttt{FWD \cite{HasSriNamLia18}} (0.01),0.009 $\pm$ 0.023,0.023 $\pm$ 0.059,0.056 $\pm$ 0.087,0.954 $\pm$ 0.071,0.977 $\pm$ 0.059,0.944 $\pm$ 0.087
2,\texttt{FWD \cite{HasSriNamLia18}} (0.1),0.022 $\pm$ 0.049,0.059 $\pm$ 0.13,0.152 $\pm$ 0.166,0.878 $\pm$ 0.123,0.941 $\pm$ 0.13,0.848 $\pm$ 0.166
3,\texttt{FWD \cite{HasSriNamLia18}} (0.3),0.027 $\pm$ 0.05,0.077 $\pm$ 0.145,0.128 $\pm$ 0.147,0.887 $\pm$ 0.101,0.923 $\pm$ 0.145,0.871 $\pm$ 0.147
4,\texttt{FWD \cite{HasSriNamLia18}} (1.0),0.016 $\pm$ 0.059,0.04 $\pm$ 0.164,0.125 $\pm$ 0.182,0.9 $\pm$ 0.127,0.959 $\pm$ 0.163,0.875 $\pm$ 0.182
5,\texttt{Kamiran \& Calders (2012) CNN},0.072 $\pm$ 0.077,0.208 $\pm$ 0.217,0.039 $\pm$ 0.054,0.904 $\pm$ 0.087,0.792 $\pm$ 0.217,0.961 $\pm$ 0.054
6,\texttt{Kamiran \& Calders (2012) CNN} (more s...,0.0009 $\pm$ 0.0016,0.003 $\pm$ 0.005,0.026 $\pm$ 0.052,0.981 $\pm$ 0.035,0.997 $\pm$ 0.005,0.974 $\pm$ 0.052


In [25]:
print(table.to_latex(index=False, escape=False))

\begin{tabular}{lllllll}
\toprule
                                            Method &         AR ratio min &      TPR ratio min &      TNR ratio min &        AR diff max &       TPR diff max &       TNR diff max \\
\midrule
                                      \texttt{ZSF} &    0.604 $\pm$ 0.213 &  0.866 $\pm$ 0.176 &  0.702 $\pm$ 0.292 &  0.236 $\pm$ 0.189 &  0.133 $\pm$ 0.175 &  0.297 $\pm$ 0.291 \\
         \texttt{FWD \cite{HasSriNamLia18}} (0.01) &    0.009 $\pm$ 0.023 &  0.023 $\pm$ 0.059 &  0.056 $\pm$ 0.087 &  0.954 $\pm$ 0.071 &  0.977 $\pm$ 0.059 &  0.944 $\pm$ 0.087 \\
          \texttt{FWD \cite{HasSriNamLia18}} (0.1) &    0.022 $\pm$ 0.049 &   0.059 $\pm$ 0.13 &  0.152 $\pm$ 0.166 &  0.878 $\pm$ 0.123 &   0.941 $\pm$ 0.13 &  0.848 $\pm$ 0.166 \\
          \texttt{FWD \cite{HasSriNamLia18}} (0.3) &     0.027 $\pm$ 0.05 &  0.077 $\pm$ 0.145 &  0.128 $\pm$ 0.147 &  0.887 $\pm$ 0.101 &  0.923 $\pm$ 0.145 &  0.871 $\pm$ 0.147 \\
          \texttt{FWD \cite{HasSriNamLia18}} (1