In [0]:
dbutils.widgets.text("experiment", "/Workspace/9900-f18a-cake/classifier")
dbutils.widgets.text("child_notebook", "/Workspace/9900-f18a-cake/mt-method2/Child")
dbutils.widgets.text("labels_table", "cb_prod.`comp9300-9900-f18a-cake`.node_labels")
dbutils.widgets.text("id_col", "biosample_id")
dbutils.widgets.text("label_col", "direct_label")

In [0]:
dbutils.widgets.text("repo_src", "/Workspace/9900-f18a-cake/mt-method2/src")
dbutils.widgets.text("rf_n_jobs", "1")
dbutils.widgets.text("cv_n_jobs", "1")
dbutils.widgets.text("prefilter_topk", "200")
dbutils.widgets.text("prefilter_scan_max", "5000")
dbutils.widgets.text("prefilter_chunk_size", "1000")
dbutils.widgets.dropdown("disable_dm", "true", ["true","false"])
dbutils.widgets.text("raise_on_error", "true")

In [0]:
dbutils.widgets.text("include_nodes_csv", "") 
dbutils.widgets.dropdown("launch_mode", "sequential", ["sequential","plan_only"])

In [0]:
EXPERIMENT     = dbutils.widgets.get("experiment")
CHILD_NOTEBOOK = dbutils.widgets.get("child_notebook")
LABELS_TABLE   = dbutils.widgets.get("labels_table")
ID_COL         = dbutils.widgets.get("id_col")
LABEL_COL      = dbutils.widgets.get("label_col")
INCLUDE_CSV    = dbutils.widgets.get("include_nodes_csv").strip()
LAUNCH_MODE    = dbutils.widgets.get("launch_mode")

In [0]:
REPO_SRC       = dbutils.widgets.get("repo_src")
RF_N_JOBS      = dbutils.widgets.get("rf_n_jobs")
CV_N_JOBS      = dbutils.widgets.get("cv_n_jobs")
PREF_TOPK      = dbutils.widgets.get("prefilter_topk")
PREF_SCAN_MAX  = dbutils.widgets.get("prefilter_scan_max")
PREF_CHUNK     = dbutils.widgets.get("prefilter_chunk_size")
DISABLE_DM     = dbutils.widgets.get("disable_dm")
RAISE_ON_ERROR = dbutils.widgets.get("raise_on_error")

In [0]:
print("Experiment:", EXPERIMENT)
print("Child:", CHILD_NOTEBOOK)
print("Labels table:", LABELS_TABLE)
print("Mode:", LAUNCH_MODE)

In [0]:
from pyspark.sql import functions as F

lbl = spark.table(LABELS_TABLE).where(F.col("node_id").isNotNull() & F.col(LABEL_COL).isNotNull())
nodes_df = lbl.groupBy("node_id").agg(F.countDistinct(ID_COL).alias("n_samples")).orderBy(F.desc("n_samples"))
display(nodes_df)

all_nodes = [r["node_id"] for r in nodes_df.collect()]

if INCLUDE_CSV:
    include = {s.strip() for s in INCLUDE_CSV.split(",") if s.strip()}
    model_nodes = [n for n in all_nodes if n in include]
else:
    model_nodes = all_nodes

print("Planned nodes:", len(model_nodes))
display(spark.createDataFrame([(i+1, n) for i, n in enumerate(model_nodes)], "idx int, node string"))

In [0]:
import mlflow, json
from mlflow.tracking import MlflowClient
from datetime import datetime

mlflow.set_experiment(EXPERIMENT)
client = MlflowClient()

session_name = f"mt-method2_pycall_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
with mlflow.start_run(run_name=session_name) as parent_run:
    parent_id = parent_run.info.run_id
    mlflow.set_tag("orchestrator", "pycall_parent")
    mlflow.log_param("num_nodes_planned", len(model_nodes))
    mlflow.log_table(spark.createDataFrame([(n,) for n in model_nodes], "node_id string").toPandas(),
                     artifact_file="plan/nodes.json")

    if LAUNCH_MODE == "plan_only":
        dbutils.notebook.exit(json.dumps({"status":"planned","parent_run_id":parent_id,"nodes":len(model_nodes)}))

    results = []
    for i, node in enumerate(model_nodes, 1):
        print(f"[{i}/{len(model_nodes)}] Launching node: {node}")
        args = {
            "experiment": EXPERIMENT,
            "parent_run_id": parent_id,
            "repo_src": REPO_SRC,
            "only_node": node,
            "rf_n_jobs": RF_N_JOBS,
            "cv_n_jobs": CV_N_JOBS,
            "prefilter_topk": PREF_TOPK,
            "prefilter_scan_max": PREF_SCAN_MAX,
            "prefilter_chunk_size": PREF_CHUNK,
            "disable_dm": DISABLE_DM,
            "raise_on_error": RAISE_ON_ERROR,
        }
        out = dbutils.notebook.run(CHILD_NOTEBOOK, timeout_seconds=0, arguments=args)
        print("Child result:", out)
        results.append({"node_id": node, "result": out})

    mlflow.log_dict(results, artifact_file="results/children.json")
    dbutils.notebook.exit(json.dumps({"status":"ok","parent_run_id":parent_id,"children":len(results)}))