In [None]:
import os
from pathlib import Path
import pyrootutils

notebook_path = Path(os.path.abspath(""))
pyrootutils.setup_root(notebook_path, indicator=".project-root", pythonpath=True)

DIRNAME = r"/cluster/home/vjimenez/adv_pa_new/results/dg/modelselection"

In [None]:
from src.plot.dg import *
from src.plot.dg._retrieve import *
from src.plot.dg._plot import *

In [None]:
# project = "DiagVib-6 OOD Model Selection"
# for dataset_name in ['pos_maxmixval']:
#     for mod in ["erm", "irm"]:
#         for opt in ["adam"]:
#             for lr in ["0.001"]:
#                 try:
#                     data_dict = get_dictionary(dataset_name, [f"mod={mod}_opt={opt}_lr={lr}"], datashift=False)
#                 except:
#                     print(f"No data found for this configuration: dataset_name={dataset_name}, mod={mod}, opt={opt}, lr={lr}")

ds_hue = ['hue_mixval','hue_maxmixval','hue_oodval']
ds_pos = ['pos_zero','pos_idval','pos_mixval', 'pos_oodval']
ds_hue_npair = ['hue_zero_npair','hue_idval_npair','hue_mixval_npair','hue_maxmixval_npair','hue_oodval_npair']
ds_pos_npair = ['pos_zero_npair','pos_idval_npair','pos_mixval_npair','pos_maxmixval_npair','pos_oodval_npair']

project = "DiagVib-6 OOD Model Selection"
for dataset_name in ds_hue + ds_pos + ds_hue_npair + ds_pos_npair:
    for lr in ["0.0001", "0.0005"]:
        try:
            data_dict = get_dictionary(dataset_name, [f"mod=lisa_ppred=0.5_opt=adam_lr={lr}"], datashift=False)
        except:
            print(f"No data found for this configuration: dataset_name={dataset_name}, mod=lisa_ppred=0.5, opt=adam, lr={lr}")

In [None]:
from matplotlib.ticker import MultipleLocator

def plot_variable_vs_run(
        data: dict,
        run_names: list,
        metrics: list,
        hue_attribute: str,
        hue_dict: dict,
        ylabel: str,
        legend_labels: list,
        title: str,
        savedir: str,
        yscale: Optional[bool] = "symlog",
        legend: Optional[bool] = True,
        legend_loc: Optional[str] = "best",
        save: Optional[bool] = False,
        version_appendix: Optional[str] = ""
    ) -> None:
    """
        Args:
            data (dict): Dictionary with all the data for the desired runs.
            metric (str): Name of the metric to plot.
            selection_metric (str): Name of the metric that guides the selection of the `metric` values to plot.
            selection_criterion (str): Criterion of the `selection_metric`. Accepts "min", "max", "first" and "last".
            selection_environment (Optional[int]): Environment to implement the selection criterion. If None, it will be implemented
                for each environment separately.
    """
    # Number of runs:
    
    run_attributes = [extract_names(name) for name in run_names] #(model, opt, lr)

    name_datasets = set(data["dataset"])
    num_datasets = len(name_datasets)

    # Get the font
    fontname = "DejaVu Serif"
    _ = fm.findfont(fm.FontProperties(family=fontname))

    # Subset of the dictionary:
    dict_to_iter = {
        "dataset": data["dataset"],
        "model": [attrs[0] for attrs in run_attributes]*num_datasets,
        "optimizer": [attrs[1] for attrs in run_attributes]*num_datasets,
        "lr": [float(attrs[2]) for attrs in run_attributes]*num_datasets,
    }    

    df_list = []
    for irun in range(len(data["dataset"])):
        dict_to_plot = {
            "epochs": np.arange(1, 101)
        }
        dict_to_plot.update({
            key: np.full(100, values[irun])
            for key, values in dict_to_iter.items()
        })
        dict_to_plot.update({
            metric: data[metric][irun]
            for metric in metrics
        })
        df_list.append(pd.DataFrame(dict_to_plot))
    
    level_set = pd.concat(df_list)
    
    # Create a line plot
    plt.close('all')
    _, ax = plt.subplots(figsize=(2 * 3.861, 2 * 2.7291))
    sns.set(font_scale=1.9)
    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.serif"] = fontname
    sns.set_style("ticks")


    for metric in metrics:
        sns.lineplot(
            data=level_set,
            ax=ax,
            x="epochs",
            y=metric,
            hue=hue_attribute,
            style=hue_attribute,
            palette=hue_dict,
            dashes=False, #[(2,2)] if metric == metrics[0] else False, #dash_styles.get(metric, False),
            marker=None,
            linewidth=3,
            legend=legend
        )

    # ax.minorticks_on()
    ax.xaxis.set_minor_locator(MultipleLocator(1))
    ax.set_xticks([1] + [i for i in range(10,101,10)])
    plt.xticks(rotation=45)

    ax.tick_params(axis="both", which="both", direction="in")
    xticks_font = fm.FontProperties(family=fontname)
    for tick in ax.get_xticklabels():
        tick.set_fontproperties(xticks_font)

    ax.grid(linestyle="--")

    # ax.set_ylim(min(level_set[metric])*2, 0.5)
        
    # posdiff = -(10**np.log10(abs(max(level_set[metric]))))/2
    # ax.set_ylim(min(level_set[metric])*2, posdiff)
    ax.set_xlabel("Epochs", fontname=fontname)
    # r"$10^{-4} \times $ PA"
    ax.set_ylabel(ylabel, fontname=fontname)
    ax.set_yscale(yscale) 
    
    # Legend
    if legend == True:
        handles, _ = ax.get_legend_handles_labels()

        # FOR TRAIN AND VAL METRICS:
        # legend_labels = ["Training", "Validation"] + legend_labels
        
        # for handle in handles[:2]:
        #     handle.set_color("black")
        # handles[1].set_linestyle("-")
        
            
        legend_properties = {
            "family": fontname,
            'size': 18,
        }  
        ax.legend(
            handles,
            legend_labels,
            loc=legend_loc,
            # loc="lower left",
            # fontsize=12,
            handlelength=0.5,
            prop=legend_properties
        )

    ax.set_title(title, fontname=fontname)
    plt.tight_layout()
    if save:
        plt.savefig(savedir)
        plt.clf()
        plt.close()
    else:
        plt.show()

In [None]:
run_names = [f"mod=irm_opt=adam_lr=0.001"]
ds_name = "pos_maxmixval"
data_dict = get_multiple_dict(
    [ds_name],
    run_names,
    datashift=False
)

In [None]:
trainval_metrics = ["acc", "loss", "specificity", "sensitivity", "precision"]
for met in trainval_metrics: 
    if met != "loss":
        continue
    
    legend = False
    if met == "loss":
        legend = True

    mettitle = met.capitalize()
    if met == "acc":
        mettitle = "Accuracy"

    plot_variable_vs_run(
        data=data_dict,
        run_names=run_names,
        metrics=[f"train/{met}", f"val/{met}"],
        hue_attribute="lr",
        hue_dict={
            0.0001: "tab:cyan",
            0.001: "tab:pink"
        },
        title=f"{mettitle}",
        legend_labels=["0.0001", "0.001"],
        ylabel="",
        # title=f"{mod.upper()}",
        savedir=os.path.join(DIRNAME, rf"{ds_name}/{mod}_lr_{met}.png"),
        yscale="linear",
        legend=legend,
        save=False,
        version_appendix=""
    )

In [None]:
mod = "irm"
run_names = [f"mod={mod}_opt=sgd_lr=0.001", f"mod={mod}_opt=sgd_lr=0.0001"]
ds_name = "hue_idval_17"
data_dict = get_multiple_dict(
    [ds_name],
    run_names,
    datashift=False
)

met = "beta"
plot_variable_vs_run(
        data=data_dict,
        run_names=run_names,
        metrics=[f"PA(0,1)/{met}"],
        hue_attribute="lr",
        hue_dict={
            0.0001: "tab:cyan",
            0.001: "tab:pink"
        },
        ylabel=r"$\beta$",
        legend_labels=["0.0001", "0.001"],
        legend_loc="best",
        title=f"{mod.upper()}",
        savedir=os.path.join(DIRNAME, rf"{ds_name}/{mod}_lr_{met}.png"),
        yscale="linear",
        legend=False,
        save=True,
        version_appendix=""
    )

In [None]:
def _get_cov_det(data_dict, epoch, run_index):
    cov_matrix = np.vstack([data_dict[f"COV_{epoch}"][run_index*2], data_dict[f"COV_{epoch}"][run_index*2 + 1]])
    return np.linalg.det(cov_matrix)

epoch_filter = np.asarray([0, 10, 20, 30, 40, 49])
cov_data = [_get_cov_det(data_dict, epoch, r) for epoch in epoch_filter]

# TABLE OF RESULTS

In [None]:
from collections import defaultdict
import pickle

def _get_table_of_results(
        model_list: list[str],
        main_factor: str,
        n_pair: bool,
        optimizer: str, 
        lr: float
    ):

    df_list = []
    for mod in model_list:
        # 1. LOAD THE APPROPIATE DATA:

        # The name of the runs we want to compare
        # run_names = [
        #     f"mod={mod}_opt={optimizer}_lr={lr}"
        # ]

        run_names = [
            f"mod=lisa_ppred=0.5_opt={optimizer}_lr={lr}"
        ]

        # The datasets for which we want to compare them:
        # val_names = ["zero", "idval", "mixval", "maxmixval", "oodval"]
        if n_pair == True:
            val_names = ["zero_npair", "idval_npair", "mixval_npair", "maxmixval_npair", "oodval_npair"]
        else:
            val_names = ["zero", "idval", "mixval", "maxmixval", "oodval"]

        ds_names = [
            main_factor + '_' + val_name
            for val_name in val_names
        ]

        # Obtain the whole dataset:
        data_dict_list = []
        for ds_name in ds_names:
            for run_name in run_names:
                try:
                    # data_dict_list.append(
                    #     get_dictionary(ds_name, [run_name], datashift=False) # this is model selection
                    # )

                    with open(rf"/cluster/home/vjimenez/adv_pa_new/results/dg/modelselection/{ds_name}/mod=lisa_ppred=0.5_opt={optimizer}_lr={lr}.pkl", 'rb') as f:
                        data_dict_list.append(
                            pickle.load(f)
                        )

                except:
                    print(f"No dataset for this configuration: {ds_name} + {run_name}")

        data_dict = defaultdict(list)
        for d in data_dict_list:
            for key, value in d.items():
                data_dict[key].extend(value)


        # 2. RETRIEVE ONLY USEFUL INFORMATION.
        num_datasets = len(data_dict['seed'])
        list_selection_metrics = ["val/acc", "PA(0,1)/AFR_pred", "PA(0,1)/logPA"]

        selection_dictionary_table = {}
        for idat, dat in enumerate(data_dict['dataset']):
            selection_dictionary_table[dat] = {
                "epoch": [data_dict[metric][idat].argmax() for metric in list_selection_metrics],
                "value": [data_dict[metric][idat].max() for metric in list_selection_metrics],
            }
            selection_dictionary_table[dat].update({
                f"acc@{e}": [
                    data_dict[f'oracle/acc_{e}'][idat][selection_dictionary_table[dat]["epoch"][imet]]
                    for imet in range(len(list_selection_metrics))
                ]
                for e in range(0, 6)
            })

        # 3. CONVERT TO PANDAS DF:
        columns = ['shift', 'metric'] + list(selection_dictionary_table.keys())
        rows = []
        # Iterate through shifts (acc@0 to acc@5)
        for shift in range(6):
            shift_key = f'acc@{shift}'
            # Iterate through metrics
            for idx, metric in enumerate(['acc', 'afrp', 'pa']):
                row = [shift, metric]
                for key in selection_dictionary_table:
                    # First value for 'acc', difference for 'afrp' and 'pa'
                    base_value = selection_dictionary_table[key][shift_key][0]
                    if metric == 'acc':
                        value = base_value
                    elif metric == 'afrp':
                        # value = selection_dictionary_table[key][shift_key][1] - base_value
                        value = selection_dictionary_table[key][shift_key][1]
                    elif metric == 'pa':
                        # value = selection_dictionary_table[key][shift_key][2] - base_value
                        value = selection_dictionary_table[key][shift_key][2]
                    row.append(value)
                rows.append(row)

        # Convert rows into a DataFrame
        df = pd.DataFrame(rows, columns=columns)
        df['model'] = mod
        df_list.append(df)

    return pd.concat(df_list)


## Add model selection keys based on the computation.

In [40]:
DATASET_DICT = {
    "zero": "SD",
    "idval": "ID",
    "mixval": "1F-MD",
    "maxmixval": "5F-MD",
    "oodval": "OOD"
}

FACTOR_DICT = {
    "pos": r"\texttt{position}",
    "hue": r"\texttt{hue}"
}

MODEL_NAMES = {
    "erm": r"{\color{tab:blue} \textbf{ERM}}",
    "irm": r"{\color{tab:orange} \textbf{IRM}}",
    "lisa_ppred=0.5": r"{\color{tab:green} \textbf{LISA}}"
}


def dataset_name_parser(ds_name: str):
    list_name = ds_name.split("_")
    return f"{DATASET_DICT[list_name[1]]}"

In [41]:
def generate_latex_table(df, optimizer: str, lr: float, main_factor: str, n_pair: bool):
    # Get the dataset names (e.g., pos_zero_npair, pos_idval_npair, etc.)
    dataset_names = df.columns[2:-1].unique()
    dataset_names_parsed = [dataset_name_parser(ds_name) for ds_name in dataset_names]
    # Get the model names (e.g., erm, irm, etc.)
    model_names = df['model'].unique()
    model_names_parsed = [MODEL_NAMES[model_name] for model_name in model_names]

    # Variable to track if significant improvement was observed in AFR$_P$
    significant_improvement_observed = False
    
    # Start building the LaTeX table
    latex_code = "\\begin{table}[H]\n\\centering\n\\resizebox{\\textwidth}{!}{%\n\\begin{tabular}{l|cl|cl|cl|cl|cl|cl}\n"

    # Header for shift values
    # shift_text = lambda shift: f"Test \#{shift}"
    shift_text = lambda shift: f"Test {shift}"
    latex_code += "\\multirow{2}{*}{} & " + " & ".join(
        [
        f"\\multicolumn{{2}}{{c|}}{{\\textbf{{{shift_text(shift)}}}}}" if shift < 5 else f"\\multicolumn{{2}}{{c}}{{\\textbf{{{shift_text(shift)}}}}}"
        for shift in range(6)
    ]
    ) + " \\\\\n"

    # Select the indexes of the best accuracies:    
    list_max_acc, list_max_pa = [], []
    for model in model_names:
        accs_dataset = np.zeros((len(dataset_names), 6))
        for idataset, dataset in enumerate(dataset_names):
            for shift in range(6):
                row = df[(df['shift'] == shift) & (df['model'] == model)]
                if not row.empty:
                    acc =  100.0*row[dataset][row['metric'] == 'acc'].values[0]
                else:
                    acc= 0.0
                accs_dataset[idataset, shift] = acc
        list_max_acc.append(np.argmax(accs_dataset, axis=0))
    
    
    # Iterate over each model
    for imodel, (model, model_parsed) in enumerate(zip(model_names, model_names_parsed)):
        # Header for metrics (Acc., PA)
        latex_code += f"\\textbf{{{model_parsed}}} & " + " & ".join(
            # [r"\textbf{Acc.} & \textbf{PA}" for _ in range(6)]
            [r"Acc. & $\Delta$Acc." for _ in range(6)]
        ) + " \\\\\n"
        latex_code += "\\midrule\n"

        # Iterate over each dataset
        for idataset, (dataset, dataset_parsed) in enumerate(zip(dataset_names, dataset_names_parsed)):
            latex_code += f"{dataset_parsed}"
            # latex_code += f"\\textbf{{{dataset_parsed}}}"
            for shift in range(6):
                # Filter the DataFrame for the current dataset, model, and shift
                row = df[(df['shift'] == shift) & (df['model'] == model)]
                if not row.empty:
                    # Extract the metric values for the current dataset, model, and shift
                    acc = 100.0*row[dataset][row['metric'] == 'acc'].values[0]
                    pa = 100.0*row[dataset][row['metric'] == 'pa'].values[0]

                    # Replace 0.000 with a dash "-"
                    # acc_str = f"\\textbf{{{acc:.1f}}}" if idataset == list_max_acc[imodel][shift] else f"{acc:.1f}"
                    acc_str = f"{acc:.1f}"
                    if float(f"{pa:.1f}") > 0:
                        pa_str_in = f"\Plus {abs(pa):.1f}"
                        pa_str = f"{{\\color{{tab:green}}  \\textbf{{{pa_str_in}}}}}"
                    elif float(f"{pa:.1f}") < 0:
                        pa_str_in = f"\Minus {abs(pa):.1f}"
                        pa_str = f"{{\\color{{tab:red}} \\textbf{{{pa_str_in}}}}}"
                    else:
                        pa_str = r"\PlusMinus 0.01" #"0.0"
                    
                    latex_code += f" & {acc_str} & {pa_str}"
                else:
                    latex_code += " & - & -"  # Placeholder if there's no data
            latex_code += " \\\\\n"


        if imodel < len(model_names)-1:
            latex_code += "\\midrule\n\\addlinespace\n\\addlinespace\n"
    
    latex_code += "\\bottomrule\n\\end{tabular}%\n}\n"

    # Add caption with significant improvement information
    caption_text = f"REMOVEopt={optimizer}-lr={lr}-mf={main_factor}-npair={n_pair}REMOVE Test performance on increasingly shifted datasets for models selected during ERM and IRM procedures. Different validation datasets are used, and the selection capabilities of PA and validation accuracy are compared."
    latex_code += f"\\caption{{{caption_text}}}\n\\label{{tab:label}}\n\\end" + "{" + "table}"
    
    return latex_code

Start here:

In [34]:
optimizer = 'adam'
lr = 0.0001
main_factor = "pos"
n_pair = False

In [35]:
df = _get_table_of_results(
        model_list = ['lisa_ppred=0.5'],
        main_factor = main_factor,
        n_pair = n_pair,
        optimizer = optimizer, 
        lr = lr
)

Check if `AFR_pred` improves the results, and write it in the caption:

In [None]:
df_afr = df.loc[df["metric"] == "afrp"]
df_afr

In [None]:
latex_code = generate_latex_table(df, optimizer=optimizer, lr=lr, main_factor=main_factor, n_pair=n_pair)
print(latex_code)

In [None]:
df

# EN MASSA

In [42]:
%%capture

latex_code = ""
for main_factor in ['hue']:
        for n_pair in [False]:
                for optimizer, lr in [('adam', 0.0001), ('adam', 0.0005)]:
                        df = _get_table_of_results(
                                model_list = ['lisa_ppred=0.5'],
                                main_factor = main_factor,
                                n_pair = n_pair,
                                optimizer = optimizer, 
                                lr = lr
                        )
                        latex_code_table = generate_latex_table(df, optimizer=optimizer, lr=lr, main_factor=main_factor, n_pair=n_pair)
                        latex_code += latex_code_table + "\n\n"

In [43]:
print(latex_code)

\begin{table}[H]
\centering
\resizebox{\textwidth}{!}{%
\begin{tabular}{l|cl|cl|cl|cl|cl|cl}
\multirow{2}{*}{} & \multicolumn{2}{c|}{\textbf{Test 0}} & \multicolumn{2}{c|}{\textbf{Test 1}} & \multicolumn{2}{c|}{\textbf{Test 2}} & \multicolumn{2}{c|}{\textbf{Test 3}} & \multicolumn{2}{c|}{\textbf{Test 4}} & \multicolumn{2}{c}{\textbf{Test 5}} \\
\textbf{{\color{tab:green} \textbf{LISA}}} & Acc. & $\Delta$Acc. & Acc. & $\Delta$Acc. & Acc. & $\Delta$Acc. & Acc. & $\Delta$Acc. & Acc. & $\Delta$Acc. & Acc. & $\Delta$Acc. \\
\midrule
SD & 99.5 & {\color{tab:green}  \textbf{\Plus 99.3}} & 78.9 & {\color{tab:green}  \textbf{\Plus 73.4}} & 73.1 & {\color{tab:green}  \textbf{\Plus 66.9}} & 71.1 & {\color{tab:green}  \textbf{\Plus 66.7}} & 69.9 & {\color{tab:green}  \textbf{\Plus 68.3}} & 29.7 & {\color{tab:green}  \textbf{\Plus 41.0}} \\
ID & 99.5 & {\color{tab:green}  \textbf{\Plus 99.4}} & 83.0 & {\color{tab:green}  \textbf{\Plus 81.0}} & 53.3 & {\color{tab:green}  \textbf{\Plus 86.1}} & 67.9 

# LATEX TABLE APPENDIX

In [44]:
def generate_latex_table(df, optimizer: str, lr: float, main_factor: str, n_pair: bool):
    # Get the dataset names (e.g., pos_zero_npair, pos_idval_npair, etc.)
    dataset_names = df.columns[2:-1].unique()
    dataset_names_parsed = [dataset_name_parser(ds_name) for ds_name in dataset_names]
    # Get the model names (e.g., erm, irm, etc.)
    model_names = df['model'].unique()
    model_names_parsed = [MODEL_NAMES[model_name] for model_name in model_names]

    # Variable to track if significant improvement was observed in AFR$_P$
    significant_improvement_observed = False

    eps = 1e-4
    
    # Start building the LaTeX table
    latex_code = "\\begin{table}[H]\n\\centering\n\\setlength{\\tabcolsep}{2.5pt}\n\\resizebox{\\textwidth}{!}{%\n\\begin{tabular}{l|ccc|ccc|ccc|ccc|ccc|ccc}\n"

    # Header for shift values
    # shift_text = lambda shift: f"Test \#{shift}"
    shift_text = lambda shift: f"Acc. Test {shift}"
    latex_code += "\\multirow{3}{*}{} & " + " & ".join(
        [
        f"\\multicolumn{{3}}{{c|}}{{\\textbf{{{shift_text(shift)}}}}}" if shift < 5 else f"\\multicolumn{{3}}{{c}}{{\\textbf{{{shift_text(shift)}}}}}"
        for shift in range(6)
    ]
    ) + " \\\\\n"    
    
    # Iterate over each model
    for imodel, (model, model_parsed) in enumerate(zip(model_names, model_names_parsed)):
        # Header for metrics (Acc., PA)
        latex_code += f"\\textbf{{{model_parsed}}} & " + " & ".join(
            # [r"\textbf{Acc.} & \textbf{PA}" for _ in range(6)]
            [r"Acc. & AFR$_\text{P}$ & PA" for _ in range(6)]
        ) + " \\\\\n"
        latex_code += "\\midrule\n"

        # Iterate over each dataset
        for idataset, (dataset, dataset_parsed) in enumerate(zip(dataset_names, dataset_names_parsed)):
            latex_code += f"{dataset_parsed}"
            # latex_code += f"\\textbf{{{dataset_parsed}}}"
            for shift in range(6):
                # Filter the DataFrame for the current dataset, model, and shift
                row = df[(df['shift'] == shift) & (df['model'] == model)]
                if not row.empty:
                    # Extract the metric values for the current dataset, model, and shift
                    acc = 100.0*row[dataset][row['metric'] == 'acc'].values[0]
                    pa = 100.0*row[dataset][row['metric'] == 'pa'].values[0]
                    afr = 100.0*row[dataset][row['metric'] == 'afrp'].values[0]

                    acc_str, pa_str, afr_str = f"{acc:.1f}", f"{pa:.1f}", f"{afr:.1f}"
                    str_metrics = [acc_str, afr_str, pa_str]
                    float_metrics = [float(val) for val in str_metrics]
                    max_val = max(float_metrics)
                    if abs(float_metrics[2] - max_val) < eps:
                        pa_str = f"{{\\textbf{{{pa_str}}}}}"
                    elif abs(float_metrics[0] - max_val) < eps:
                        acc_str = f"{{\\textbf{{{acc_str}}}}}"
                    else:
                        afr_str = f"{{\\textbf{{{afr_str}}}}}"

                    
                    latex_code += f" & {acc_str} & {afr_str} & {pa_str}"
                else:
                    latex_code += " & - & - & -"  # Placeholder if there's no data
            latex_code += " \\\\\n"


        if imodel < len(model_names)-1:
            latex_code += "\\midrule\n\\addlinespace\n\\addlinespace\n"
    
    latex_code += "\\bottomrule\n\\end{tabular}%\n}\n"

    # Add caption with significant improvement information
    caption_text = f"REMOVEopt={optimizer}-lr={lr}-mf={main_factor}-npair={n_pair}APPENDIX."
    latex_code += f"\\caption{{{caption_text}}}\n\\label{{tab:label}}\n\\end" + "{" + "table}"
    
    return latex_code

In [59]:
%%capture

latex_code = ""
for main_factor in ['hue']:
        for n_pair in [True]:
                for optimizer, lr in [('adam', 0.0001), ('adam', 0.0005)]: # ('adam', 0.0001), ('adam', 0.0005)
                        df = _get_table_of_results(
                                model_list = ['lisa_ppred=0.5'],
                                main_factor = main_factor,
                                n_pair = n_pair,
                                optimizer = optimizer, 
                                lr = lr
                        )
                        latex_code_table = generate_latex_table(df, optimizer=optimizer, lr=lr, main_factor=main_factor, n_pair=n_pair)
                        latex_code += latex_code_table + "\n\n"

In [60]:
print(latex_code)

\begin{table}[H]
\centering
\setlength{\tabcolsep}{2.5pt}
\resizebox{\textwidth}{!}{%
\begin{tabular}{l|ccc|ccc|ccc|ccc|ccc|ccc}
\multirow{3}{*}{} & \multicolumn{3}{c|}{\textbf{Acc. Test 0}} & \multicolumn{3}{c|}{\textbf{Acc. Test 1}} & \multicolumn{3}{c|}{\textbf{Acc. Test 2}} & \multicolumn{3}{c|}{\textbf{Acc. Test 3}} & \multicolumn{3}{c|}{\textbf{Acc. Test 4}} & \multicolumn{3}{c}{\textbf{Acc. Test 5}} \\
\textbf{{\color{tab:green} \textbf{LISA}}} & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA \\
\midrule
SD & 99.3 & 99.3 & {\textbf{99.4}} & {\textbf{88.4}} & 88.4 & 83.1 & 53.9 & 53.9 & {\textbf{78.5}} & 57.9 & 57.9 & {\textbf{80.5}} & 63.1 & 63.1 & {\textbf{77.0}} & {\textbf{38.8}} & 38.8 & 35.4 \\
ID & 99.3 & 99.3 & {\textbf{99.4}} & {\textbf{88.4}} & 88.4 & 83.1 & 53.9 & 53.9 & {\textbf{78.5}} & 57.9 & 57.9 & {\textbf{80.5}} & 63.1 & 63.1 & {\textbf{77.0

# TABLES OF SURPLUS