In [0]:
import os, json, sys
import mlflow
from pathlib import Path
from mlflow.tracking import MlflowClient
from pyspark.sql import functions as F

In [0]:
sys.path.append("/Workspace/9900-f18a-cake") 
sys.path.append("/Workspace/9900-f18a-cake/mt-method2/src")
sys.path.append("/Workspace/9900-f18a-cake/mt-method2")

In [0]:
from mch.models.training import BatchModelTrainer

In [0]:
def _set_env_from_plan(plan: dict):
    # Minimal env config used by the trainer code
    # Feature prefilter knobs
    os.environ["MCH_PREFILTER_TOPK"]      = str(plan.get("prefilter_topk", 200))
    os.environ["MCH_PREFILTER_SCAN_MAX"]  = str(plan.get("prefilter_scan_max", 20000))
    os.environ["MCH_PREFILTER_CHUNK_SIZE"]= str(plan.get("prefilter_chunk_size", 5000))

    # Disable/enable Differential Methylation step (1 disables)
    os.environ["MCH_DISABLE_DM"] = "1" if plan.get("disable_dm", True) else "0"

    # Limit nodes (train a single tree node)
    only_node = plan.get("only_node")
    if only_node:
        os.environ["MCH_ONLY_NODE"] = str(only_node)
    else:
        os.environ.pop("MCH_ONLY_NODE", None)

    # Parallelism hints the trainer reads
    os.environ["RF_N_JOBS"] = str(plan.get("rf_n_jobs", 1))
    os.environ["CV_N_JOBS"] = str(plan.get("cv_n_jobs", 1))

In [0]:
def run_child(plan: dict) -> dict:
    """
    Run one training job driven by `plan`.
    Returns a result dict (ok/metrics/etc.) that the parent will collect.
    """
    start_ts = time.time()
    node_id  = plan.get("node_id", "UNKNOWN")
    exp_path = plan.get("mlflow_experiment_path", "/Shared/methyl/experiments/classifier")

    # Make sure experiment exists (safe if it already does)
    mlflow.set_experiment(exp_path)
    exp = mlflow.get_experiment_by_name(exp_path)
    if exp is None:
        exp_id = mlflow.create_experiment(exp_path)
    else:
        exp_id = exp.experiment_id
    mlflow.set_experiment(experiment_id=exp_id)

    # Configure trainer via env vars
    _set_env_from_plan(plan)

    # Tags & params we want recorded
    tags = {
        "orchestrator": "parent_notebook",
        "node_id": node_id,
        "only_node": str(plan.get("only_node")),
        "disable_dm": str(plan.get("disable_dm", True)),
    }
    params = {
        "prefilter_topk":        plan.get("prefilter_topk", 200),
        "prefilter_scan_max":    plan.get("prefilter_scan_max", 20000),
        "prefilter_chunk_size":  plan.get("prefilter_chunk_size", 5000),
        "rf_n_jobs":             plan.get("rf_n_jobs", 1),
        "cv_n_jobs":             plan.get("cv_n_jobs", 1),
    }

    result = {
        "ok": False,
        "node_id": node_id,
        "metrics": {},
        "summary_path": None,
        "error": None,
        "trace": None,
        "t_sec": None,
    }

    with mlflow.start_run(run_name=f"node={node_id}") as run:
        # record basic config in MLflow
        mlflow.set_tags(tags)
        mlflow.log_params(params)

        try:
            # Instantiate & train (driven by env set above)
            trainer = BatchModelTrainer()
            stats = trainer.train_all_models(save_dir=None, raise_on_error=False)

            # If we limited to a single node, pull that node's metrics if present
            only_node = plan.get("only_node")
            node_key = only_node if only_node else node_id
            node_stats = stats.get(node_key, {})

            # Log a JSON summary artifact
            out_dir = Path("/dbfs/tmp/mt_method2_child")
            out_dir.mkdir(parents=True, exist_ok=True)
            summary_file = out_dir / f"{node_id}_summary.json"
            with open(summary_file, "w", encoding="utf-8") as f:
                json.dump({"node": node_key, "stats": node_stats, "all": stats}, f, indent=2)
            mlflow.log_artifact(str(summary_file), artifact_path="child_summaries")

            # Also log top-level metrics (if we have them)
            metrics = node_stats.get("metrics", {})
            for k, v in metrics.items():
                if isinstance(v, (int, float)):
                    mlflow.log_metric(k, float(v))

            result["ok"] = True
            result["metrics"] = metrics
            result["summary_path"] = str(summary_file)

        except Exception as e:
            result["error"] = f"{type(e).__name__}: {e}"
            result["trace"] = traceback.format_exc()

        finally:
            result["t_sec"] = round(time.time() - start_ts, 3)

    return result