In [3]:
import numpy as np
import logging
import random
import yaml

from lgtree.lgtree import run_optimisation_mcts
from Trial_functions import FUNCTIONS, plot_3d_function

In [4]:
# Load per-benchmark configuration from YAML
# ------------------------------------------------------------
with open("mcts_benchmarks.yaml", "r") as f:
    BENCHMARK_CONFIG = yaml.safe_load(f)


In [5]:
# ------------------------------------------------------------
# Global knobs (shared across all benchmarks)
# ------------------------------------------------------------
SAMPLING_MODE = "logistic"  # default; can override per call

# These are still global across benchmarks unless you also
# put them in YAML. Easy to promote later if needed.
A_MIN, A_MAX = 0.008, 0.08      # range of 'a' scanned across trees
B0 = 0.5                        # initial window parameter b
TARGET = 0.0                    # fallback target if YAML doesn't specify

# Fallback defaults if YAML misses a key
DEFAULT_EVAL_CUTOFF = 15000
DEFAULT_NTREES = 20
DEFAULT_TOP_K = 1
DEFAULT_ALPHA = 0.1
DEFAULT_DROP_FACTOR = 0.93
DEFAULT_EXPLORE_CONSTANT_MIN = 1e-1


# ------------------------------------------------------------
# Run a single benchmark from YAML config
# ------------------------------------------------------------
def run_mcts_on_benchmark(
    benchmark_id: str,
    sampling_mode: str | None = None,
    custom_dim: int | None = None,
):
    # Use global default if caller doesn't pass a mode
    if sampling_mode is None:
        sampling_mode = SAMPLING_MODE

    if benchmark_id not in BENCHMARK_CONFIG:
        raise ValueError(f"Unknown benchmark_id '{benchmark_id}'")

    cfg = BENCHMARK_CONFIG[benchmark_id]
    func = FUNCTIONS[benchmark_id]          # Trial_functions.FUNCTIONS["Fxx"]

    # ------- extract + type-cast config from YAML -------
    name = cfg.get("name", benchmark_id)

    # YAML dimension
    yaml_dim = int(cfg["dim"])
    # If custom_dim is provided, override YAML; else use YAML
    dim = int(custom_dim) if custom_dim is not None else yaml_dim

    lb_val = float(cfg["lb"])
    ub_val = float(cfg["ub"])

    eval_cutoff = int(cfg.get("eval_cutoff", DEFAULT_EVAL_CUTOFF))
    ntrees = int(cfg.get("ntrees", DEFAULT_NTREES))
    alpha = float(cfg.get("alpha", DEFAULT_ALPHA))
    drop_factor = float(cfg.get("drop_factor", DEFAULT_DROP_FACTOR))
    explore_constant_min = float(cfg.get("explore_constant_min", DEFAULT_EXPLORE_CONSTANT_MIN))
    top_k = int(cfg.get("top_k", DEFAULT_TOP_K))

    # --- target info from YAML (optional) ---
    global_min_val = cfg.get("global_min_value", None)
    global_min_points = cfg.get("global_min_points", None)
    # per-benchmark target used inside MCTS window adaptation
    target_value = float(global_min_val) if global_min_val is not None else float(TARGET)

    lb = np.full(dim, lb_val, dtype=float)
    ub = np.full(dim, ub_val, dtype=float)

    # -------------------- random seeding --------------------
    seed_value = random.randint(1, 100000)
    random.seed(seed_value)
    np.random.seed(seed_value)

    # optional: keep track of used seeds
    with open("used_seeds.txt", "a") as seedfile:
        seedfile.write(f"{benchmark_id},{seed_value}\n")

    # -------------------- logging setup --------------------
    logging.basicConfig(
        filename=f"mcts_{benchmark_id.lower()}_{sampling_mode.lower()}.log",
        level=logging.INFO,
        filemode="w",
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    )
    logger = logging.getLogger(f"MCTSLogger_{benchmark_id}")

    # -------------------- print / log configuration summary --------------------
    print("\n====================================================")
    print(f"Running benchmark      : {benchmark_id} ({name})")
    print("Extracted configuration:")
    print(f"  dim (YAML)           : {yaml_dim}")
    if custom_dim is not None:
        print(f"  dim (OVERRIDE)       : {dim}")
    print(f"  bounds               : [{lb_val}, {ub_val}]")
    print(f"  eval_cutoff          : {eval_cutoff}")
    print(f"  ntrees               : {ntrees}")
    print(f"  top_k                : {top_k}")
    print(f"  alpha (window)       : {alpha}")
    print(f"  drop_factor          : {drop_factor}")
    print(f"  explore_constant_min : {explore_constant_min}")
    print(f"  sampling_mode        : {sampling_mode}")
    print(f"  seed_value           : {seed_value}")
    if global_min_val is not None:
        print(f"  target f* (from YAML): {global_min_val}")
    else:
        print(f"  target f* (fallback) : {TARGET}")
    print("====================================================\n")

    logger.info(f"Benchmark: {benchmark_id} ({name})")
    logger.info(f"Random seed used: {seed_value}")
    logger.info(f"SAMPLING_MODE = {sampling_mode}")
    logger.info(
        f"DIM={dim} (yaml_dim={yaml_dim}), LB={lb_val}, UB={ub_val}, "
        f"EVAL_CUTOFF={eval_cutoff}, NTREES={ntrees}, "
        f"ALPHA={alpha}, DROP_FACTOR={drop_factor}, TOP_K={top_k}, "
        f"EXPLORE_CONSTANT_MIN={explore_constant_min}, "
        f"TARGET={target_value}"
    )

    # -------------------- 3D surface plot if possible --------------------
    # We can only do a 3D surface plot if the *input* is 2D (x1, x2 -> f)
    if dim == 2:
        print(f"\n=== {benchmark_id} ({name}) – 3D surface plot ===")
        plot_3d_function(func, lb, ub)
    else:
        print(f"\n=== {benchmark_id} ({name}) – dim={dim}, skipping 3D plot ===")

    # -------------------- run MCTS --------------------
    print(f"\n=== {benchmark_id} ({name}) – running MCTS ({sampling_mode}) ===")
    logger.info("Starting MCTS optimisation")

    minscore, best_parameter = run_optimisation_mcts(
        objfunc=func,              # func(x) -> (x_relaxed, score)
        logger=logger,
        lb=lb,
        ub=ub,
        ntrees=ntrees,
        top_k=top_k,
        a_min=A_MIN,
        a_max=A_MAX,
        b0=B0,
        explore_constant_min=explore_constant_min,
        niterations_local=5,
        niterations_global=5,
        nexpand=2,
        nsimulate=1,
        nplayouts=10,
        max_depth_global=12,
        n_total_evals=eval_cutoff,
        score_threshold_iterations_global=2,
        patience_mcts_global=4,
        patience_mcts_local=1,
        stagnation_threshold_pruning=2,
        alpha=alpha,
        target=target_value,
        aggressive_drop_factor=drop_factor,
        epsilon=1e-30,
        verbose=False,
        sampling_mode=sampling_mode,
    )

    # --------- summary + simple comparison with YAML optimum ---------
    print("---------- Result ----------")
    print(f"Benchmark ID              : {benchmark_id}")
    print(f"Name                      : {name}")
    print(f"Sampling mode             : {sampling_mode}")
    print(f"Best score (MCTS)         : {minscore}")
    print(
        f"Best parameter (first 5)  : {best_parameter[:5]} "
        f"... (dim={len(best_parameter)})"
    )

    logger.info("========== Summary / Comparison ==========")
    logger.info(f"Best score (MCTS): {minscore}")
    logger.info(f"Best parameter (first 5): {best_parameter[:5]}")

    # Print known optimum from YAML if present
    if global_min_val is not None:
        print("\n--- Known optimum from YAML ---")
        print(f"Known global minimum f*   : {global_min_val}")
        logger.info(f"Known global minimum f*: {global_min_val}")

    if global_min_points is not None:
        print("Known solution point(s)   :")
        print(global_min_points)
        logger.info(f"Known solution point(s): {global_min_points}")

    return minscore, best_parameter




In [6]:

run_mcts_on_benchmark("F2", sampling_mode="logistic")#custom_dim = 5



Running benchmark      : F2 (Schwefel 2.22)
Extracted configuration:
  dim (YAML)           : 30
  bounds               : [-10.0, 10.0]
  eval_cutoff          : 20000
  ntrees               : 20
  top_k                : 1
  alpha (window)       : 0.06
  drop_factor          : 0.95
  explore_constant_min : 0.1
  sampling_mode        : logistic
  seed_value           : 29721
  target f* (from YAML): 0.0


=== F2 (Schwefel 2.22) – dim=30, skipping 3D plot ===

=== F2 (Schwefel 2.22) – running MCTS (logistic) ===
Global Round:
  Best score: 2.8557570827352876
  Parameter: [0.07224666688513004, -0.08267734779439806, -0.12883349315602466, 0.19048352771002186, 0.05804638400686102, 0.2723435592730894, 0.15614506576893916, 0.031406921209786276, -0.09795017668440487, -0.11473417507756878, -0.005257050479935188, -0.13795028106393836, 0.05270527097959743, -0.032915523955320936, -0.02571555517696389, -0.23209221271635627, -0.004649796653460214, -0.07488175095791938, -0.008659315818862368, -0.29004

(1.524270787811588e-05,
 [-4.1346182300117107e-07,
  5.892838128573885e-07,
  5.883663529715477e-07,
  4.175822212459934e-07,
  -4.1594158695090755e-07,
  -4.1408570616852103e-07,
  -5.908050511038709e-07,
  -5.885942631067564e-07,
  -4.152982686633777e-07,
  5.883725719968425e-07,
  5.893179828575512e-07,
  5.910207150350288e-07,
  -4.1642575432376816e-07,
  4.1574505793562366e-07,
  -4.140096230287327e-07,
  -5.902642126187629e-07,
  5.908719469260859e-07,
  -4.1579261278457125e-07,
  5.879148510246068e-07,
  -5.894483301460696e-07,
  -4.1648526227788807e-07,
  -4.153124528727403e-07,
  4.140336038460646e-07,
  5.879030702260479e-07,
  5.861722822686488e-07,
  5.886257063991707e-07,
  4.184917443694758e-07,
  4.1841120790309105e-07,
  -5.859775047412086e-07,
  -5.886922984643661e-07])