The output files got quite messed up due to patching up existing output with new
alternatives. This notebook is for making sense of and cleaning up some of the output
files.

In [11]:
import json
import os

import torch
from torch import Tensor

%matplotlib inline
import matplotlib.pyplot as plt

current_dir = os.getcwd()


def get_output_files(dirname: str) -> list:
    exp_dir = os.path.join(current_dir, dirname)
    config_path = os.path.join(exp_dir, "config.json")

    with open(config_path, "r") as f:
        config_dict = json.load(f)
        print(f"Config: \n {json.dumps(config_dict, indent=4)}")

    # get all the output files
    directory_list = os.listdir(exp_dir)
    output_files = [
        os.path.join(exp_dir, file) for file in directory_list if file[-3:] == ".pt"
    ]
    return output_files


def read_files(output_files: list) -> list:
    # read all outputs into a list
    output_dicts = list()
    for file_path in output_files:
        output_dicts.append(torch.load(file_path))
    output_count = len(output_dicts)
    print(f"Read {output_count} output files.")
    return output_dicts


clean_label = "LCEGP_Adam"


def clean_all_outputs(output_files: list, expected_index: int) -> None:
    for file_path in output_files:
        output = torch.load(file_path)
        if len(output["labels"]) <= expected_index:
            continue
        if output["labels"][expected_index] == clean_label:
            output["labels"].pop(expected_index)
            output["X_list"].pop(expected_index)
            output["Y_list"].pop(expected_index)
            output["pcs_estimates"].pop(expected_index)
            output["correct_selection"].pop(expected_index)
            torch.save(output, file_path)
            print(f"Done: {file_path[-7:]}")

def delete_outputs(output_files: list, label: str) -> None:
    r"""Delete the outputs with the given label"""
    for file in output_files:
        file_label = os.path.basename(file)[5:-3]
        if file_label == label:
            print(f"Deleting {os.path.basename(file)}")
            os.remove(file)


def move_outputs(source_files: list, target_dir: str, label: str) -> None:
    r"""Move the source files to target_dir given that the label agrees."""
    for file in source_files:
        file_label = os.path.basename(file)[5:-3]
        if file_label == label:
            print(f"Moving {os.path.basename(file)}")
            target_path = os.path.join(target_dir, os.path.basename(file))
            os.rename(file, target_path)

In [10]:
for dir_no in [
    "0", "00", "000", "1", "2", "3", "6", "7", "8", "9", "10", "11", "12", "13"
]:
    dirname = "config_" + dir_no
    dirname = os.path.join("backup", dirname)
    label = "Gao"

    files = get_output_files(dirname)
    delete_outputs(files, label)


Config: 
 {
    "iterations": 100,
    "fit_frequency": 10,
    "num_arms": 4,
    "num_contexts": 4
}
Config: 
 {
    "iterations": 100,
    "fit_frequency": 10,
    "num_arms": 4,
    "num_contexts": 4,
    "randomize_ties": 1
}
Config: 
 {
    "iterations": 100,
    "fit_frequency": 10,
    "fit_tries": 5,
    "num_arms": 4,
    "num_contexts": 4,
    "randomize_ties": 1
}
Config: 
 {
    "iterations": 400,
    "fit_frequency": 10,
    "num_arms": 4,
    "num_contexts": 4,
    "batch_size": 5
}
Config: 
 {
    "iterations": 200,
    "fit_frequency": 10,
    "num_arms": 4,
    "num_contexts": 10,
    "batch_size": 1,
    "randomize_ties": 1
}
Config: 
 {
    "iterations": 200,
    "fit_frequency": 10,
    "num_arms": 4,
    "num_contexts": 4,
    "batch_size": 5,
    "num_fantasies": 0
}
Config: 
 {
    "iterations": 400,
    "fit_frequency": 500,
    "num_arms": 4,
    "num_contexts": 8,
    "batch_size": 5,
    "num_fantasies": 0
}
Config: 
 {
    "iterations": 400,
    "fit_freque

In [20]:
for dir_no in [
    "0", "00", "000", "1", "2", "3", "6", "7", "8", "9", "10", "11", "12", "13"
]:
    dirname = "config_" + dir_no
    target_dir = os.path.join(current_dir, dirname)
    dirname = os.path.join("backup", dirname)
    label = "LCEGP_Adam"

    files = get_output_files(dirname)
    move_outputs(files, target_dir, label)

Config: 
 {
    "iterations": 100,
    "fit_frequency": 10,
    "num_arms": 4,
    "num_contexts": 4
}
Moving 0009_LCEGP_Adam.pt
Moving 0032_LCEGP_Adam.pt
Moving 0031_LCEGP_Adam.pt
Moving 0003_LCEGP_Adam.pt
Moving 0021_LCEGP_Adam.pt
Moving 0039_LCEGP_Adam.pt
Moving 0025_LCEGP_Adam.pt
Moving 0002_LCEGP_Adam.pt
Moving 0012_LCEGP_Adam.pt
Moving 0013_LCEGP_Adam.pt
Moving 0024_LCEGP_Adam.pt
Moving 0033_LCEGP_Adam.pt
Moving 0011_LCEGP_Adam.pt
Moving 0022_LCEGP_Adam.pt
Moving 0036_LCEGP_Adam.pt
Moving 0028_LCEGP_Adam.pt
Moving 0007_LCEGP_Adam.pt
Moving 0010_LCEGP_Adam.pt
Moving 0008_LCEGP_Adam.pt
Moving 0005_LCEGP_Adam.pt
Moving 0019_LCEGP_Adam.pt
Moving 0029_LCEGP_Adam.pt
Moving 0015_LCEGP_Adam.pt
Moving 0034_LCEGP_Adam.pt
Moving 0001_LCEGP_Adam.pt
Moving 0004_LCEGP_Adam.pt
Moving 0016_LCEGP_Adam.pt
Moving 0027_LCEGP_Adam.pt
Moving 0020_LCEGP_Adam.pt
Moving 0038_LCEGP_Adam.pt
Config: 
 {
    "iterations": 100,
    "fit_frequency": 10,
    "num_arms": 4,
    "num_contexts": 4,
    "randomize_

In [17]:
target_dir

'/home/saitcakmak/PycharmProjects/contextual_rs/experiments/contextual_rs_categorical_from_gp/backup/config_13'

In [12]:
dirname = "config_2"

files = get_output_files(dirname)

output = read_files(files)

for seed, out in enumerate(output):
    print(f"Seed {seed}, labels: {out['labels']}, len pcs: {len(out['pcs_estimates'])}")

Config: 
 {
    "iterations": 200,
    "fit_frequency": 10,
    "num_arms": 4,
    "num_contexts": 10,
    "batch_size": 1,
    "randomize_ties": 1
}
Read 19 output files.
Seed 0, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 1, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 2, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 3, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 4, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 5, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 6, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 7, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 8, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 9, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 10, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 11, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 12, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 13, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 14, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 15, labels: ['LCEGP', 'Li', 'Gao'], len pcs: 3
Seed 1

Checking that whenever LCEGP_reeval exists, we have a corresponding entry for
pcs_estimates and correct_selection.

In [13]:
seed = -1

for key in output[seed]:
    print(f"key {key}, len {len(output[seed][key])}")


key labels, len 3
key X_list, len 3
key Y_list, len 3
key true_means, len 40
key pcs_estimates, len 3
key correct_selection, len 3


In [14]:
for seed, out in enumerate(output):
    len_label = len(out["labels"])
    len_pcs = len(out["pcs_estimates"])
    len_cs = len(out["correct_selection"])
    if len_pcs != len_cs:
        print(f"Seed {seed} has mismatch!")
        raise RuntimeError

Now that we're sure that it is all good (did not get error in previous step), we can
clean the output.


In [15]:
clean_all_outputs(files, 3)