In [1]:
import pandas as pd
import torch
import pickle
from datetime import datetime
from src.utils.ontology_utils import load_ontology
from src.data_pipeline.data_loader import load_filtered_cell_metadata
from src.data_pipeline.preprocess_ontology import preprocess_data_ontology
from src.utils.paths import PROJECT_ROOT
from src.utils.ontology_utils import get_sub_DAG

In [2]:
# 1. Load the cached ontology object
cl = load_ontology()

# Define the root of the ontology subgraph to be processed
root_cl_id = 'CL:0000988'  # hematopoietic cell

cell_types = get_sub_DAG(cl, root_cl_id)
print(f"Total cell types including root: {len(cell_types)}")
for t in list(cell_types)[:5]:
    print(t)


Loading cached ontology from /Users/jzhao/dev/Welch-lab/McCell/data/processed/ontology.pkl...
Ontology loaded successfully.
Total cell types including root: 708
Term('CL:0002344', name='CD56-negative, CD161-positive immature natural killer cell, human')
Term('CL:0000985', name='IgG plasma cell')
Term('CL:0000958', name='T1 B cell')
Term('CL:0002401', name='mature dendritic epithelial T cell precursor')
Term('CL:0000926', name='CD4-positive type I NK T cell secreting interferon-gamma')


In [3]:
# 2. Load filtered cell metadata from CellXGene Census
cell_obs_metadata = load_filtered_cell_metadata(cl, root_cl_id=root_cl_id)
print(f"Loaded cell metadata with {cell_obs_metadata.shape[0]} cells and {cell_obs_metadata.shape[1]} metadata fields.")
print("Metadata columns:", cell_obs_metadata.columns.tolist())

# 3. Preprocess the ontology and cell data
target_column = 'cell_type_ontology_term_id'

print("Starting ontology preprocessing...")
# Note the 6 return values, including our new DataFrames
mapping_dict, leaf_values, internal_values, marginalization_df, target_BCE, exclusion_df = preprocess_data_ontology(
    cl, cell_obs_metadata, target_column,
    upper_limit=root_cl_id,
    cl_only=True
)

print(f"Preprocessing complete. Found {len(leaf_values)} leaf values and {len(internal_values)} internal values.")

Fetching descendants of CL:0000988...
Connecting to CellXGene Census...


The "stable" release is currently 2025-01-30. Specify 'census_version="2025-01-30"' in future calls to open_soma() to ensure data consistency.


Reading cell metadata to filter cell types...
Found 160 cell types with > 5000 cells.
Querying for final cell metadata...
Finished loading and filtering cell metadata.
Loaded cell metadata with 6383545 cells and 3 metadata fields.
Metadata columns: ['cell_type_ontology_term_id', 'assay', 'is_primary_data']
Starting ontology preprocessing...
141 cell types in the dataset 41 leaf types, 100 internal types
Preprocessing complete. Found 41 leaf values and 100 internal values.


In [4]:
print("Cell Parent Mask:")
print(target_BCE)
print(target_BCE.shape)

Cell Parent Mask:
            CL:0000233  CL:0000559  CL:0000794  CL:0000895  CL:0000899  \
CL:0000233           1           0           0           0           0   
CL:0000559           0           1           0           0           0   
CL:0000794           0           0           1           0           0   
CL:0000895           0           0           0           1           0   
CL:0000899           0           0           0           0           1   
...                ...         ...         ...         ...         ...   
CL:0002489           0           0           0           0           0   
CL:0002496           0           0           0           0           0   
CL:0008001           0           0           0           0           0   
CL:1001603           0           0           0           0           0   
CL:4033039           0           0           0           0           0   

            CL:0000900  CL:0000903  CL:0000904  CL:0000905  CL:0000907  ...  \
CL:0000233    

In [6]:
blood_cell_term = cl.get('CL:0000988')

if blood_cell_term:
    print(f"Found {blood_cell_term.name} term: {blood_cell_term.id}")

    # Get all subclasses of blood cell
    blood_cell_subclasses = list(blood_cell_term.subclasses())
    print(f"Found {len(blood_cell_subclasses)} blood cell subclasses.")

    # Directly access the row in the DataFrame using the CL ID.
    # The row represents the "child" and the columns with a '1' are its "parents".
    print(f"\nInspecting ontology_df for: {blood_cell_term.name}")

    # Use .loc for clear, label-based indexing.
    ancestor_vector = target_BCE.loc[blood_cell_term.id]

    print("\nRow (Pandas Series) from ontology_df:")
    # Displaying the first 25 entries of the vector
    print(ancestor_vector.head(25))

    # To make it more readable, let's find and print the names of the ancestors
    ancestors = ancestor_vector[ancestor_vector == 1].index.tolist()
    print(f"\nFound {len(ancestors)} ancestors in the DataFrame (including self):")
    for ancestor_id in ancestors[:10]:  # Print first 10
        print(f"- {cl[ancestor_id].name} ({ancestor_id})")
else:
    print("Blood cell term 'CL:0000988' not found in the ontology.")

Found hematopoietic cell term: CL:0000988
Found 708 blood cell subclasses.

Inspecting ontology_df for: hematopoietic cell

Row (Pandas Series) from ontology_df:
CL:0000233    0
CL:0000559    0
CL:0000794    0
CL:0000895    0
CL:0000899    0
CL:0000900    0
CL:0000903    0
CL:0000904    0
CL:0000905    0
CL:0000907    0
CL:0000910    0
CL:0000912    0
CL:0000913    0
CL:0000915    0
CL:0000917    0
CL:0000934    0
CL:0000936    0
CL:0000938    0
CL:0000939    0
CL:0000940    0
CL:0000985    0
CL:0000987    0
CL:0001043    0
CL:0001044    0
CL:0001049    0
Name: CL:0000988, dtype: int64

Found 1 ancestors in the DataFrame (including self):
- hematopoietic cell (CL:0000988)


In [10]:

# 1. Pick a leaf node to verify
leaf_to_inspect = leaf_values[0]
leaf_node_term = cl[leaf_to_inspect]
print(f"Verifying leaf node: '{leaf_node_term.name}' ({leaf_to_inspect})")

# 2. Directly access the corresponding row in target_BCE using its CL ID
print(f"\nFetching ancestor vector from target_BCE...")
ancestor_vector = target_BCE.loc[leaf_to_inspect]

# 3. Find and print the names of the ancestors from the vector
# This is more informative than printing a long tensor of 0s and 1s.
ancestors = ancestor_vector[ancestor_vector == 1].index.tolist()

print(f"\nFound {len(ancestors)} ancestors (including self):")
# To avoid printing too many, let's just show the first 15
for i, ancestor_id in enumerate(ancestors):
    if i < 15:
        print(f"- {cl[ancestor_id].name} ({ancestor_id})")

if len(ancestors) > 15:
    print(f"... and {len(ancestors) - 15} more.")

Verifying leaf node: 'platelet' (CL:0000233)

Fetching ancestor vector from target_BCE...

Found 4 ancestors (including self):
- platelet (CL:0000233)
- blood cell (CL:0000081)
- myeloid cell (CL:0000763)
- hematopoietic cell (CL:0000988)


In [12]:
# Using a known internal node for consistency, e.g., monocyte
internal_to_inspect = 'CL:0000576'

if internal_to_inspect in internal_values:
    internal_node_term = cl[internal_to_inspect]
    print(f"Verifying exclusion_df for internal node: '{internal_node_term.name}' ({internal_to_inspect})")
    print("We expect to see its descendants listed here, as they will be excluded from the loss (value = 0).")

    # 2. Directly access the corresponding row in exclusion_df
    exclusion_vector = exclusion_df.loc[internal_to_inspect]

    # 3. Find and print the names of the excluded descendants (where value is 0)
    excluded_nodes = exclusion_vector[exclusion_vector == 0].index.tolist()

    if excluded_nodes:
        print(f"\nFound {len(excluded_nodes)} excluded descendants:")
        # To avoid printing too many, let's just show the first 15
        for i, child_id in enumerate(excluded_nodes):
            if i < 15:
                print(f"- {cl[child_id].name} ({child_id})")
        if len(excluded_nodes) > 15:
            print(f"... and {len(excluded_nodes) - 15} more.")
    else:
        print("\nFound no descendants in this dataset to exclude for this node.")

else:
    print(f"Node {internal_to_inspect} is not in this dataset's list of internal values.")

Verifying exclusion_df for internal node: 'monocyte' (CL:0000576)
We expect to see its descendants listed here, as they will be excluded from the loss (value = 0).

Found 7 excluded descendants:
- CD14-positive, CD16-negative classical monocyte (CL:0002057)
- CD14-low, CD16-positive monocyte (CL:0002396)
- classical monocyte (CL:0000860)
- non-classical monocyte (CL:0000875)
- CD14-positive monocyte (CL:0001054)
- intermediate monocyte (CL:0002393)
- CD14-positive, CD16-positive monocyte (CL:0002397)
