In [None]:
import pandas as pd
import torch
import pickle
from datetime import datetime
from src.utils.ontology_utils import load_ontology
from src.data_pipeline.data_loader import load_filtered_cell_metadata
from src.data_pipeline.preprocess_ontology import preprocess_data_ontology
from src.utils.paths import PROJECT_ROOT
from src.utils.ontology_utils import get_sub_DAG

# 1. Load the cached ontology object
cl = load_ontology()

# Define the root of the ontology subgraph to be processed
root_cl_id = 'CL:0000988'  # hematopoietic cell

# 2. Load filtered cell metadata from CellXGene Census
# This step can take a few minutes , and might not work on hpc due to likely RAM issues
cell_obs_metadata = load_filtered_cell_metadata(cl, root_cl_id=root_cl_id)

# 3. Preprocess the ontology and cell data
target_column = 'cell_type_ontology_term_id'

mapping_dict, leaf_values, internal_values, ontology_df, cell_children_mask = preprocess_data_ontology(
        cl, cell_obs_metadata, target_column,
        upper_limit=root_cl_id,
        cl_only=True, include_leafs=False
    )

reverse_mapping_dict = {v: k for k, v in mapping_dict.items()}

print(f'Loaded {len(mapping_dict)} cell types.')
print("Cell children mask shape:", cell_children_mask.shape)


Loading cached ontology from /Users/jzhao/dev/Welch-lab/McCell/data/processed/ontology.pkl...


The "stable" release is currently 2025-01-30. Specify 'census_version="2025-01-30"' in future calls to open_soma() to ensure data consistency.


Ontology loaded successfully.
Fetching descendants of CL:0000988...
Connecting to CellXGene Census...
Reading cell metadata to filter cell types...
Found 160 cell types with > 5000 cells.
Querying for final cell metadata...
Finished loading and filtering cell metadata.
141 cell types in the dataset 41 leaf types, 100 internal types
Loaded 141 cell types.
Cell children mask shape: torch.Size([141, 141])


Data loader


In [11]:
import cellxgene_census
import cellxgene_census.experimental.ml as census_ml
import tiledbsoma as soma
from tiledbsoma_ml import ExperimentDataset, experiment_dataloader
import pandas as pd
import torch

all_cell_values = list(mapping_dict.keys())

In [13]:
# --- Gene Filter (same as before) ---
print("Loading and filtering gene list...")
biomart = pd.read_csv("/Users/jzhao/dev/Welch-lab/McCell/hpc_workaround/data/mart_export.txt")
coding_only = biomart[biomart['Gene type'] == 'protein_coding']
full_gene_list = coding_only['Gene stable ID'].tolist()
num_genes_to_use = 500
gene_list = full_gene_list[:num_genes_to_use]
var_value_filter = f"feature_id in {gene_list}"

# Create filters for the data query, including the assay and primary data filters.
var_value_filter = f"feature_id in {gene_list}"
obs_value_filter = f'''assay == "10x 3' v3" and is_primary_data == True and cell_type_ontology_term_id in {all_cell_values}'''

print(f"Querying {len(all_cell_values)} cell types and {len(gene_list)} protein-coding genes.")

Loading and filtering gene list...
Querying 141 cell types and 500 protein-coding genes.


In [15]:
with cellxgene_census.open_soma() as census:
    experiment = census["census_data"]["homo_sapiens"]

    # 1. Create an ExperimentDataset with the full query
    with experiment.axis_query(
        measurement_name="RNA",
        obs_query=soma.AxisQuery(value_filter=obs_value_filter),
        var_query=soma.AxisQuery(value_filter=var_value_filter),
    ) as query:
        experiment_dataset = ExperimentDataset(
            query,
            obs_column_names=["cell_type_ontology_term_id"],
            layer_name="raw",
            batch_size=128,
            shuffle=True,
            seed=111
        )

        # 2. Split the Dataset object
        train_dataset, val_dataset = experiment_dataset.random_split([0.8, 0.2], seed=42)

        print(f"\nTotal matching cells: {len(experiment_dataset)}")
        print(f"Training set size: {len(train_dataset)}")
        print(f"Validation set size: {len(val_dataset)}")

        # 3. Wrap the subsets in the dataloader
        train_dataloader = experiment_dataloader(
            train_dataset
        )
        val_dataloader = experiment_dataloader(
            val_dataset
        )

print("\nTraining and validation DataLoaders are ready.")



The "stable" release is currently 2025-01-30. Specify 'census_version="2025-01-30"' in future calls to open_soma() to ensure data consistency.



Total matching cells: 49872
Training set size: 39898
Validation set size: 9975

Training and validation DataLoaders are ready.


In [None]:
from src.train.model import SimpleNN
from src.train.loss import MarginalizationLoss
import matplotlib.pyplot as plt

# --- 1. Setup ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# The input dimension is the number of genes from our dataset object
input_dim = train_dataset.shape[1]
output_dim = len(leaf_values)

model = SimpleNN(input_dim=input_dim, output_dim=output_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss_fn = MarginalizationLoss(
    ontology_df=ontology_df,
    leaf_values=leaf_values,
    mapping_dict=mapping_dict,
    device=device
)


Using device: cpu


In [19]:
num_epochs = 1
batches_per_epoch = 1  # Limit to 500 batches per epoch for speed
batch_loss_history = []

print(f"\nStarting a short training run for {num_epochs} epochs ({batches_per_epoch} batches each)...")
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for i, (X_batch, obs_batch) in enumerate(train_dataloader):
        if i >= batches_per_epoch:
            break
        
        # Convert numpy array from dataloader to a torch tensor
        X_batch = torch.from_numpy(X_batch).float().to(device)
        
        # Get label strings from the pandas DataFrame and map to integer indices
        label_strings = obs_batch["cell_type_ontology_term_id"]
        y_batch = torch.tensor([mapping_dict[term] for term in label_strings], device=device, dtype=torch.long)

        # Standard PyTorch training steps
        optimizer.zero_grad()
        outputs = model(X_batch)
        total_loss, _, _ = loss_fn(outputs, y_batch)
        total_loss.backward()
        optimizer.step()

        # Record loss for plotting
        batch_loss_history.append(total_loss.item())
        
        running_loss += total_loss.item()
        if (i + 1) % 100 == 0:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')


Starting a short training run for 1 epochs (1 batches each)...


KeyboardInterrupt: 

In [20]:
import torch.optim as optim
from src.train.model import SimpleNN
from src.train.loss import MarginalizationLoss
import matplotlib.pyplot as plt

# --- 1. Setup ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

input_dim = train_dataset.shape[1]
output_dim = len(leaf_values)

model = SimpleNN(input_dim=input_dim, output_dim=output_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-4)
loss_fn = MarginalizationLoss(
    ontology_df=ontology_df,
    leaf_values=leaf_values,
    mapping_dict=mapping_dict,
    device=device
)

# --- 2. Verification Run (1 Batch) ---
print("\n--- Starting verification run with a single batch ---")
print("Fetching the first batch... (This may take a moment)")

model.train()
try:
    # Get just one batch from the dataloader
    X_batch, obs_batch = next(iter(train_dataloader))
    print("Batch successfully loaded!")

    # --- Data Preparation ---
    X_batch = torch.from_numpy(X_batch).float().to(device)
    label_strings = obs_batch["cell_type_ontology_term_id"]
    y_batch = torch.tensor([mapping_dict[term] for term in label_strings], device=device, dtype=torch.long)
    print(f"Data converted to tensors. X_batch shape: {X_batch.shape}")

    # --- Training Step ---
    optimizer.zero_grad()
    outputs = model(X_batch)
    total_loss, loss_leafs, loss_parents = loss_fn(outputs, y_batch)
    total_loss.backward()
    optimizer.step()
    print("Forward and backward pass completed successfully.")
    
    print("\n--- VERIFICATION SUCCESSFUL ---")
    print(f"Calculated Total Loss: {total_loss.item():.4f}")
    print(f"  - Leaf Loss component: {loss_leafs.item():.4f}")
    print(f"  - Parent Loss component: {loss_parents.item():.4f}")

except Exception as e:
    print(f"\n--- VERIFICATION FAILED ---")
    print(f"An error occurred: {e}")



Using device: cpu

--- Starting verification run with a single batch ---
Fetching the first batch... (This may take a moment)


KeyboardInterrupt: 

In [7]:
import socket
print(socket.gethostbyname("cellxgene-census-public-us-west-2.s3.us-west-2.amazonaws.com"))


52.92.136.210
