In [1]:
import os
from pathlib import Path
import pyrootutils

notebook_path = Path(os.path.abspath(""))
pyrootutils.setup_root(notebook_path, indicator=".project-root", pythonpath=True)

DIRNAME = r"/cluster/home/vjimenez/adv_pa_new/results/dg/datashift"

In [2]:
from src.plot.dg import *
from src.plot.dg._retrieve import *
from src.plot.dg._plot import *

In [6]:
def extract_names_shift(run_string):
    """(model, optimizer, gpu, ppred)"""
    if run_string[:3] == "sgd":
        optimizer = "sgd"
        run_string = run_string[3:]
    else:
        optimizer = "adam"

    split_run = run_string.split("_")
    if len(split_run) == 1: # erm, irm
        return (split_run[0], optimizer, "ddp", None)
    elif len(split_run) == 2:
        if split_run[1].isdigit(): # lisa_ppred
            return (split_run[0], optimizer, "ddp", split_run[1])
        else: # erm_gpu, irm_gpu
            return (split_run[0], optimizer, "gpu", None)
        
    else: # split_run == 3, lisa_gpu_ppred
        return (split_run[0],optimizer, "gpu", split_run[2])
    
    return None

In [33]:
from matplotlib.ticker import MultipleLocator

def plot_variable_vs_run(
        data: dict,
        dataset_list: list,
        run_names: list,
        metrics: list,
        hue_attribute: str,
        hue_dict: dict,
        ylabel: str,
        legend_labels: list,
        title: str,
        savedir: str,
        yscale: Optional[bool] = "symlog",
        legend: Optional[bool] = True,
        legend_loc: Optional[str] = "best",
        save: Optional[bool] = False,
        version_appendix: Optional[str] = ""
    ) -> None:
    """
        Args:
            data (dict): Dictionary with all the data for the desired runs.
            metric (str): Name of the metric to plot.
            selection_metric (str): Name of the metric that guides the selection of the `metric` values to plot.
            selection_criterion (str): Criterion of the `selection_metric`. Accepts "min", "max", "first" and "last".
            selection_environment (Optional[int]): Environment to implement the selection criterion. If None, it will be implemented
                for each environment separately.
    """
    # Number of runs:
    num_runs = len(run_names)
    run_attributes = [extract_names_shift(name) for name in run_names] #(model, opt, lr)

    name_datasets = set(dataset_list)
    num_datasets = len(name_datasets)

    # Get the font
    fontname = "DejaVu Serif"
    _ = fm.findfont(fm.FontProperties(family=fontname))

    # Subset of the dictionary:
    dict_to_iter = {
        "dataset": list(name_datasets)*num_runs,
        "model": [attrs[0] for attrs in run_attributes]*num_datasets,
        "trainer": [attrs[1] for attrs in run_attributes]*num_datasets,
        "optimizer": [attrs[2] for attrs in run_attributes]*num_datasets,
        "ppred": [attrs[3] for attrs in run_attributes]*num_datasets,
    }    

    df_list = []
    for irun in range(num_runs*num_datasets):
        dict_to_plot = {
            "epochs": np.arange(1, 101)
        }
        dict_to_plot.update({
            key: np.full(100, values[irun])
            for key, values in dict_to_iter.items()
        })
        dict_to_plot.update({
            metric: data[metric][irun]
            for metric in metrics
        })
        try:
            df = pd.DataFrame(dict_to_plot)
        except:
            import ipdb; ipdb.set_trace()
        df_list.append(df)
    
    level_set = pd.concat(df_list)
    
    # Create a line plot
    plt.close('all')
    _, ax = plt.subplots(figsize=(2 * 3.861, 2 * 2.7291))
    sns.set(font_scale=1.9)
    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.serif"] = fontname
    sns.set_style("ticks")


    for metric in metrics:
        sns.lineplot(
            data=level_set,
            ax=ax,
            x="epochs",
            y=metric,
            hue=hue_attribute,
            style=hue_attribute,
            palette=hue_dict,
            dashes=[(2,2)] if metric == metrics[0] else False, #dash_styles.get(metric, False),
            marker=None,
            linewidth=3,
            legend=legend
        )

    ax.xaxis.set_minor_locator(MultipleLocator(1))
    ax.set_xticks([1] + [i for i in range(10,101,10)])
    plt.xticks(rotation=45)

    ax.tick_params(axis="both", which="both", direction="in")
    xticks_font = fm.FontProperties(family=fontname)
    for tick in ax.get_xticklabels():
        tick.set_fontproperties(xticks_font)

    ax.grid(linestyle="--")

    # ax.set_ylim(min(level_set[metric])*2, 0.5)
        
    # posdiff = -(10**np.log10(abs(max(level_set[metric]))))/2
    # ax.set_ylim(min(level_set[metric])*2, posdiff)
    ax.set_xlabel("Epochs", fontname=fontname)
    # r"$10^{-4} \times $ PA"
    ax.set_ylabel(ylabel, fontname=fontname)
    ax.set_yscale(yscale) 
    
    # Legend
    if legend == True:
        handles, _ = ax.get_legend_handles_labels()

        # FOR TRAIN AND VAL METRICS:
        legend_labels = ["Training", "Validation"] + legend_labels
        
        for handle in handles[:2]:
            handle.set_color("black")
        handles[1].set_linestyle("-")
        
            
        legend_properties = {
            "family": fontname,
            'size': 18,
        }  
        ax.legend(
            handles,
            legend_labels,
            loc=legend_loc,
            # loc="lower left",
            # fontsize=12,
            handlelength=0.5,
            prop=legend_properties
        )

    ax.set_title(title, fontname=fontname)
    plt.tight_layout()
    if save:
        plt.savefig(savedir)
        plt.clf()
        plt.close()
    else:
        plt.show()

In [34]:
run_names = [f"erm", f"irm", "lisa_00", "lisa_10"]
ds_list = ["CGO_2_hue"]
data_dict = get_multiple_dict(
    ds_list,
    run_names,
    datashift=True
)

Run: 100%|██████████| 4/4 [00:01<00:00,  2.87it/s]


In [35]:
trainval_metrics = ["acc", "loss", "specificity", "sensitivity", "precision"]
for met in trainval_metrics: 
    legend = False
    if met == "loss":
        legend = False

    mettitle = met.capitalize()
    if met == "acc":
        mettitle = "Accuracy"

    plot_variable_vs_run(
        data=data_dict,
        run_names=run_names,
        dataset_list = ds_list,
        metrics=[f"train/{met}", f"val/{met}"],
        hue_attribute="model",
        hue_dict={
            "erm": "tab:blue",
            "irm": "tab:orange",
            "lisa": "tab:green"
        },
        title=f"{mettitle}",
        legend_labels=["ERM", "IRM", "LISA"],
        ylabel="",
        # title=f"{mod.upper()}",
        savedir=os.path.join(DIRNAME, rf"{ds_list[0]}/adam_mod_{met}.png"),
        yscale="linear",
        legend=legend,
        save=True,
        version_appendix=""
    )

> [0;32m/tmp/ipykernel_1805482/4087828590.py[0m(66)[0;36mplot_variable_vs_run[0;34m()[0m
[0;32m     65 [0;31m            [0;32mimport[0m [0mipdb[0m[0;34m;[0m [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 66 [0;31m        [0mdf_list[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0mdf[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m


{'epochs': array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
        92,  93,  94,  95,  96,  97,  98,  99, 100]), 'dataset': array(['CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue',
       'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue',
       'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue',
       'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue',
       'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue',
       'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', 'CGO_2_hue', '