In [2]:
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
import itertools
from util.save_load import load_kernel_model
from kernels.wrapper import MODELS, KernelModelWrapper
from dataset.ipc2023_learning_domain_info import IPC2023_LEARNING_DOMAINS, get_number_of_ipc2023_training_data
from itertools import product
from IPython.display import display, HTML

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
_LOG_DIR = "../logs_training_kernels"

ITERATIONS = [1, 2, 4, 8, 16]
PRUNES = [0, 1, 2, 4, 8, 16]

configs = product(
    ["wl"],  # wl algorithms
    ITERATIONS,  # iterations
    ["ig"],  # representation
    ["ferry", "blocksworld", "childsnack", "floortile", "miconic", "rovers", "satellite", "sokoban", "spanner", "transport"],  # domains
    ["linear-svr", "lasso", "ridge", "rbf-svr", "quadratic-svr", "cubic-svr", "mlp"],  # models
)


CONFIGS = product(
    ["wl", "2wl"],  # wl algorithms
    ITERATIONS,  # iterations
    PRUNES,  # count prunes
    ["ig"],  # representation
    [
        "ferry",
        "blocksworld",
        "childsnack",
        "floortile",
        "miconic",
        "rovers",
        "satellite",
        "sokoban",
        "spanner",
        "transport",
    ],  # domains
    ["linear-svr", "quadratic-svr", "cubic-svr", "rbf-svr", "lasso", "ridge"],  # ml models
)
CONFIGS = list(CONFIGS)

DOMAINS = IPC2023_LEARNING_DOMAINS

PLT_DIR = "plots"
os.makedirs(PLT_DIR, exist_ok=True)

### Train metrics

In [8]:
def get_data_from_log_file(log_file):
    assert os.path.exists(log_file), log_file
    stats = {}
    lines = list(open(log_file, 'r').readlines())
    for line in lines:
      toks = line.split()
      if "train_mse" in line:
        stats["train_mse"] = float(toks[-1])
      elif "train_f1_macro" in line:
        stats["train_f1"] = float(toks[-1])
      elif "val_mse" in line:
        stats["val_mse"] = float(toks[-1])
      elif "val_f1_macro" in line:
        stats["val_f1"] = float(toks[-1])
      elif "zero_weights" in line:
        weights = int(toks[1].split('/')[1])
        zeros = int(toks[1].split('/')[0])
        stats["nonzero_weights"] = weights - zeros
      elif "Model training completed in " in line:
        stats["time"] = float(toks[-1].replace("s", ""))
    
    if "nonzero_weights" not in stats:
      stats["nonzero_weights"] = "na"
    
    return stats

In [9]:
""" display df over all domains """
def get_data(domain, target):
  d = {
    "config": [],
    "mse": [],
    "f1": [],
    "nonzero_weights": [],
    "time": [],
  }

  assert target in {"H", "D"}

  for wl, iterations, rep, domain, model in CONFIGS:
    desc = f"{wl}_{iterations}_{rep}_{model}_{target}_{domain}"
    log_file = _LOG_DIR + "/" + desc + ".log"
    
    if not os.path.exists(log_file):
      continue

    stats = get_data_from_log_file(log_file)
    stats["config"] = desc

    if len(stats) != len(d):
      continue

    for key in stats:
      d[key].append(stats[key])

  return d

def get_df(domain, target):
  d = get_data(domain, target)
  return pd.DataFrame(d)

# max_times = []
# for domain in IPC2023_LEARNING_DOMAINS:
#   data = get_df(domain, "H")
#   max_times.append(max(data.to_numpy()[:,-1]))
#   display(data)
# print("max time:", max(max_times))

In [18]:
rep = "ig"
target = "H"
for wl, metric, domain in product(["wl", "2wl"], ["train_mse", "train_f1", "val_mse", "val_f1"], DOMAINS):
    for model in MODELS:
        xs = []
        ys = []
        for iterations, prune in product(ITERATIONS, PRUNES):
            desc = "_".join([model, wl, str(iterations), str(prune), rep, domain, target])
            log_file = _LOG_DIR + "/" + desc + ".log"
            if not os.path.exists(log_file):
                continue
            stats = get_data_from_log_file(log_file)
            if metric not in stats:
                continue
            xs.append(iterations)
            ys.append(stats[metric])
        plt.scatter(xs, ys, label=model)
    if "f1" in metric:
        plt.ylim((0, 1))
    elif "mse" in metric:
        plt.ylim((1e-1, 1e2))
        plt.yscale("log")
    plt.xscale("log")
    power_of_2_ticks = [2**i for i in range(int(np.log2(min(ITERATIONS))), int(np.log2(max(ITERATIONS))) + 1)]
    plt.xticks(power_of_2_ticks, [str(tick) for tick in power_of_2_ticks])
    plt.title(f"{metric} {domain} {wl}")
    plt.legend(bbox_to_anchor=(1, 1), loc='upper left')
    try:
        plt.savefig(f"{PLT_DIR}/{metric}_{domain}_{wl}.png", dpi=480, bbox_inches="tight")
    except:
        pass
    plt.clf()

<Figure size 640x480 with 0 Axes>