In [0]:
import sys
import os
import joblib
from concurrent.futures import ThreadPoolExecutor
import time
import json
import pandas as pd

# 1. Set your source code path
# (This MUST be the same as in Training_model.ipynb)
sys.path.append(r"/Workspace/9900-f18a-cake/working_branch/src")

# 2. Define the path to the "Worker" Notebook
# (We assume it's in the same folder)
WORKER_NOTEBOOK_PATH = "./Training_model"

print("Environment setup complete.")

In [0]:
JOBLIB_PATH = "/Workspace/9900-f18a-cake/working_branch/data/freeze0525/diseaseTree_mapped.joblib"

try:
    tree_object = joblib.load(JOBLIB_PATH)
    print("DiseaseTree node list loaded successfully.")
except Exception as e:
    print(f"Failed to load DiseaseTree: {e}")

In [0]:
def get_nodes_to_train(tree):
    """
    Recursively traverses the DiseaseTree and returns a list of all node names to be trained.
    """
    all_nodes = []
    
    def traverse(node):
        # Logic: Train a model for any non-root node that has samples
        # (You can adjust this logic as needed)
        if node.name != 'ZERO2' and hasattr(node, 'samples') and len(node.samples) > 0:
             all_nodes.append(node.name)
        
        # Recurse into children
        if hasattr(node, 'children'):
            for child in node.children:
                traverse(child)

    traverse(tree)
    unique_nodes = list(set(all_nodes))
    print(f"Extracted {len(unique_nodes)} unique nodes from DiseaseTree.")
    return unique_nodes

# Run the function
nodes_to_train = get_nodes_to_train(tree_object)
print(f"Node list (first 10): {nodes_to_train[:10]}")

In [0]:
def train_node_parallel(node_name):
    """
    This is the task each parallel thread will execute.
    It calls the 'Training_model' notebook to train a single node.
    """
    print(f"➡️  [START] Node: {node_name}")
    
    # 1. Parameters to pass to the 'Training_model' notebook
    params = {
      "only_node": node_name 
    }
    
    # 2. Timeout (e.g., 2 hours)
    timeout_seconds = 7200 
    
    try:
        # 3. Execute!
        # This starts the WORKER_NOTEBOOK_PATH in a new job
        # and waits for it to return or time out
        result_json = dbutils.notebook.run(WORKER_NOTEBOOK_PATH, timeout_seconds, params)
        
        print(f"✅  [SUCCESS] Node: {node_name}.")
        return (node_name, "Success", result_json)
    
    except Exception as e:
        # 4. Capture any failed jobs
        print(f"❌  [FAILED] Node: {node_name}. Error: {e}")
        return (node_name, "Failed", str(e))

In [0]:
# --- [CONFIGURE YOUR RUN] ---

# 1. Set your desired maximum number of parallel runs
MAX_PARALLEL_RUNS = 10

# 2. (For Testing)
#    Run only the first 5 nodes to test pipeline
nodes_to_run = nodes_to_train[:5] 

# 3. (For Production)
#    Once testing is successful, uncomment the line below to run all nodes
# nodes_to_run = nodes_to_train

# --------------------------

print(f"--- Starting parallel training for {len(nodes_to_run)} nodes, max parallelism: {MAX_PARALLEL_RUNS} ---")
start_time = time.time()
all_results = [] # To store (node_name, status, json_output)

# Use ThreadPoolExecutor to manage the parallel runs
with ThreadPoolExecutor(max_workers=MAX_PARALLEL_RUNS) as executor:
    # Submit all jobs
    futures = [executor.submit(train_node_parallel, node) for node in nodes_to_run]
    
    # Collect results as they complete
    for future in futures:
        all_results.append(future.result())

end_time = time.time()
print("\n--- [ALL JOBS COMPLETE] ---")
print(f"Total time taken: {end_time - start_time:.2f} seconds ( {(end_time - start_time)/60:.2f} minutes )")


# --- [SUMMARY REPORT] ---
print("\n--- Summary Report ---")
success_count = 0
failed_nodes = []
all_stats_dfs = [] # To store all successful results

for (node_name, status, result_data) in all_results:
    if status == "Success":
        success_count += 1
        try:
            # Try to convert the returned JSON back into a DataFrame
            node_stats_df = pd.read_json(result_data, orient='records')
            node_stats_df['node'] = node_name # Add a column to identify the node
            all_stats_dfs.append(node_stats_df)
        except Exception as e:
            print(f"Warning: Could not parse results from node {node_name}: {e}")
    else:
        failed_nodes.append(node_name)

print(f"Successful: {success_count} / {len(nodes_to_run)}")
print(f"Failed: {len(failed_nodes)} / {len(nodes_to_run)}")
if failed_nodes:
    print(f"List of failed nodes: {failed_nodes}")

# 4. Combine all successful results and display them
if all_stats_dfs:
    final_stats_summary = pd.concat(all_stats_dfs, ignore_index=True)
    print("\n--- [All Model Performance Metrics] ---")
    display(final_stats_summary)