In [None]:
import os
from pathlib import Path
import pyrootutils

notebook_path = Path(os.path.abspath(""))
pyrootutils.setup_root(notebook_path, indicator=".project-root", pythonpath=True)

DIRNAME = r"/cluster/home/vjimenez/adv_pa_new/results/dg/datashift"

In [None]:
from src.plot.dg import *
from src.plot.dg._retrieve import *
from src.plot.dg._plot import *

# GET MODEL SELECTION TABLES

In [None]:
def get_dictionary(dataset_name, run_names: list, datashift: bool = True):
    dirname = r"/cluster/home/vjimenez/adv_pa_new/results/dg/datashift"

    data_dict_list = []
    for run_name in tqdm(run_names, desc="Run: "):
        fname = osp.join(osp.join(dirname, dataset_name), f"test_{run_name}.pkl")
        with open(fname, 'rb') as file:
            data_dict = pickle.load(file) 
        data_dict_list.append(data_dict)

    merged_dict = defaultdict(list)
    for d in data_dict_list:
        for key, value in d.items():
            merged_dict[key].extend(value)

    return merged_dict

def get_multiple_dict(dataset_names, run_names, datashift: bool = True):
    data_dict_list_2 = []
    for dataset_name in dataset_names:
        data_dict_list_2.append(
            get_dictionary(dataset_name, run_names, datashift)
        )

    merged_dict = defaultdict(list)
    for d in data_dict_list_2:
        for key, value in d.items():
            merged_dict[key].extend(value)

    return pd.DataFrame(merged_dict)

In [None]:
data_dict = get_multiple_dict(
        dataset_names=["CGO_1_hue"],
        run_names=[
            "erm",
        ],
        datashift = True
    )

# EN MASSA

In [None]:
with open("/cluster/home/vjimenez/adv_pa_new/results/dg/datashift/CGO_1_pos/test_irm.pkl", 'rb') as file:
    data_dict = pickle.load(file) 

df = pd.DataFrame(data_dict)

In [None]:
df

In [None]:
metric = "acc"
selection_metric = "AFR_pred"

df_sel = df.loc[
    (df["sr"] == 1.0) & (df["env1"] == "1"),
    [
        "sr",
        "selection_metric",
        "env1",
        metric,
     ]
]

df_sel

In [None]:
df_sel[df_sel["selection_metric"] == "logPA"]["acc"].values[0]

In [5]:
DATASET_DICT = {
    "ZSO_hue_3": "ZSO",
    "ZGO_hue_3": "ZGO",
    "CGO_1_hue": "1-CGO",
    "CGO_2_hue": "2-CGO",
    "CGO_3_hue": "3-CGO",
    "ZSO_pos_3": "ZSO",
    "ZGO_pos_3": "ZGO",
    "CGO_1_pos": "1-CGO",
    "CGO_2_pos": "2-CGO",
    "CGO_3_pos": "3-CGO"
}

FACTOR_DICT = {
    "pos": r"\texttt{position}",
    "hue": r"\texttt{hue}"
}

MODEL_NAMES = {
    "erm": r"{\color{tab:blue} \textbf{ERM}}",
    "irm": r"{\color{tab:orange} \textbf{IRM}}",
    "lisa": r"{\color{tab:green} \textbf{LISA}}"
}


def dataset_name_parser(ds_name: str):
    return f"{DATASET_DICT[ds_name]}"

In [24]:
def generate_latex_table(
        model_name_list: list[str],
        ds_name_list: list[str],
        lr: float,
        main_factor: str,
    ):
    dataset_names_parsed = [DATASET_DICT[ds_name] for ds_name in ds_name_list]
    # model_names_parsed = [MODEL_NAMES[model_name.split("_")[0]] for model_name in model_name_list]
    model_names_parsed = model_name_list
    
    # Start building the LaTeX table
    latex_code = "\\begin{table}[H]\n\\centering\n\\setlength{\\tabcolsep}{2.5pt}\n\\resizebox{\\textwidth}{!}{%\n\\begin{tabular}{l|ccc|ccc|ccc|ccc|ccc}\n"

    # Header for shift values
    # shift_text = lambda shift: f"Test \#{shift}"
    shift_text = lambda shift: f"Acc. Test {shift}"
    latex_code += "\\multirow{3}{*}{} & " + " & ".join(
        [
        f"\\multicolumn{{3}}{{c|}}{{\\textbf{{{shift_text(shift)}}}}}" if shift < 5 else f"\\multicolumn{{3}}{{c}}{{\\textbf{{{shift_text(shift)}}}}}"
        for shift in range(1, 6)
    ]
    ) + " \\\\\n"

    eps = 1e-4
    
    # Iterate over each model
    for imodel, (model, model_parsed) in enumerate(zip(model_name_list, model_names_parsed)):
        # Header for metrics (Acc., PA)
        latex_code += f"\\textbf{{{model_parsed}}} & " + " & ".join(
            # [r"\textbf{Acc.} & \textbf{PA}" for _ in range(6)]
            [r"Acc. & AFR$_\text{P}$ & PA" for _ in range(1, 6)]
        ) + " \\\\\n"
        latex_code += "\\midrule\n"

        # Iterate over each dataset
        for idataset, (dataset, dataset_parsed) in enumerate(zip(ds_name_list, dataset_names_parsed)):
            latex_code += f"{dataset_parsed}"
            # latex_code += f"\\textbf{{{dataset_parsed}}}"

            try:
                with open(rf"/cluster/home/vjimenez/adv_pa_new/results/dg/datashift/{dataset}/test_{model}.pkl", 'rb') as file:
                    df = pd.DataFrame((pickle.load(file)))
            except:
                continue
      
            for shift in range(1, 6):
                # Filter the DataFrame for the current dataset, model, and shift
                row = df[(df['env1'] == str(shift)) & (df["sr"] == 1.0)]
                if not row.empty:                    
                    # Extract the metric values for the current dataset, model, and shift
                    acc = 100.0*row[row["selection_metric"] == "acc"]["acc"].values[0]
                    beta = 100.0*row[row["selection_metric"] == "acc"]["beta"].values[0]

                    pa = 100.0*row[row["selection_metric"] == "logPA"]["acc"].values[0]
                    afr = 100.0*row[row["selection_metric"] == "AFR_pred"]["acc"].values[0]


                    acc_str, pa_str, afr_str = f"{acc:.1f}", f"{pa:.1f}", f"{afr:.1f}"
                    str_metrics = [acc_str, afr_str, pa_str]
                    float_metrics = [float(val) for val in str_metrics]
                    max_val = max(float_metrics)
                    if abs(float_metrics[2] - max_val) < eps:
                        pa_str = f"{{\\textbf{{{pa_str}}}}}"
                    elif abs(float_metrics[0] - max_val) < eps:
                        acc_str = f"{{\\textbf{{{acc_str}}}}}"
                    else:
                        afr_str = f"{{\\textbf{{{afr_str}}}}}"

                    pa_str = "_________" + str(beta) + "_________"


                    latex_code += f" & {acc_str} & {afr_str} & {pa_str}"
                else:
                    latex_code += " & - & - & -"  # Placeholder if there's no data


                break
            
            latex_code += " \\\\\n"

        if imodel < len(model_name_list)-1:
            latex_code += "\\midrule\n\\addlinespace\n\\addlinespace\n"
    
    latex_code += "\\bottomrule\n\\end{tabular}%\n}\n"

    # Add caption with significant improvement information
    caption_text = f"REMOVE-lr={lr}-mf={main_factor}REMOVE Test performance on increasingly shifted datasets for models selected during ERM and IRM procedures. Different validation datasets are used, and the selection capabilities of PA and validation accuracy are compared."
    latex_code += f"\\caption{{{caption_text}}}\n\\label{{tab:label}}\n\\end" + "{" + "table}"
    
    return latex_code

In [29]:
latex_table = generate_latex_table(
        model_name_list = ['irm_adam_001', 'irm_0001'],
        ds_name_list = ['ZGO_hue_3', 'CGO_1_hue','CGO_2_hue','CGO_3_hue','ZSO_hue_3'],
        # ds_name_list=['ZGO_pos_3','CGO_1_pos','CGO_2_pos','CGO_3_pos','ZSO_pos_3'],
        lr = 0.001,
        main_factor = 'hue',
)
print(latex_table)

\begin{table}[H]
\centering
\setlength{\tabcolsep}{2.5pt}
\resizebox{\textwidth}{!}{%
\begin{tabular}{l|ccc|ccc|ccc|ccc|ccc}
\multirow{3}{*}{} & \multicolumn{3}{c|}{\textbf{Acc. Test 1}} & \multicolumn{3}{c|}{\textbf{Acc. Test 2}} & \multicolumn{3}{c|}{\textbf{Acc. Test 3}} & \multicolumn{3}{c|}{\textbf{Acc. Test 4}} & \multicolumn{3}{c}{\textbf{Acc. Test 5}} \\
\textbf{irm_adam_001} & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA \\
\midrule
ZGO & 50.1 & 56.0 & _________0.0_________ \\
1-CGO & 63.0 & 64.0 & _________0.0_________ \\
2-CGO & 69.0 & {\textbf{82.4}} & _________0.0_________ \\
3-CGO & 79.5 & 91.1 & _________0.0_________ \\
ZSO & 99.4 & 99.5 & _________116.32434129714966_________ \\
\midrule
\addlinespace
\addlinespace
\textbf{irm_0001} & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA & Acc. & AFR$_\text{P}$ & PA

# TABLES OF SURPLUS 

In [None]:
def generate_latex_table(
        model_name_list: list[str],
        ds_name_list: list[str],
        lr: float,
        main_factor: str,
    ):
    dataset_names_parsed = [DATASET_DICT[ds_name] for ds_name in ds_name_list]
    model_names_parsed = [MODEL_NAMES[model_name.split("_")[0]] for model_name in model_name_list]
    
    # Start building the LaTeX table
    latex_code = "\\begin{table}[H]\n\\centering\n\\resizebox{\\textwidth}{!}{%\n\\begin{tabular}{l|cl|cl|cl|cl|cl}\n"

    # Header for shift values
    # shift_text = lambda shift: f"Test \#{shift}"
    shift_text = lambda shift: f"Test {shift}"
    latex_code += "\\multirow{2}{*}{} & " + " & ".join(
        [
        f"\\multicolumn{{2}}{{c|}}{{\\textbf{{{shift_text(shift)}}}}}" if shift < 5 else f"\\multicolumn{{2}}{{c}}{{\\textbf{{{shift_text(shift)}}}}}"
        for shift in range(1, 6)
    ]
    ) + " \\\\\n"

    eps = 1e-4
    
    # Iterate over each model
    for imodel, (model, model_parsed) in enumerate(zip(model_name_list, model_names_parsed)):
        # Header for metrics (Acc., PA)
        latex_code += f"\\textbf{{{model_parsed}}} & " + " & ".join(
            # [r"\textbf{Acc.} & \textbf{PA}" for _ in range(6)]
            [r"Acc. & $\Delta$Acc." for _ in range(1, 6)]
        ) + " \\\\\n"
        latex_code += "\\midrule\n"

        # Iterate over each dataset
        for idataset, (dataset, dataset_parsed) in enumerate(zip(ds_name_list, dataset_names_parsed)):
            latex_code += f"{dataset_parsed}"
            # latex_code += f"\\textbf{{{dataset_parsed}}}"

            try:
                with open(rf"/cluster/home/vjimenez/adv_pa_new/results/dg/datashift/{dataset}/test_{model}.pkl", 'rb') as file:
                    df = pd.DataFrame((pickle.load(file)))
            except:
                continue
      
            for shift in range(1, 6):
                # Filter the DataFrame for the current dataset, model, and shift
                row = df[(df['env1'] == str(shift)) & (df["sr"] == 1.0)]
                if not row.empty:
                    
                    # Extract the metric values for the current dataset, model, and shift
                    acc = 100.0*row[row["selection_metric"] == "acc"]["acc"].values[0]
                    pa = 100.0*row[row["selection_metric"] == "logPA"]["acc"].values[0] - 100.0*row[row["selection_metric"] == "acc"]["acc"].values[0]

                    acc_str = f"{acc:.1f}"
                    if float(f"{pa:.1f}") > 0:
                        pa_str_in = f"\Plus {abs(pa):.1f}"
                        pa_str = f"{{\\color{{tab:green}}  \\textbf{{{pa_str_in}}}}}"
                    elif float(f"{pa:.1f}") < 0:
                        pa_str_in = f"\Minus {abs(pa):.1f}"
                        pa_str = f"{{\\color{{tab:red}} \\textbf{{{pa_str_in}}}}}"
                    else:
                        pa_str = r"\PlusMinus 0.01" #"0.0"


                    latex_code += f" & {acc_str} & {pa_str}"
                else:
                    latex_code += " & - & -"  # Placeholder if there's no data
            latex_code += " \\\\\n"


        if imodel < len(model_name_list)-1:
            latex_code += "\\midrule\n\\addlinespace\n\\addlinespace\n"
    
    latex_code += "\\bottomrule\n\\end{tabular}%\n}\n"

    # Add caption with significant improvement information
    caption_text = f"REMOVE-lr={lr}-mf={main_factor}REMOVE"
    latex_code += f"\\caption{{{caption_text}}}\n\\label{{tab:label}}\n\\end" + "{" + "table}"
    
    return latex_code

In [None]:
latex_table = generate_latex_table(
        model_name_list = ['irm_0001', 'irm_adam_001'],
        ds_name_list = ['ZGO_hue_3', 'CGO_1_hue','CGO_2_hue','CGO_3_hue','ZSO_hue_3'],
        # ds_name_list=['ZSO_pos_3','CGO_1_pos','CGO_2_pos','CGO_3_pos','ZGO_pos_3'],
        lr = 0.001,
        main_factor = 'hue',
)
print(latex_table)