Currently, all output corresponding to a given seed for a given configuration is stored
in a single file. This notebook will read such output and split it into multiple files,
each named `<seed>_<label>.pt`. This allows for a better management when we only want
to focus on a subset of labels etc.

In [1]:
import json
import os

import torch

parent_dir = os.getcwd()
parent_dir = os.path.join(parent_dir, "backup")


def process(dirname: str) -> None:
    exp_dir = os.path.join(parent_dir, dirname)
    config_path = os.path.join(exp_dir, "config.json")

    with open(config_path, "r") as f:
        config_dict = json.load(f)
        print(f"Config: \n {json.dumps(config_dict, indent=4)}")

    # get all the output files
    directory_list = os.listdir(exp_dir)
    output_files = [
        os.path.join(exp_dir, file) for file in directory_list if file[-3:] == ".pt"
    ]
    print(f"Found {len(output_files)} files.")
    # read all outputs into a list
    for file_path in output_files:
        seed = file_path[-7: -3]
        try:
            assert 0 <= int(seed) <= 9999
        except ValueError:
            print(f"Invalid file path: {file_path}!")
            continue
        output_dict = torch.load(file_path)
        labels = output_dict["labels"]
        for i, label in enumerate(labels):
            new_dict = {
                "label": label,
                "X": output_dict["X_list"][i],
                "Y": output_dict["Y_list"][i],
                "true_means": output_dict["true_means"],
                "pcs_estimates": output_dict["pcs_estimates"][i],
                "correct_selection": output_dict["correct_selection"][i],
            }
            file_name = seed + "_" + label + ".pt"
            new_path = os.path.join(exp_dir, file_name)
            torch.save(new_dict, new_path)
        os.remove(file_path)
    print("Processing complete!")


In [19]:
process("config_13")

Config: 
 {
    "iterations": 100,
    "fit_frequency": 25,
    "fit_tries": 10,
    "num_arms": 6,
    "num_contexts": 20,
    "batch_size": 5,
    "num_fantasies": 0
}
Found 20 files.
Processing complete!
