In [0]:
REPO_SRC    = "/Workspace/9900-f18a-cake/mt-method2/src"   
JOBLIB_PATH = "/Workspace/9900-f18a-cake/mt-method2/data/freeze0525/diseaseTree_mapped.joblib"

In [0]:
import sys, joblib
if REPO_SRC not in sys.path:
    sys.path.append(REPO_SRC)

tree = joblib.load(JOBLIB_PATH)
root = getattr(tree, "root", tree)

print("Loaded:", type(tree))
print("Root:", getattr(root, "name", getattr(root, "label", "<unknown>")))

In [0]:
from collections import defaultdict, deque

def node_name(n):
    return getattr(n, "name", getattr(n, "label", "UNKNOWN"))

def node_children(n):
    v = getattr(n, "children", None)
    return list(v) if v else []

def gather_own_samples(n):
    """Samples explicitly listed at this node (not including descendants)."""
    s = set()
    for attr in ("samples", "training_samples", "validation_samples", "calibration_samples"):
        lst = getattr(n, attr, None)
        if lst:
            s.update(lst)
    return s

def gather_subtree_samples(n):
    """All samples under this node (union of this node + all descendants)."""
    all_s = set()
    q = deque([n])
    while q:
        cur = q.popleft()
        all_s.update(gather_own_samples(cur))
        q.extend(node_children(cur))
    return all_s

In [0]:
ADD_REST_CLASS = False        # set True if you want one-vs-rest "OTHER" samples per node
REST_LABEL      = "__OTHER__" # or use node_name(parent) if you prefer

rows = []  # (biosample_id, node_id, direct_label)

def process_node(n):
    kids = node_children(n)
    if not kids:
        return  # leaf; no classifier at this node

    parent_id = node_name(n)

    # Precompute each child's FULL subtree sample set
    child_to_samples = {}
    for c in kids:
        child_to_samples[node_name(c)] = gather_subtree_samples(c)

    # Any sample claimed by any child
    union_children = set().union(*child_to_samples.values()) if child_to_samples else set()

    # Emit label rows for child membership
    for child_label, sid_set in child_to_samples.items():
        for sid in sid_set:
            rows.append((sid, parent_id, child_label))

    # Optional: put the leftovers (present at parent but not in any child subtree) into REST label
    if ADD_REST_CLASS:
        parent_own = gather_own_samples(n)
        rest = parent_own - union_children
        # You can switch REST_LABEL to parent_id if desired
        rest_label = REST_LABEL if REST_LABEL else parent_id
        for sid in rest:
            rows.append((sid, parent_id, rest_label))

    # Recurse
    for c in kids:
        process_node(c)

process_node(root)
print("Raw label rows:", len(rows))

In [0]:
labels_df = spark.createDataFrame(rows, schema="biosample_id string, node_id string, direct_label string").dropDuplicates()

features_tbl = "cb_prod.`comp9300-9900-f18a-cake`.filter_meth_mvalues_masked_subset_leukaemia_subsampled"
features_ids = spark.table(features_tbl).select("biosample_id").distinct()

labels_df = labels_df.join(features_ids, on="biosample_id", how="inner")

print("Labels after filtering to features:", labels_df.count())
display(labels_df.limit(20))

In [0]:
TARGET_TABLE = "cb_prod.`comp9300-9900-f18a-cake`.node_labels"
labels_df.write.mode("overwrite").format("delta").saveAsTable(TARGET_TABLE)
print("Wrote:", TARGET_TABLE)