In [2]:
import pandas as pd
import torch
import pickle
from datetime import datetime
from src.utils.ontology_utils import load_ontology
from src.data_pipeline.data_loader import load_filtered_cell_metadata
from src.data_pipeline.preprocess_ontology import preprocess_data_ontology
from src.utils.paths import PROJECT_ROOT
from src.utils.ontology_utils import get_sub_DAG

In [3]:
# 1. Load the cached ontology object
cl = load_ontology()

# Define the root of the ontology subgraph to be processed
root_cl_id = 'CL:0000988'  # hematopoietic cell

cell_types = get_sub_DAG(cl, root_cl_id)
print(f"Total cell types including root: {len(cell_types)}")
for t in list(cell_types)[:5]:
    print(t)


Loading cached ontology from /Users/jzhao/dev/Welch-lab/McCell/data/processed/ontology.pkl...
Ontology loaded successfully.
Total cell types including root: 708
Term('CL:0000845', name='marginal zone B cell of spleen')
Term('CL:0001077', name='ILC1, human')
Term('CL:0001080', name='NKp44-negative group 3 innate lymphoid cell, human')
Term('CL:4033076', name='cycling macrophage')
Term('CL:0009112', name='centroblast')


In [4]:

# 2. Load filtered cell metadata from CellXGene Census
cell_obs_metadata = load_filtered_cell_metadata(cl, root_cl_id=root_cl_id)
print(f"Loaded cell metadata with {cell_obs_metadata.shape[0]} cells and {cell_obs_metadata.shape[1]} metadata fields.")
print("Metadata columns:", cell_obs_metadata.columns.tolist())
# 3. Preprocess the ontology and cell data
target_column = 'cell_type_ontology_term_id'

print("Starting ontology preprocessing...")
mapping_dict, leaf_values, internal_values, ontology_df, cell_children_mask = preprocess_data_ontology(
        cl, cell_obs_metadata, target_column,
        upper_limit=root_cl_id,
        cl_only=True, include_leafs=False
    )

print(f"Preprocessing complete. Found {len(leaf_values)} leaf values and {len(internal_values)} internal values.")

Fetching descendants of CL:0000988...
Connecting to CellXGene Census...


The "stable" release is currently 2025-01-30. Specify 'census_version="2025-01-30"' in future calls to open_soma() to ensure data consistency.


Reading cell metadata to filter cell types...
Found 160 cell types with > 5000 cells.
Querying for final cell metadata...
Finished loading and filtering cell metadata.
Loaded cell metadata with 6383545 cells and 3 metadata fields.
Metadata columns: ['cell_type_ontology_term_id', 'assay', 'is_primary_data']
Starting ontology preprocessing...
Preprocessing complete. Found 41 leaf values and 100 internal values.


In [5]:
print("Cell Parent Mask:")
print(cell_children_mask)
print(cell_children_mask.shape)

Cell Parent Mask:
tensor([[0., 1., 1.,  ..., 1., 1., 1.],
        [1., 0., 0.,  ..., 1., 1., 0.],
        [1., 1., 0.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 0., 1., 1.],
        [1., 1., 1.,  ..., 1., 0., 1.],
        [1., 1., 1.,  ..., 1., 1., 0.]])
torch.Size([141, 141])


In [6]:

reverse_mapping_dict = {v: k for k, v in mapping_dict.items()}
print(reverse_mapping_dict)

blood_cell_term = None
for term in cl.terms():
    if term.id == 'CL:0000988':
        blood_cell_term = term
        break

if blood_cell_term:
    print(f"Found {blood_cell_term.name} term: {blood_cell_term.id}")
    # Get all subclasses of blood cell
    blood_cell_subclasses = list(blood_cell_term.subclasses())
    print(f"Found {len(blood_cell_subclasses)} blood cell subclasses.")

    #print some masks
    root_index = mapping_dict.get('CL:0000988', None)
    if root_index is not None:
        print(f"Index of blood cell in mapping dict: {root_index}")
        print("Mask row for blood cell:")
        print(cell_children_mask[root_index])
else:
    print("Blood cell not found in the ontology.")

{0: 'CL:0000814', 1: 'CL:0000763', 2: 'CL:0000775', 3: 'CL:0000097', 4: 'CL:0000786', 5: 'CL:0000236', 6: 'CL:0000232', 7: 'CL:0000233', 8: 'CL:0000451', 9: 'CL:0000084', 10: 'CL:0000235', 11: 'CL:0000576', 12: 'CL:0000094', 13: 'CL:0000878', 14: 'CL:0000738', 15: 'CL:0000129', 16: 'CL:0000542', 17: 'CL:0000766', 18: 'CL:0001082', 19: 'CL:0000838', 20: 'CL:0000817', 21: 'CL:0002355', 22: 'CL:0002045', 23: 'CL:0000559', 24: 'CL:0000816', 25: 'CL:0001054', 26: 'CL:0000826', 27: 'CL:0000836', 28: 'CL:0000837', 29: 'CL:0000557', 30: 'CL:0000556', 31: 'CL:0000788', 32: 'CL:0000938', 33: 'CL:0000936', 34: 'CL:0000049', 35: 'CL:0000784', 36: 'CL:0000624', 37: 'CL:0000767', 38: 'CL:0000860', 39: 'CL:0000782', 40: 'CL:0000904', 41: 'CL:0000909', 42: 'CL:0000875', 43: 'CL:0000623', 44: 'CL:0000895', 45: 'CL:0000905', 46: 'CL:0000787', 47: 'CL:0000900', 48: 'CL:0000798', 49: 'CL:0000980', 50: 'CL:0001062', 51: 'CL:0001044', 52: 'CL:0000625', 53: 'CL:0001050', 54: 'CL:0002393', 55: 'CL:0000492', 5

In [None]:
leaf_node_id = leaf_values[0]  # Let's pick the first one
leaf_node_term = cl[leaf_node_id]
print(f"Verifying leaf node: {leaf_node_term.name} ({leaf_node_id})")

# 2. Get the index for this node
leaf_node_index = mapping_dict[leaf_node_id]

# 3. Get the mask for this node
cell_mask_for_node = cell_children_mask[leaf_node_index]

# 4. Print the mask
print(f"\nParent mask for {leaf_node_term.name}:")
print(cell_mask_for_node)

# oth index looks a bit weird, let's print it out
print(reverse_mapping_dict[0], cl[reverse_mapping_dict[0]].name)
print(reverse_mapping_dict[0] in leaf_node_id)
cell_mask_for_node = cell_children_mask[0]

# 4. Print the mask
print(f"\nParent mask for {leaf_node_term.name}:")
print(cell_mask_for_node)

# print subclass of index 0
print(f"Subclasses of {cl[reverse_mapping_dict[0]].name}: {[subclass.name for subclass in cl[reverse_mapping_dict[0]].subclasses()]}")
# print index of subclasses
for subclass in cl[reverse_mapping_dict[0]].subclasses():
    print(f"{subclass.name}: {subclass.id} -> {mapping_dict.get(subclass.id, 'Not found')}")

Verifying leaf node: platelet (CL:0000233)

Parent mask for platelet:
tensor([1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
CL:0000814 mature NK T cell
False

Parent mask for platelet:
tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,


In [13]:
# 1. Pick an internal node
internal_node_id = internal_values[10] # Picking the 10th internal node for variety
internal_node_term = cl[internal_node_id]
print(f"Verifying internal node: {internal_node_term.name} ({internal_node_id})")

# 2. Get the parent mask for this node
internal_node_index = mapping_dict[internal_node_id]
parent_mask_for_node = cell_parent_mask[:, internal_node_index]

# 3. Find the indices where the mask is 0
zero_indices = (parent_mask_for_node == 0).nonzero(as_tuple=True)[0]
print(len(zero_indices), "masked-out parents indices:", zero_indices.tolist())

# 4. Convert those indices to parent CL terms
masked_out_parents = ontology_df.index[zero_indices].tolist()
print(f"\nMasked-out parents for {internal_node_term.name}:")
for parent_id in masked_out_parents:
    print(f"- {cl[parent_id].name} ({parent_id})")

# 5. Get the children of the internal node from the ontology
children_of_node = [term.id for term in internal_node_term.subclasses(with_self=True)]
print(f"\nActual children of {internal_node_term.name}:")
for child_id in children_of_node:
    if child_id in cl:
        print(f"- {cl[child_id].name} ({child_id})")

Verifying internal node: monocyte (CL:0000576)
6 masked-out parents indices: [1, 11, 14, 17, 57, 61]

Masked-out parents for monocyte:
- myeloid cell (CL:0000763)
- monocyte (CL:0000576)
- leukocyte (CL:0000738)
- myeloid leukocyte (CL:0000766)
- mononuclear phagocyte (CL:0000113)
- hematopoietic cell (CL:0000988)

Actual children of monocyte:
- monocyte (CL:0000576)
- classical monocyte (CL:0000860)
- non-classical monocyte (CL:0000875)
- CD115-positive monocyte (CL:0001022)
- CD14-positive monocyte (CL:0001054)
- intermediate monocyte (CL:0002393)
- cycling monocyte (CL:4033073)
- BAG3-positive monocyte (CL:4047060)
- CD14-positive, CD16-negative classical monocyte (CL:0002057)
- Gr1-high classical monocyte (CL:0002395)
- Gr1-low non-classical monocyte (CL:0002058)
- CD14-low, CD16-positive monocyte (CL:0002396)
- Gr1-positive, CD43-positive monocyte (CL:0002398)
- CD14-positive, CD16-positive monocyte (CL:0002397)
- CD14-positive, CD16-low monocyte (CL:0001055)
- MHC-II-negative cla