In [1]:
import os
from pathlib import Path
import pyrootutils

notebook_path = Path(os.path.abspath(""))
pyrootutils.setup_root(notebook_path, indicator=".project-root", pythonpath=True)

DIRNAME = r"/cluster/home/vjimenez/adv_pa_new/results/dg/datashift"

In [2]:
from src.plot.dg import *
from src.plot.dg._retrieve import *
from src.plot.dg._plot import *

In [3]:
LIST_PASPECIFIC_METRICS = ["logPA", "beta", "acc_pa", "AFR_true", "AFR_pred", "CD", "MMD", "FID", "CS", "KL", "W", "oracle"]

def _get_metric_value(
        data,
        list_metric_indexes,
        metric_name,
        run_ind,
        env = 1
    ):
    pos_to_select = list_metric_indexes[env-1][run_ind]
    try:
        if metric_name in LIST_PASPECIFIC_METRICS:
            key_name = f"PA(0,{env})/{metric_name}"
            if metric_name == "oracle":
                key_name = f"oracle/acc_{env-1}"
                
            out = data[key_name][run_ind][pos_to_select].item()
        else:
            out = data[f"{metric_name}"][run_ind][pos_to_select].item()
    except:
        import ipdb; ipdb.set_trace()
    return out

# Apply selection functions to the metric:
def _fun_selection_metric(metric_array, selection_criterion) -> int:
    if selection_criterion == "min":
        return np.argmin(metric_array).item()
    elif selection_criterion == "max":
        return np.argmax(metric_array).item()
    elif selection_criterion == "first":
        return 0
    else:
        return -1

def _generate_indexes(
        data,
        selection_metric,
        selection_criterion,
        selection_environment = None
    ):
    run_names = list(data.keys())
    num_runs = len(data[run_names[0]])

    # Environments to select from:
    envs_to_select = [selection_environment]*5
    if selection_environment == None:
        envs_to_select = [i for i in range(1, 6)]

    # Select the metric indexes:
    datakey = lambda env: selection_metric
    if selection_metric in LIST_PASPECIFIC_METRICS:
        datakey = lambda env: f"PA(0,{env})/{selection_metric}"
        if selection_metric == "oracle":
            datakey = lambda env: f"oracle/acc_{env-1}"

    try:
        list_metric_indexes = [
            [
                _fun_selection_metric(
                    data[datakey(env)][i],
                    selection_criterion
                )
                for i in range(num_runs)
            ]
            for env in envs_to_select
        ]
    except:
        import ipdb; ipdb.set_trace()
    return list_metric_indexes

In [14]:
LIST_PASPECIFIC_METRICS_WANDB = ["logPA", "beta", "acc_pa", "AFR_true", "AFR_pred", "CD", "MMD", "FID", "CS", "KL", "W"]

def dg_pa_datashift_testpaper(
        exp_name: str,
        group: str = "paper",
        labelwise: bool = False,
        dirname: str = "results",
        cache: bool = False,
    ) -> pd.DataFrame:

    entity = "malvai"
    project = "DiagVib-6 Paper"

    api = wandb.Api(timeout=100)
    runs = api.runs(entity + '/' + project, {"group": group})

    pathdir = osp.join(dirname, group)
    fname = osp.join(pathdir, f"test_{exp_name}.pkl")
    os.makedirs(dirname, exist_ok=True)

    if cache and osp.exists(fname):
        with open(fname, 'rb') as file:
            loaded_dict = pickle.load(file) 
        return loaded_dict
    
    
    data = defaultdict(list)
    for run in tqdm(runs, desc="Run: "):
        config = run.config        
        run_name = run.name

        try:
            # If run name is not decomposable in these terms, it means that it is not of our interests.
            sr = float(run_name.split("=")[-1])
            exp, selmet, envs = ["_".join(rn.split("_")[:-1]) for rn in run_name.split("=")[:-1]][1:]
            envs = envs.split("&")
        except:
            continue
        
        if exp != exp_name:
            continue
            
        # Config keys
        data["seed"].append(0)
        data["dataset"].append(group)

        try:
            model_name = config["model/_target_"].split(".")[-1]
        except:
            continue
            
        data["model"].append(model_name)
        if model_name == "LISA":
            data["ppred"].append(config["model/ppred"])
        else:
            data["ppred"].append(None)
        
        # data["n_classes"].append(config["n_classes"])
        # data["net"].append(config['model/net/net'])

        # Characterizing the test
        data["sr"].append(sr)
        data["selection_metric"].append(selmet)
        data["env0"].append(envs[0])
        data["env1"].append(envs[1])

        # for metric in LIST_PASPECIFIC_METRICS_WANDB:
        for metric in ["logPA", "beta", "AFR_pred", "AFR_true", "acc_pa"]:
            metric_key = f"PA(0,1)/{metric}"
            # if metric in ["logPA", "beta", "AFR_pred", "AFR_true", "acc_pa"]:
            #     metric_key += "_test"
            # else:
            #     continue

            metric_key += "_test"
            data[metric].append(
                retrieve_from_history(run, metric_key)[0]
            )

            # if metric == "CD" and labelwise == True:
            #     for lab in range(int(config["n_classes"])):
            #         data[metric + f"@{lab}"].append(retrieve_from_history(run, metric_key + f"_{lab}"))[0]

        # Now we get training metrics:
        for metric_stage in ["loss", "acc", "specificity", "sensitivity", "precision"]:
            data[metric_stage].append(
                retrieve_from_history(run, f"test/{metric_stage}_{selmet}_epoch")[0]
            )
            

    # Store it already, only one dictionary per run:
    df = pd.DataFrame(data)
    if not osp.exists(pathdir):
        os.mkdir(pathdir)
    df.to_pickle(fname)
    print(f"dataframe stored in {fname}.")
    return df

In [12]:
def retrieve_testpaper(exp_name_list: list, group: str = "paper"):
    df_list = []
    for exp_name in exp_name_list:
        df_list.append(
                dg_pa_datashift_testpaper(
                group=group,
                exp_name=exp_name,
                dirname=DIRNAME,
                cache = True
            )
        )

    # return pd.concat(df_list)
    return df_list

In [15]:
exp_names = [
    ["erm_0001", "erm_adam_001", "irm_0001", "irm_adam_001"],
    ["erm", "erm_001", "irm", "irm_001"]
]
group_names = [
    ['CGO_1_hue', 'CGO_2_hue','CGO_3_hue','ZSO_hue_3'], #ZGO_hue_3
    ['CGO_1_pos','CGO_2_pos','CGO_3_pos','ZGO_pos_3','ZSO_pos_3']
]
for i in range(2):
    for group_name in group_names[i]:
        retrieve_testpaper(exp_name_list = exp_names[i], group = group_name)

Run: 100%|██████████| 1126/1126 [00:15<00:00, 72.17it/s] 
Run: 100%|██████████| 1126/1126 [00:08<00:00, 134.15it/s]
Run: 100%|██████████| 1126/1126 [00:10<00:00, 105.50it/s]
Run: 100%|██████████| 1126/1126 [00:07<00:00, 147.93it/s]
Run: 100%|██████████| 964/964 [00:06<00:00, 146.71it/s]
Run: 100%|██████████| 964/964 [00:06<00:00, 157.91it/s]
Run: 100%|██████████| 624/624 [00:04<00:00, 150.28it/s]
Run: 100%|██████████| 624/624 [00:04<00:00, 144.41it/s]
Run:  25%|██▍       | 151/608 [00:00<00:02, 168.18it/s]

exp=erm_001_met=acc_envs=0&1_sr=0.7
exp=erm_001_met=AFR_pred_envs=0&5_sr=0.6
exp=erm_001_met=AFR_pred_envs=0&4_sr=0.6


Run: 100%|██████████| 608/608 [00:03<00:00, 170.45it/s]
Run:  25%|██▍       | 151/608 [00:01<00:03, 145.51it/s]

exp=erm_001_met=acc_envs=0&1_sr=0.7
exp=erm_001_met=AFR_pred_envs=0&5_sr=0.6
exp=erm_001_met=AFR_pred_envs=0&4_sr=0.6


Run: 100%|██████████| 608/608 [00:03<00:00, 161.02it/s]
Run:  25%|██▍       | 151/608 [00:01<00:03, 147.60it/s]

exp=erm_001_met=acc_envs=0&1_sr=0.7
exp=erm_001_met=AFR_pred_envs=0&5_sr=0.6
exp=erm_001_met=AFR_pred_envs=0&4_sr=0.6


Run: 100%|██████████| 608/608 [00:03<00:00, 155.20it/s]
Run: 100%|██████████| 602/602 [00:03<00:00, 170.01it/s]
Run: 100%|██████████| 602/602 [00:03<00:00, 156.44it/s]
Run: 100%|██████████| 602/602 [00:03<00:00, 168.00it/s]
Run: 100%|██████████| 668/668 [00:05<00:00, 126.13it/s]
Run:  45%|████▌     | 301/668 [00:03<00:05, 69.34it/s] 