In [0]:
import time, pandas as pd, sys, os
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import jobs, compute
import joblib

In [0]:
TRAINING_NOTEBOOK = "/Workspace/9900-f18a-cake/mt-method2/Training_model_job"
REPO_SRC          = "/Workspace/9900-f18a-cake/mt-method2/src"

In [0]:
sys.path.append("/Workspace/9900-f18a-cake/mt-method2/src")
TREE_PATH = "/Workspace/9900-f18a-cake/mt-method2/data/freeze0525/diseaseTree_mapped.joblib"
tree  = joblib.load(TREE_PATH)
root  = getattr(tree, "root", tree)
gname = lambda n: getattr(n, "name", getattr(n, "label", "UNKNOWN"))
gkids = lambda n: getattr(n, "children", []) or getattr(n, "child_nodes", []) or []

In [0]:
MAX_IN_FLIGHT     = 12
USE_EXISTING_CLUSTER = True 

In [0]:
host = spark.conf.get("spark.databricks.workspaceUrl")
w = WorkspaceClient(
    host=f"https://{host}",
    azure_client_id=dbutils.secrets.get("ops", "SP_CLIENT_ID"),
    azure_tenant_id=dbutils.secrets.get("ops", "SP_TENANT_ID"),
    azure_client_secret=dbutils.secrets.get("ops", "SP_CLIENT_SECRET"),
)

print("Me (SP):", w.current_user.me().application_id)

In [0]:
MAX_IN_FLIGHT = 12

In [0]:
EXISTING_CLUSTER_ID = spark.conf.get("spark.databricks.clusterUsageTags.clusterId", None)
print("EXISTING_CLUSTER_ID =", EXISTING_CLUSTER_ID)
if not EXISTING_CLUSTER_ID:
    raise ValueError("Open this notebook on a cluster (or paste an ID).")

In [0]:
def submit(node_id: str) -> int:
    """Submit one child run to the CURRENT cluster with only_node=<node_id>."""
    payload = {
        "run_name": f"train_{node_id}",
        "existing_cluster_id": EXISTING_CLUSTER_ID,  # reuse this notebook's cluster
        "notebook_task": {
            "notebook_path": TRAINING_NOTEBOOK,
            "base_parameters": {"only_node": node_id},
        },
    }
    resp = w.api_client.do("POST", "/api/2.1/jobs/runs/submit", body=payload)
    # If you want to sanity check:
    # print(resp)
    return resp["run_id"]

In [0]:
def list_trainable(n, out=None):
    out = out or []
    ks = gkids(n)
    if len(ks) >= 2: out.append(gname(n))
    for k in ks: list_trainable(k, out)
    return out

In [0]:
def finished(run_id:int) -> bool:
    s = w.runs.get(run_id=run_id)
    return s.state.life_cycle_state in {"TERMINATED","SKIPPED","INTERNAL_ERROR"}

In [0]:
def result_state(rid):
    return (w.runs.get(run_id=rid).state.result_state) or "UNKNOWN"

In [0]:
nodes = list_trainable(root)
print("Trainable parents:", len(nodes))
display(pd.DataFrame({"node": nodes}))

In [0]:
pending = list(nodes); active = {}; fails = []
while pending and len(active) < MAX_IN_FLIGHT:
    rid = submit(pending.pop(0)); active[rid] = rid

In [0]:
while active or pending:
    done = []
    for rid in list(active):
        if finished(rid):
            rs = result_state(rid)
            if rs != "SUCCESS": fails.append((rid, rs))
            done.append(rid)
    for rid in done: active.pop(rid, None)
    while pending and len(active) < MAX_IN_FLIGHT:
        rid = submit(pending.pop(0)); active[rid] = rid
    if active: import time; time.sleep(8)

print("failures:", fails)

In [0]:
pending = list(nodes)
active  = {}
fails   = []

In [0]:
while pending and len(active) < MAX_IN_FLIGHT:
    n = pending.pop(0); rid = submit(n); active[rid] = n

In [0]:
while active or pending:
    done = []
    for rid, n in list(active.items()):
        if finished(rid):
            rs = result(rid)
            print(f"completed {n} ({rid}) => {rs}")
            if rs != "SUCCESS": fails.append((n, rid, rs))
            done.append(rid)
    for rid in done: active.pop(rid, None)

    while pending and len(active) < MAX_IN_FLIGHT:
        n = pending.pop(0); rid = submit(n); active[rid] = n

    if active: time.sleep(8)

In [0]:
print(f"\nSubmitted: {len(nodes)}, failures: {len(fails)}")
if fails: display(pd.DataFrame(fails, columns=["node","run_id","result"]))