In [0]:
import sys, os, time, pandas as pd, mlflow, joblib
from mlflow.tracking import MlflowClient
from pyspark.sql import functions as F
import polars as pl
from mlflow.exceptions import RestException

In [0]:
MLF_EXPERIMENT = "/Workspace/9900-f18a-cake/classifier"

client = MlflowClient()
exp = client.get_experiment_by_name(MLF_EXPERIMENT)

if exp is None:
    try:
        exp_id = client.create_experiment(MLF_EXPERIMENT)   # creates the Workspace experiment
        print("Created experiment:", MLF_EXPERIMENT, "->", exp_id)
    except RestException as e:
        # usually a permissions/folder issue under /Shared
        print("Create failed:", e)
        raise
else:
    exp_id = exp.experiment_id
    print("Found experiment:", MLF_EXPERIMENT, "->", exp_id)

mlflow.set_experiment(experiment_id=exp_id)

In [0]:
WORKER_NOTEBOOK = "/Workspace/9900-f18a-cake/mt-method2/Training_model_MLFlow"
REPO_SRC   = "/Workspace/9900-f18a-cake/mt-method2/src"
JOBLIB_PATH= "/Workspace/9900-f18a-cake/mt-method2/data/freeze0525/diseaseTree_mapped.joblib"

In [0]:
USE_JOBS_API   = False      # False = sequential (works now); True = parallel via Jobs API
MAX_IN_FLIGHT  = 12         # only used if USE_JOBS_API=True

In [0]:
FEATURES_TABLE = "ccia.curated.filter_meth_mvalues_masked"
LABELS_TABLE   = "ccia.curated.node_direct_labels"     # your table with (node_id, sample_id, direct_label)
LABEL_COL      = "direct_label"
ID_COL         = "sample_id"

In [0]:
sys.path.append(REPO_SRC)
tree = joblib.load(JOBLIB_PATH)
root = getattr(tree, "root", tree)
gname = lambda n: getattr(n, "name", getattr(n, "label", "UNKNOWN"))
gkids = lambda n: getattr(n, "children", []) or getattr(n, "child_nodes", []) or []

In [0]:
def list_trainable(n, out=None):
    out = out or []
    ks = gkids(n)
    if len(ks) >= 2: out.append(gname(n))
    for k in ks: list_trainable(k, out)
    return out

In [0]:
nodes = list_trainable(root)
print("Trainable parents:", len(nodes))
display(pd.DataFrame({"node": nodes}))

In [0]:
mlflow.set_experiment(MLF_EXPERIMENT)
client = MlflowClient()

extra_args = {
    "mlflow_experiment": MLF_EXPERIMENT,
    "features_table": FEATURES_TABLE,
    "labels_table": LABELS_TABLE,
    "label_col": LABEL_COL,
    "id_col": ID_COL,
}

In [0]:
with mlflow.start_run(run_name="methyl_fanout_session") as parent:
    parent_id = parent.info.run_id
    mlflow.set_tag("purpose", "hierarchical_train")
    mlflow.log_dict({"nodes": nodes, "count": len(nodes)}, "plan/manifest.json")

    # default: sequential (no Jobs API; works without tokens)
    if not USE_JOBS_API:
        for n in nodes:
            print(f"[SEQUENTIAL] training {n}")
            dbutils.notebook.run(
                WORKER_NOTEBOOK,
                timeout_seconds=0,
                arguments={"only_node": n, "parent_run_id": parent_id, **extra_args},
            )
        print("All nodes done (sequential).")
    else:
        # ---- Parallel via Jobs API on the existing cluster (requires 'Run as' user/SP) ----
        from databricks.sdk import WorkspaceClient
        from databricks.sdk.service import jobs

        host  = spark.conf.get("spark.databricks.workspaceUrl")
        token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
        w = WorkspaceClient(host=f"https://{host}", token=token)  # will work when this notebook is run AS user/SP

        EXISTING_CLUSTER_ID = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
        assert EXISTING_CLUSTER_ID, "Attach to a cluster first."

        def submit(node_id: str) -> int:
            payload = {
                "run_name": f"train_{node_id}",
                "existing_cluster_id": EXISTING_CLUSTER_ID,
                "notebook_task": {
                    "notebook_path": WORKER_NOTEBOOK,
                    "base_parameters": {
                        "only_node": node_id,
                        "parent_run_id": parent_id,
                        **extra_args
                    },
                },
            }
            resp = w.api_client.do("POST", "/api/2.1/jobs/runs/submit", body=payload)
            return resp["run_id"]

        def finished(rid: int) -> bool:
            s = w.runs.get(run_id=rid)
            return s.state.life_cycle_state in {"TERMINATED", "SKIPPED", "INTERNAL_ERROR"}

        def result_state(rid: int) -> str:
            return (w.runs.get(run_id=rid).state.result_state) or "UNKNOWN"

        pending = list(nodes); active = {}; fails = []
        while pending and len(active) < MAX_IN_FLIGHT:
            rid = submit(pending.pop(0)); active[rid] = rid

        while active or pending:
            done = []
            for rid in list(active):
                if finished(rid):
                    rs = result_state(rid)
                    print(f"completed {rid} => {rs}")
                    if rs != "SUCCESS": fails.append((rid, rs))
                    done.append(rid)
            for rid in done: active.pop(rid, None)
            while pending and len(active) < MAX_IN_FLIGHT:
                rid = submit(pending.pop(0)); active[rid] = rid
            if active: time.sleep(8)

        if fails:
            print("Failures:", fails)
        else:
            print("All nodes done (parallel).")