In [1]:
import pandas as pd
import torch
import pickle
from datetime import datetime
import os

# Imports from our project
from src.utils.paths import PROJECT_ROOT
from src.utils.ontology_utils import load_ontology  # Still need this to access term names

# --- 1. Load Preprocessed Data Artifacts ---
# Instead of running preprocessing, we now load the files created by `run_preprocessing.py`.

# Hardcoded date for loading the preprocessed files
DATE = '2025-10-17'
PROCESSED_DATA_DIR = PROJECT_ROOT / "data" / "processed"

print(f"Loading data from: {PROCESSED_DATA_DIR} for date {DATE}")

# Load the ontology object to get term names for printing
cl = load_ontology()

# Load DataFrames
marginalization_df = pd.read_csv(PROCESSED_DATA_DIR / f"{DATE}_marginalization_df.csv", index_col=0)
parent_child_df = pd.read_csv(PROCESSED_DATA_DIR / f"{DATE}_parent_child_df.csv", index_col=0)
exclusion_df = pd.read_csv(PROCESSED_DATA_DIR / f"{DATE}_exclusion_df.csv", index_col=0)

# Load mapping_dict
mapping_dict_df = pd.read_csv(PROCESSED_DATA_DIR / f"{DATE}_mapping_dict_df.csv", index_col=0)
# The DataFrame was saved with CL numbers as the index and integer mappings in the first column
mapping_dict = pd.Series(mapping_dict_df.iloc[:, 0].values, index=mapping_dict_df.index).to_dict()

# Load leaf and internal values
with open(PROCESSED_DATA_DIR / f"{DATE}_leaf_values.pkl", "rb") as fp:
    leaf_values = pickle.load(fp)
with open(PROCESSED_DATA_DIR / f"{DATE}_internal_values.pkl", "rb") as fp:
    internal_values = pickle.load(fp)

print("\nAll data artifacts loaded successfully.")
print(f"Loaded {len(mapping_dict)} cell types.")
print(f"  - {len(leaf_values)} leaf nodes")
print(f"  - {len(internal_values)} internal nodes")




Loading data from: /home/jingqiao/real_McCell/data/processed for date 2025-10-17
Loading cached ontology from /home/jingqiao/real_McCell/data/processed/ontology.pkl...
Ontology loaded successfully.

All data artifacts loaded successfully.
Loaded 141 cell types.
  - 41 leaf nodes
  - 100 internal nodes


Data loader


In [2]:
import cellxgene_census
import tiledbsoma as soma
from tiledbsoma_ml import ExperimentDataset, experiment_dataloader

# Get all cell types from our mapping dict to build the query
all_cell_values = list(mapping_dict.keys())

# --- Create Filters for SOMA Query ---

# Use a subset of genes for speed
print("Loading and filtering gene list...")
# Assuming the gene list file is in a relative path
biomart = pd.read_csv(PROJECT_ROOT / "hpc_workaround/data/mart_export.txt")
coding_only = biomart[biomart['Gene type'] == 'protein_coding']
full_gene_list = coding_only['Gene stable ID'].tolist()
num_genes_to_use = 500
gene_list = full_gene_list[:num_genes_to_use]

# Create the 'value_filter' strings for the query
var_value_filter = f"feature_id in {gene_list}"
obs_value_filter = f"assay == '10x 3\\' v3' and is_primary_data == True and cell_type_ontology_term_id in {all_cell_values}"

print(f"Querying {len(all_cell_values)} cell types and {len(gene_list)} protein-coding genes.")



Loading and filtering gene list...
Querying 141 cell types and 500 protein-coding genes.


In [3]:
# Point to the local SOMA database instead of the live one
soma_uri = "/scratch/sigbio_project_root/sigbio_project25/jingqiao/mccell-single/soma_db_homo_sapiens"
print(f"Opening local SOMA database at: {soma_uri}")

with soma.open(uri=soma_uri) as experiment:

    # Create the ExperimentDataset and DataLoaders using the query filters
    with experiment.axis_query(
        measurement_name="RNA",
        obs_query=soma.AxisQuery(value_filter=obs_value_filter),
        var_query=soma.AxisQuery(value_filter=var_value_filter),
    ) as query:
        experiment_dataset = ExperimentDataset(
            query,
            obs_column_names=["cell_type_ontology_term_id"],
            layer_name="raw",
            batch_size=128,
            shuffle=True,
            seed=111
        )

        train_dataset, val_dataset = experiment_dataset.random_split([0.8, 0.2], seed=42)

        print(f'\nTotal matching cells: {len(experiment_dataset)}')
        print(f'Training set size: {len(train_dataset)}')
        print(f'Validation set size: {len(val_dataset)}')

        train_dataloader = experiment_dataloader(train_dataset)
        val_dataloader = experiment_dataloader(val_dataset)

# Show a summary of the loaded train and validation datasets
print("\nTrain dataset shape:", train_dataset.shape)
print("Validation dataset shape:", val_dataset.shape)


Opening local SOMA database at: /scratch/sigbio_project_root/sigbio_project25/jingqiao/mccell-single/soma_db_homo_sapiens

Total matching cells: 49872
Training set size: 39898
Validation set size: 9975

Train dataset shape: (39898, 389)
Validation dataset shape: (9975, 389)


In [1]:
from src.train.model import SimpleNN
from src.train.loss import MarginalizationLoss
import torch.optim as optim
import matplotlib.pyplot as plt

# --- 1. Setup ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# The input dimension is the number of genes from our dataset object
input_dim = train_dataset.shape[1]
output_dim = len(leaf_values)  # Model only predicts leaf nodes

model = SimpleNN(input_dim=input_dim, output_dim=output_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-4)

# Instantiate the new, correct loss function with all required artifacts
loss_fn = MarginalizationLoss(
    marginalization_df=marginalization_df,
    parent_child_df=parent_child_df,
    exclusion_df=exclusion_df,
    leaf_values=leaf_values,
    internal_values=internal_values,
    mapping_dict=mapping_dict,
    device=device
)

print("Model, optimizer, and loss function are ready.") 

NameError: name 'torch' is not defined

In [None]:
num_epochs = 2
batches_per_epoch = 200  # Limit batches for a quick run
batch_loss_history = []

print(f"\nStarting training for {num_epochs} epochs ({batches_per_epoch} batches each)...")
for epoch in range(num_epochs):
    model.train()
    print(f'\n--- Epoch {epoch + 1} ---')

    for i, (X_batch, obs_batch) in enumerate(train_dataloader):
        if i >= batches_per_epoch:
            break

        # --- Data Preparation ---
        X_batch = torch.from_numpy(X_batch).float().to(device)
        label_strings = obs_batch["cell_type_ontology_term_id"]
        y_batch = torch.tensor([mapping_dict[term] for term in label_strings], device=device, dtype=torch.long)

        # --- Training Step ---
        optimizer.zero_grad()
        outputs = model(X_batch)
        # Our loss function returns all three components
        total_loss, loss_leafs, loss_parents = loss_fn(outputs, y_batch)
        total_loss.backward()
        optimizer.step()

        # --- Logging ---
        batch_loss_history.append(total_loss.item())
        # Print a more detailed log every 50 batches
        if (i + 1) % 50 == 0:
            print(f'  [Batch {i + 1:3d}] Total Loss: {total_loss.item():.4f} (Leaf: {loss_leafs.item():.4f}, Parent: {loss_parents.item():.4f})')

print('\nFinished Training.')




Starting training for 2 epochs (200 batches each)...

--- Epoch 1 ---


In [None]:
import socket
print(socket.gethostbyname("cellxgene-census-public-us-west-2.s3.us-west-2.amazonaws.com"))


52.92.136.210
