In [1]:
import sys
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
import itertools
from util.save_load import load_kernel_model
from util.scrape_log import scrape_kernel_train_log
from kernels.wrapper import MODELS, KernelModelWrapper
from dataset.ipc2023_learning_domain_info import IPC2023_LEARNING_DOMAINS, get_number_of_ipc2023_training_data
from itertools import product
from IPython.display import display, HTML

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
_LOG_DIR = "icaps24_train_logs"

ITERATIONS = [4]
PRUNES = [0]

configs = product(
    ["1wl"],  # wl algorithms
    ITERATIONS,  # iterations
    PRUNES,  # prune
    ["ilg"],  # representation
    ["blocksworld", "childsnack", "ferry", "floortile", "miconic", "rovers", "satellite", "sokoban", "spanner", "transport"],  # domains
    ["linear-svr"],  # models
)
CONFIGS = list(configs)

DOMAINS = IPC2023_LEARNING_DOMAINS

PLT_DIR = "plots"
os.makedirs(PLT_DIR, exist_ok=True)

### Train metrics

In [5]:
""" display df over all domains """


def get_data(target):
    d = {
        "config": [],
        "train_mse": [],
        "val_mse": [],
        "train_f1": [],
        "val_f1": [],
        "nonzero_weights": [],
        "time": [],
    }

    assert target in {"H", "D"}

    for wl, iterations, prune, rep, domain, model in CONFIGS:
        desc = "_".join([domain, rep, wl, str(iterations), str(prune), model, target])
        log_file = _LOG_DIR + "/" + desc + ".log"

        if not os.path.exists(log_file):
            continue

        stats = scrape_kernel_train_log(log_file)
        stats["config"] = desc

        if len(stats) != len(d):
            continue

        for key in stats:
            d[key].append(stats[key])

    return d


def get_df(target):
    d = get_data(target)
    return pd.DataFrame(d)


# max_times = []
# for domain in IPC2023_LEARNING_DOMAINS:
#   data = get_df(domain, "H")
#   max_times.append(max(data.to_numpy()[:,-1]))
#   display(data)
# print("max time:", max(max_times))
df = get_df("H")
df

Unnamed: 0,config,train_mse,val_mse,train_f1,val_f1,nonzero_weights,time
0,blocksworld_ilg_1wl_4_0_linear-svr_H,0.01,1.97,1.0,0.28,51262,0.19
1,childsnack_ilg_1wl_4_0_linear-svr_H,0.4,0.4,0.64,0.62,684,0.88
2,ferry_ilg_1wl_4_0_linear-svr_H,0.01,0.23,1.0,0.59,29301,0.18
3,floortile_ilg_1wl_4_0_linear-svr_H,0.42,1.25,0.75,0.35,27358,4.92
4,miconic_ilg_1wl_4_0_linear-svr_H,0.41,0.91,0.69,0.34,3942,1.39
5,rovers_ilg_1wl_4_0_linear-svr_H,0.02,0.41,0.98,0.61,28051,4.03
6,satellite_ilg_1wl_4_0_linear-svr_H,0.01,0.22,1.0,0.74,33736,0.75
7,sokoban_ilg_1wl_4_0_linear-svr_H,23.56,29.4,0.13,0.02,2736,26.47
8,spanner_ilg_1wl_4_0_linear-svr_H,1.03,2.36,0.56,0.25,3230,1.73
9,transport_ilg_1wl_4_0_linear-svr_H,0.01,0.9,1.0,0.36,19897,0.24


In [4]:
rep = "ilg"
target = "H"
prune = 0
model = "linear-svr"
for metric, domain in product(["train_mse", "train_f1", "val_mse", "val_f1"], DOMAINS):
    for wl in ["1wl", "2gwl", "2lwl"]:
        xs = []
        ys = []
        for iterations in ITERATIONS:
            desc = "_".join([domain, rep, wl, str(iterations), str(prune), model, target])
            log_file = _LOG_DIR + "/" + desc + ".log"
            if not os.path.exists(log_file):
                continue
            stats = get_data_from_log_file(log_file)
            if metric not in stats:
                continue
            xs.append(iterations)
            ys.append(stats[metric])
        plt.plot(xs, ys, label=wl)
    if "f1" in metric:
        plt.ylim((0, 1))
    elif "mse" in metric:
        # plt.ylim((1e-1, 1e2))
        plt.yscale("log")
    # plt.xscale("log")
    # power_of_2_ticks = [2**i for i in range(int(np.log2(min(ITERATIONS))), int(np.log2(max(ITERATIONS))) + 1)]
    # plt.xticks(power_of_2_ticks, [str(tick) for tick in power_of_2_ticks])
    plt.title(f"{metric} {domain}")
    plt.legend(bbox_to_anchor=(1, 1), loc='upper left')
    try:
        plt.savefig(f"{PLT_DIR}/{metric}_{domain}.png", dpi=480, bbox_inches="tight")
    except:
        pass
    plt.clf()

NameError: name 'get_data_from_log_file' is not defined