The output files got quite messed up due to patching up existing output with new
alternatives. This notebook is for making sense of and cleaning up some of the output
files.

In [1]:
import json
import os

import torch
from torch import Tensor

%matplotlib inline
import matplotlib.pyplot as plt

current_dir = os.getcwd()


def get_output_files(dirname: str) -> list:
    exp_dir = os.path.join(current_dir, dirname)
    config_path = os.path.join(exp_dir, "config.json")

    with open(config_path, "r") as f:
        config_dict = json.load(f)
        print(f"Config: \n {json.dumps(config_dict, indent=4)}")

    # get all the output files
    directory_list = os.listdir(exp_dir)
    output_files = [
        os.path.join(exp_dir, file) for file in directory_list if file[-3:] == ".pt"
    ]
    return output_files


def read_files(output_files: list) -> list:
    # read all outputs into a list
    output_dicts = list()
    for file_path in output_files:
        output_dicts.append(torch.load(file_path))
    output_count = len(output_dicts)
    print(f"Read {output_count} output files.")
    return output_dicts


clean_label = "LCEGP_Adam"


def clean_all_outputs(output_files: list, expected_index: int) -> None:
    for file_path in output_files:
        output = torch.load(file_path)
        if len(output["labels"]) <= expected_index:
            continue
        if output["labels"][expected_index] == clean_label:
            output["labels"].pop(expected_index)
            output["X_list"].pop(expected_index)
            output["Y_list"].pop(expected_index)
            output["pcs_estimates"].pop(expected_index)
            output["correct_selection"].pop(expected_index)
            torch.save(output, file_path)
            print(f"Done: {file_path[-7:]}")

def delete_outputs(output_files: list, label: str) -> None:
    r"""Delete the outputs with the given label"""
    for file in output_files:
        file_label = os.path.basename(file)[5:-3]
        if file_label == label:
            print(f"Deleting {os.path.basename(file)}")
            os.remove(file)


def move_outputs(source_files: list, target_dir: str, label: str) -> None:
    r"""Move the source files to target_dir given that the label agrees."""
    for file in source_files:
        file_label = os.path.basename(file)[5:-3]
        if file_label == label:
            print(f"Moving {os.path.basename(file)}")
            target_path = os.path.join(target_dir, os.path.basename(file))
            os.rename(file, target_path)

In [4]:
for dir_no in [
    "b_1_worst",
    "b_2_worst",
    "b_3_worst",
    "b_3_worst_fant",
    "b_4_worst",
    "b_5_worst",
    "g_1_mean",
    "g_2_mean",
    "g_3_mean",
    "g_3_mean_fant",
    "g_4_mean",
    "g_5_mean",
    "g_1_cvar",
    "g_2_cvar",
    "g_3_cvar",
    "g_3_cvar_fant",
    "g_4_cvar",
    "g_5_cvar",
]:
    for label in [
        "ML_PCS",
        "ML_IKG_rho",
        "LCEGP_Gao",
        "LCEGP_Gao_Adam_rand",
        "LCEGP_Gao_Adam_gp",
        "LCEGP_Gao_Adam_reuse",
    ]:
        dirname = "config_" + dir_no
        target_dirname = os.path.join(dirname, "inferior")
        if not os.path.isdir(target_dirname):
            os.mkdir(target_dirname)

        files = get_output_files(dirname)
        move_outputs(files, target_dirname, label)


Config: 
 {
    "iterations": 100,
    "fit_frequency": 10,
    "num_arms": 4,
    "num_contexts": 4,
    "num_fantasies": 0,
    "rho_key": "worst",
    "ground_truth_kwargs": {
        "function": "branin"
    }
}
Config: 
 {
    "iterations": 100,
    "fit_frequency": 10,
    "num_arms": 4,
    "num_contexts": 4,
    "num_fantasies": 0,
    "rho_key": "worst",
    "ground_truth_kwargs": {
        "function": "branin"
    }
}
Moving 0004_ML_IKG_rho.pt
Moving 0002_ML_IKG_rho.pt
Moving 0005_ML_IKG_rho.pt
Moving 0012_ML_IKG_rho.pt
Moving 0007_ML_IKG_rho.pt
Moving 0013_ML_IKG_rho.pt
Moving 0000_ML_IKG_rho.pt
Moving 0006_ML_IKG_rho.pt
Moving 0001_ML_IKG_rho.pt
Moving 0014_ML_IKG_rho.pt
Moving 0008_ML_IKG_rho.pt
Moving 0009_ML_IKG_rho.pt
Moving 0015_ML_IKG_rho.pt
Moving 0019_ML_IKG_rho.pt
Moving 0003_ML_IKG_rho.pt
Moving 0017_ML_IKG_rho.pt
Moving 0018_ML_IKG_rho.pt
Moving 0010_ML_IKG_rho.pt
Moving 0011_ML_IKG_rho.pt
Moving 0016_ML_IKG_rho.pt
Config: 
 {
    "iterations": 100,
    "fit_freq