In [2]:
import pandas as pd
import torch
import pickle
from datetime import datetime
from src.utils.ontology_utils import load_ontology
from src.data_pipeline.data_loader import load_filtered_cell_metadata
from src.data_pipeline.preprocess_ontology import preprocess_data_ontology
from src.utils.paths import PROJECT_ROOT
from src.utils.ontology_utils import get_sub_DAG

# 1. Load the cached ontology object
cl = load_ontology()

# Define the root of the ontology subgraph to be processed
root_cl_id = 'CL:0000988'  # hematopoietic cell

# 2. Load filtered cell metadata from CellXGene Census
# This step can take a few minutes
cell_obs_metadata = load_filtered_cell_metadata(cl, root_cl_id=root_cl_id)

# 3. Preprocess the ontology and cell data
target_column = 'cell_type_ontology_term_id'

mapping_dict, leaf_values, internal_values, ontology_df, cell_children_mask = preprocess_data_ontology(
        cl, cell_obs_metadata, target_column,
        upper_limit=root_cl_id,
        cl_only=True, include_leafs=False
    )

reverse_mapping_dict = {v: k for k, v in mapping_dict.items()}

print(f'Loaded {len(mapping_dict)} cell types.')
print("Cell children mask shape:", cell_children_mask.shape)


Loading cached ontology from /Users/jzhao/dev/Welch-lab/McCell/data/processed/ontology.pkl...


The "stable" release is currently 2025-01-30. Specify 'census_version="2025-01-30"' in future calls to open_soma() to ensure data consistency.


Ontology loaded successfully.
Fetching descendants of CL:0000988...
Connecting to CellXGene Census...
Reading cell metadata to filter cell types...
Found 160 cell types with > 5000 cells.
Querying for final cell metadata...
Finished loading and filtering cell metadata.
141 cell types in the dataset 41 leaf types, 100 internal types
Loaded 141 cell types.
Cell children mask shape: torch.Size([141, 141])


Data loader


In [3]:
import cellxgene_census
import cellxgene_census.experimental.ml as census_ml
import tiledbsoma as soma

all_cell_values = list(mapping_dict.keys())

In [6]:

# Load the gene list from the file you found and filter for protein-coding genes.
print("Loading and filtering gene list from hpc_workaround/data/mart_export.txt...")
biomart = pd.read_csv("/Users/jzhao/dev/Welch-lab/McCell/hpc_workaround/data/mart_export.txt")
coding_only = biomart[biomart['Gene type'] == 'protein_coding']
gene_list = coding_only['Gene stable ID'].tolist()[:500]  # Limit to first 500 genes for speed

# Create filters for the data query, including the assay and primary data filters.
var_value_filter = f"feature_id in {gene_list}"
obs_value_filter = f'''assay == "10x 3' v3" and is_primary_data == True and cell_type_ontology_term_id in {all_cell_values}'''

print(f"Querying {len(all_cell_values)} cell types and {len(gene_list)} protein-coding genes.")

with cellxgene_census.open_soma() as census:
    

    print(f"Querying {len(all_cell_values)} cell types and {len(gene_list)} genes.")

    experiment_datapipe = census_ml.pytorch.ExperimentDataPipe(
        census["census_data"]["homo_sapiens"],
        measurement_name="RNA",
        X_name="raw",
        obs_query=soma.AxisQuery(value_filter=obs_value_filter),
        var_query=soma.AxisQuery(value_filter=var_value_filter),
        obs_column_names=["cell_type_ontology_term_id"],
        batch_size=256,
        shuffle=True,
    )

    # Create the train/validation split
    train_datapipe, val_datapipe = experiment_datapipe.random_split(weights={"train": 0.8, "val": 0.2}, seed=42)
    train_dataloader = census_ml.experiment_dataloader(train_datapipe)
    val_dataloader = census_ml.experiment_dataloader(val_datapipe)

print("\nDataLoaders are ready.") 

Loading and filtering gene list from hpc_workaround/data/mart_export.txt...


The "stable" release is currently 2025-01-30. Specify 'census_version="2025-01-30"' in future calls to open_soma() to ensure data consistency.


Querying 141 cell types and 500 protein-coding genes.
Querying 141 cell types and 500 genes.


  experiment_datapipe = census_ml.pytorch.ExperimentDataPipe(



DataLoaders are ready.


  train_dataloader = census_ml.experiment_dataloader(train_datapipe)
  val_dataloader = census_ml.experiment_dataloader(val_datapipe)


In [7]:
# --- Verify a single batch of streamed data ---
print("\n--- Verifying a single batch of data ---")

# Take one batch from the train_dataloader
X_batch, y_batch_meta = next(iter(train_dataloader))

print(f"Shape of X_batch: {X_batch.shape}")
print(f"Shape of y_batch_meta: {y_batch_meta.shape}")

# Let's look at the first 5 cells
num_to_show = 5
X_sample = X_batch[:num_to_show]
y_sample_meta = y_batch_meta[:num_to_show]

# Map ontology terms (CL IDs) to integer labels
y_sample = torch.tensor(
    [mapping_dict[term] for term in y_sample_meta[:, 0]],
    dtype=torch.long
)

# Reverse dictionary for displaying CL term back from integer index
reverse_mapping_dict = {v: k for k, v in mapping_dict.items()}

print("\nExample cells (first 5):")
print("Cell # | y_meta (CL term)           | mapped label idx | X nonzero entries")
print("--------------------------------------------------------------------------")
for i in range(num_to_show):
    cl_id = y_sample_meta[i, 0]
    label_idx = y_sample[i].item()
    nnz = int((X_sample[i] != 0).sum())  # number of expressed genes in sparse row
    print(f"{i:<6} | {cl_id:<25} | {label_idx:<16} | {nnz}")



--- Verifying a single batch of data ---


KeyboardInterrupt: 

In [None]:
from src.train.model import SimpleNN
from src.train.loss import MarginalizationLoss

import torch.optim as optim
import matplotlib.pyplot as plt

# --- 1. Setup ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

input_dim = experiment_datapipe.shape[1]
output_dim = len(leaf_values)

model = SimpleNN(input_dim=input_dim, output_dim=output_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-4)

# Instantiate the loss function.
# It now correctly creates the leaf_indices set internally.
loss_fn = MarginalizationLoss(
    ontology_df=ontology_df,
    leaf_values=leaf_values,
    mapping_dict=mapping_dict,
    device=device
)

# --- 2. Training Loop (Short Run) ---
num_epochs = 1
batches_per_epoch = 5  # Limit the run to 500 batches per epoch for speed
batch_loss_history = []

print(f"\nStarting a short training run for {num_epochs} epochs ({batches_per_epoch} batches each)...")
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for i, batch in enumerate(train_dataloader):
        if i >= batches_per_epoch:
            break
        
        X_batch, y_batch_meta = batch
        X_batch = X_batch.float().to(device)
        y_batch = torch.tensor([mapping_dict[term] for term in y_batch_meta[:, 0]], device=device, dtype=torch.long)

        optimizer.zero_grad()
        outputs = model(X_batch)
        total_loss, _, _ = loss_fn(outputs, y_batch)
        total_loss.backward()
        optimizer.step()

        # Record loss for plotting
        batch_loss_history.append(total_loss.item())

        running_loss += total_loss.item()
        if (i + 1) % 100 == 0:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

# --- 3. Plotting Loss ---
print("\nPlotting training loss...")
plt.figure(figsize=(12, 6))
plt.plot(batch_loss_history, label='Total Loss per Batch')
plt.xlabel("Batch Number")
plt.ylabel("Loss")
plt.title("Training Loss Over Time")
plt.legend()
plt.grid(True)
plt.show()


Using device: cpu

Starting a short training run for 2 epochs (50 batches each)...


Exception ignored in: <function _EagerIterator.__del__ at 0x30bce5d00>
Traceback (most recent call last):
  File "/Users/jzhao/dev/Welch-lab/McCell/.venv/lib/python3.11/site-packages/cellxgene_census/experimental/util/_eager_iter.py", line 50, in __del__
    self._cleanup()
  File "/Users/jzhao/dev/Welch-lab/McCell/.venv/lib/python3.11/site-packages/cellxgene_census/experimental/util/_eager_iter.py", line 44, in _cleanup
    self._pool.shutdown()
  File "/Users/jzhao/.local/share/uv/python/cpython-3.11.13-macos-aarch64-none/lib/python3.11/concurrent/futures/thread.py", line 235, in shutdown
    t.join()
  File "/Users/jzhao/.local/share/uv/python/cpython-3.11.13-macos-aarch64-none/lib/python3.11/threading.py", line 1119, in join
    self._wait_for_tstate_lock()
  File "/Users/jzhao/.local/share/uv/python/cpython-3.11.13-macos-aarch64-none/lib/python3.11/threading.py", line 1139, in _wait_for_tstate_lock
    if lock.acquire(block, timeout):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardIn

In [7]:
import socket
print(socket.gethostbyname("cellxgene-census-public-us-west-2.s3.us-west-2.amazonaws.com"))


52.92.136.210
