In [1]:
import pandas as pd
import torch
import pickle
from datetime import datetime
import os

# Imports from our project
from src.utils.paths import PROJECT_ROOT, get_data_folder
from src.utils.ontology_utils import load_ontology  # Still need this to access term names

# --- 1. Load Preprocessed Data Artifacts ---
# Instead of running preprocessing, we now load the files created by `run_preprocessing.py`.

# Hardcoded date for loading the preprocessed files
DATE = '2025-10-24'
PROCESSED_DATA_DIR = get_data_folder(DATE)

print(f"Loading data from: {PROCESSED_DATA_DIR} for date {DATE}")

# Load the ontology object to get term names for printing
cl = load_ontology()

# Load DataFrames
marginalization_df = pd.read_csv(PROCESSED_DATA_DIR / f"{DATE}_marginalization_df.csv", index_col=0)
parent_child_df = pd.read_csv(PROCESSED_DATA_DIR / f"{DATE}_parent_child_df.csv", index_col=0)
exclusion_df = pd.read_csv(PROCESSED_DATA_DIR / f"{DATE}_exclusion_df.csv", index_col=0)

# Load mapping_dict
mapping_dict_df = pd.read_csv(PROCESSED_DATA_DIR / f"{DATE}_mapping_dict_df.csv", index_col=0)
# The DataFrame was saved with CL numbers as the index and integer mappings in the first column
mapping_dict = pd.Series(mapping_dict_df.iloc[:, 0].values, index=mapping_dict_df.index).to_dict()

# Load leaf and internal values
with open(PROCESSED_DATA_DIR / f"{DATE}_leaf_values.pkl", "rb") as fp:
    leaf_values = pickle.load(fp)
with open(PROCESSED_DATA_DIR / f"{DATE}_internal_values.pkl", "rb") as fp:
    internal_values = pickle.load(fp)

print("\nAll data artifacts loaded successfully.")
print(f"Loaded {len(mapping_dict)} cell types.")
print(f"  - {len(leaf_values)} leaf nodes")
print(f"  - {len(internal_values)} internal nodes")

Loading data from: /home/jingqiao/real_McCell/data/processed/10-24 for date 2025-10-24
Loading cached ontology from /home/jingqiao/real_McCell/data/processed/ontology.pkl...
Ontology loaded successfully.

All data artifacts loaded successfully.
Loaded 80 cell types.
  - 23 leaf nodes
  - 57 internal nodes


Data loader


In [9]:
import cellxgene_census
import tiledbsoma as soma
from tiledbsoma_ml import ExperimentDataset, experiment_dataloader
import pandas as pd

# Get all cell types from our mapping dict to build the query
all_cell_values = list(mapping_dict.keys())

# --- Load gene list from BioMart (matching old_reference approach) ---
print("Loading protein-coding genes from BioMart...")
biomart_path = PROJECT_ROOT / "hpc_workaround/data/mart_export.txt"
biomart = pd.read_csv(biomart_path)

# Filter for protein-coding genes only
coding_only = biomart[biomart['Gene type'] == 'protein_coding']
gene_list = coding_only['Gene stable ID'].tolist()

print(f"Loaded {len(gene_list)} protein-coding genes from BioMart")

# Create the 'value_filter' strings for the query
var_value_filter = f"feature_id in {gene_list}"
obs_value_filter = f'assay == "10x 3\' v3" and is_primary_data == True and cell_type_ontology_term_id in {all_cell_values}'

print(f"Ready to query {len(all_cell_values)} cell types and {len(gene_list)} protein-coding genes.")

Loading protein-coding genes from BioMart...
Loaded 23262 protein-coding genes from BioMart
Ready to query 80 cell types and 23262 protein-coding genes.


In [None]:
# Point to the local SOMA database (which is already the homo_sapiens experiment)
soma_uri = "/scratch/sigbio_project_root/sigbio_project25/jingqiao/mccell-single/soma_db_homo_sapiens"
print(f"Opening local SOMA database at: {soma_uri}")

# Open the experiment directly (it's a SOMAExperiment, not a SOMACollection)
experiment = soma.open(soma_uri, mode="r")

# Create the ExperimentDataset and DataLoaders using the query filters
with experiment.axis_query(
    measurement_name="RNA",
    obs_query=soma.AxisQuery(value_filter=obs_value_filter),
    var_query=soma.AxisQuery(value_filter=var_value_filter),
) as query:
    experiment_dataset = ExperimentDataset(
        query,
        obs_column_names=["cell_type_ontology_term_id"],
        layer_name="raw",
        batch_size=256,
        shuffle=True,
        seed=111
    )

    train_dataset, val_dataset = experiment_dataset.random_split(0.8, 0.2, seed=42)
    
    # Get cell count before context closes
    actual_cell_count = len(experiment_dataset.query_ids.obs_joinids)
    print(f'\nTotal matching cells: {actual_cell_count}')
    print(f'Training set size: {len(train_dataset)}')
    print(f'Validation set size: {len(val_dataset)}')

# Create dataloaders OUTSIDE the context manager to avoid potential issues
train_dataloader = experiment_dataloader(train_dataset)
val_dataloader = experiment_dataloader(val_dataset)

# Show a summary of the loaded train and validation datasets
print("\nTrain dataset shape:", train_dataset.shape)
print("Validation dataset shape:", val_dataset.shape)

In [30]:
# DEBUG: Print the actual filter strings
print("\n" + "="*80)
print("DEBUG: Filter strings")
print("="*80)
print(f"Number of cell types: {len(all_cell_values)}")
print(f"Number of genes: {len(gene_list)}")
print(f"\nobs_value_filter (first 200 chars):")
print(obs_value_filter[:200] + "...")
print(f"\nvar_value_filter (first 200 chars):")
print(var_value_filter[:200] + "...")
print("="*80 + "\n")


DEBUG: Filter strings
Number of cell types: 80
Number of genes: 23262

obs_value_filter (first 200 chars):
assay == "10x 3' v3" and is_primary_data == True and cell_type_ontology_term_id in ['CL:0000233', 'CL:0000895', 'CL:0000900', 'CL:0000904', 'CL:0000905', 'CL:0000910', 'CL:0000912', 'CL:0000913', 'CL:...

var_value_filter (first 200 chars):
feature_id in ['ENSG00000198888', 'ENSG00000198763', 'ENSG00000198804', 'ENSG00000198712', 'ENSG00000228253', 'ENSG00000198899', 'ENSG00000198938', 'ENSG00000198840', 'ENSG00000212907', 'ENSG000001988...



In [None]:
from src.train.model import SimpleNN
from src.train.loss import MarginalizationLoss
import torch.optim as optim
import matplotlib.pyplot as plt

# --- 1. Setup with Multi-GPU Support ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Check for multiple GPUs
num_gpus = torch.cuda.device_count()
print(f"Available GPUs: {num_gpus}")
if num_gpus > 1:
    print(f"GPU names: {[torch.cuda.get_device_name(i) for i in range(num_gpus)]}")

# The input dimension is the number of genes from our dataset object
input_dim = train_dataset.shape[1]
output_dim = len(leaf_values)  # Model only predicts leaf nodes

model = SimpleNN(input_dim=input_dim, output_dim=output_dim)

# Wrap model with DataParallel for multi-GPU training
if num_gpus > 1:
    print(f"\n🚀 Using DataParallel with {num_gpus} GPUs")
    model = torch.nn.DataParallel(model)
    print(f"   Effective batch size: {256 * num_gpus} (256 per GPU × {num_gpus} GPUs)")
else:
    print("\n⚠️  Single GPU mode")

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Instantiate the new, correct loss function with all required artifacts
loss_fn = MarginalizationLoss(
    marginalization_df=marginalization_df,
    parent_child_df=parent_child_df,
    exclusion_df=exclusion_df,
    leaf_values=leaf_values,
    internal_values=internal_values,
    mapping_dict=mapping_dict,
    device=device
)

print("\nModel, optimizer, and loss function are ready.")

In [None]:
import time

num_epochs = 5
batches_per_epoch = len(train_dataset)  # Use ALL training batches (19,665)
batch_loss_history = []

print(f"\n{'='*80}")
print(f"TRAINING CONFIGURATION")
print(f"{'='*80}")
print(f"Total training batches: {len(train_dataset):,}")
print(f"Total training cells: ~{len(train_dataset) * 256:,}")
print(f"Batches per epoch: {batches_per_epoch:,} (100% of data)")
print(f"Number of epochs: {num_epochs}")
print(f"Total batches to process: {batches_per_epoch * num_epochs:,}")
print(f"{'='*80}\n")

print(f"Starting training for {num_epochs} epochs...")
print(f"⏱️  Estimating ~15-20 hours per epoch on single GPU")
if torch.cuda.device_count() > 1:
    print(f"⚡ With {torch.cuda.device_count()} GPUs, expecting ~{15//torch.cuda.device_count()}-{20//torch.cuda.device_count()} hours per epoch")
print()

epoch_times = []

for epoch in range(num_epochs):
    model.train()
    epoch_start_time = time.time()
    print(f'\n{"="*80}')
    print(f'EPOCH {epoch + 1}/{num_epochs}')
    print(f'{"="*80}')
    
    epoch_losses = []
    batch_count = 0

    for i, (X_batch, obs_batch) in enumerate(train_dataloader):
        if i >= batches_per_epoch:
            break

        # Data preparation
        X_batch = torch.from_numpy(X_batch).float()
        X_batch = torch.log1p(X_batch)  # Log-transform gene expression
        X_batch = X_batch.to(device)
        
        label_strings = obs_batch["cell_type_ontology_term_id"]
        y_batch = torch.tensor([mapping_dict[term] for term in label_strings], device=device, dtype=torch.long)

        # Training step
        optimizer.zero_grad()
        outputs = model(X_batch)
        total_loss, loss_leafs, loss_parents = loss_fn(outputs, y_batch)
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()

        # Logging
        batch_loss_history.append(total_loss.item())
        epoch_losses.append(total_loss.item())
        batch_count += 1
        
        # Progress updates every 1000 batches
        if (i + 1) % 1000 == 0:
            elapsed = time.time() - epoch_start_time
            batches_remaining = batches_per_epoch - (i + 1)
            time_per_batch = elapsed / (i + 1)
            eta_seconds = batches_remaining * time_per_batch
            eta_hours = eta_seconds / 3600
            
            avg_loss_recent = sum(epoch_losses[-1000:]) / len(epoch_losses[-1000:])
            print(f'  [Batch {i+1:5d}/{batches_per_epoch}] '
                  f'Loss: {total_loss.item():.4f} (avg: {avg_loss_recent:.4f}) | '
                  f'Elapsed: {elapsed/3600:.2f}h | ETA: {eta_hours:.2f}h')

    epoch_time = time.time() - epoch_start_time
    epoch_times.append(epoch_time)
    avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
    
    print(f'\n--- Epoch {epoch + 1} Summary ---')
    print(f'  Time: {epoch_time/3600:.2f} hours ({epoch_time/60:.1f} minutes)')
    print(f'  Batches processed: {batch_count:,}')
    print(f'  Average loss: {avg_epoch_loss:.4f}')
    print(f'  Final loss: {epoch_losses[-1]:.4f}')
    
    if len(epoch_times) > 1:
        avg_time = sum(epoch_times) / len(epoch_times)
        remaining_epochs = num_epochs - (epoch + 1)
        total_eta = (avg_time * remaining_epochs) / 3600
        print(f'  Estimated time remaining: {total_eta:.1f} hours')

print('\n' + '='*80)
print('TRAINING COMPLETE')
print('='*80)
print(f'Total time: {sum(epoch_times)/3600:.2f} hours')
print(f'Average time per epoch: {sum(epoch_times)/len(epoch_times)/3600:.2f} hours')

In [None]:
import os
from datetime import datetime

# Create directory for saving models in data folder
save_dir = PROJECT_ROOT / "data" / "saved_models"
save_dir.mkdir(parents=True, exist_ok=True)
print(f"Save directory: {save_dir}")

# Generate timestamp for this training run
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
model_name = f"blood_cell_model_{timestamp}"

# Save the model weights
model_path = save_dir / f"{model_name}.pt"
print(f"\n{'='*80}")
print(f"SAVING MODEL")
print(f"{'='*80}")

# If using DataParallel, need to save the underlying module
if isinstance(model, torch.nn.DataParallel):
    state_dict_to_save = model.module.state_dict()
    print("Saving DataParallel model (extracting module.state_dict())")
else:
    state_dict_to_save = model.state_dict()
    print("Saving single GPU model")

# Save full checkpoint with training state
checkpoint = {
    'model_state_dict': state_dict_to_save,
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': num_epochs,
    'batch_loss_history': batch_loss_history,
    'epoch_times': epoch_times,
    'num_gpus': torch.cuda.device_count(),
    'input_dim': input_dim,
    'output_dim': output_dim,
    'date_preprocessed': DATE,
    'total_training_cells': len(train_dataset) * 256,
    'total_batches_processed': len(batch_loss_history),
}

torch.save(checkpoint, model_path)

print(f"Model saved to: {model_path}")
print(f"  - Model architecture: SimpleNN({input_dim} -> {output_dim})")
print(f"  - Training epochs: {num_epochs}")
print(f"  - Total batches: {len(batch_loss_history):,}")
print(f"  - Final loss: {batch_loss_history[-1]:.4f}")
print(f"  - Total training time: {sum(epoch_times)/3600:.2f} hours")

# Also save just the model weights (smaller file, easier to load)
weights_path = save_dir / f"{model_name}_weights_only.pt"
torch.save(state_dict_to_save, weights_path)
print(f"\nWeights-only file saved to: {weights_path}")
print(f"  - File size: ~{os.path.getsize(weights_path) / 1e6:.1f} MB")

# Save training history as CSV for easy plotting later
history_df = pd.DataFrame({
    'batch': range(1, len(batch_loss_history) + 1),
    'loss': batch_loss_history
})
history_path = save_dir / f"{model_name}_training_history.csv"
history_df.to_csv(history_path, index=False)
print(f"\nTraining history saved to: {history_path}")

print(f"\n{'='*80}")
print(f"To load this model later:")
print(f"  checkpoint = torch.load('{model_path}')")
print(f"  model.load_state_dict(checkpoint['model_state_dict'])")
print(f"{'='*80}\n")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# --- Plot Training Loss ---
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(batch_loss_history, alpha=0.6, linewidth=0.8)
plt.xlabel('Batch')
plt.ylabel('Total Loss')
plt.title('Training Loss per Batch')
plt.grid(True, alpha=0.3)

# Moving average for smoother trend
window_size = 20
if len(batch_loss_history) >= window_size:
    moving_avg = np.convolve(batch_loss_history, np.ones(window_size)/window_size, mode='valid')
    plt.plot(range(window_size-1, len(batch_loss_history)), moving_avg, 
             color='red', linewidth=2, label=f'{window_size}-batch moving avg')
    plt.legend()

plt.subplot(1, 2, 2)
# Plot loss by epoch (if multiple epochs)
if num_epochs > 1:
    epoch_losses = [batch_loss_history[i*batches_per_epoch:(i+1)*batches_per_epoch] 
                   for i in range(num_epochs)]
    for i, losses in enumerate(epoch_losses):
        plt.plot(losses, label=f'Epoch {i+1}', alpha=0.7)
    plt.xlabel('Batch (within epoch)')
    plt.ylabel('Total Loss')
    plt.title('Training Loss by Epoch')
    plt.legend()
    plt.grid(True, alpha=0.3)
else:
    plt.text(0.5, 0.5, 'Run multiple epochs\nto see comparison', 
             ha='center', va='center', transform=plt.gca().transAxes)

plt.tight_layout()
plt.show()

print(f"\nTraining Summary:")
print(f"  Initial loss: {batch_loss_history[0]:.4f}")
print(f"  Final loss: {batch_loss_history[-1]:.4f}")
print(f"  Loss reduction: {(1 - batch_loss_history[-1]/batch_loss_history[0])*100:.1f}%")

In [None]:
import numpy as np
import glob

# --- Load the most recent saved model ---
save_dir = PROJECT_ROOT / "data" / "saved_models"
model_files = sorted(glob.glob(str(save_dir / "blood_cell_model_*.pt")))

if not model_files:
    print("ERROR: No saved model found in data/saved_models/")
    print("Please run training first (cell-6 and cell-7)")
else:
    latest_model = model_files[-1]  # Get most recent
    print(f"Loading model from: {latest_model}")
    
    # Load checkpoint
    checkpoint = torch.load(latest_model)
    
    # Create fresh model with same architecture
    input_dim = checkpoint['input_dim']
    output_dim = checkpoint['output_dim']
    
    validation_model = SimpleNN(input_dim=input_dim, output_dim=output_dim)
    
    # Handle DataParallel if needed
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        validation_model = torch.nn.DataParallel(validation_model)
        validation_model.module.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded model with DataParallel ({num_gpus} GPUs)")
    else:
        validation_model.load_state_dict(checkpoint['model_state_dict'])
        print("Loaded model (single GPU)")
    
    validation_model = validation_model.to(device)
    validation_model.eval()
    
    print(f"Model info:")
    print(f"  - Trained epochs: {checkpoint.get('num_epochs', 'unknown')}")
    print(f"  - Total batches: {checkpoint.get('total_batches_processed', 'unknown'):,}")
    print(f"  - Architecture: SimpleNN({input_dim} -> {output_dim})")
    print()

    # --- Validate the Model (LEAF NODES ONLY) ---
    print("="*60)
    print("VALIDATION (LEAF NODES ONLY)")
    print("="*60)

    val_batches = 50

    val_total_losses = []
    val_leaf_losses = []
    val_parent_losses = []
    all_predictions = []
    all_labels = []

    leaf_indices_set = {mapping_dict[cid] for cid in leaf_values}
    total_samples = 0
    leaf_samples = 0

    with torch.no_grad():
        for i, (X_batch, obs_batch) in enumerate(val_dataloader):
            if i >= val_batches:
                break

            # Data preparation
            X_batch = torch.from_numpy(X_batch).float()
            X_batch = torch.log1p(X_batch)
            X_batch = X_batch.to(device)

            label_strings = obs_batch["cell_type_ontology_term_id"]
            y_batch = torch.tensor([mapping_dict[term] for term in label_strings],
                                  device=device, dtype=torch.long)

            total_samples += len(y_batch)

            # FILTER: Only keep samples with LEAF node labels
            is_leaf = torch.tensor([y.item() in leaf_indices_set for y in y_batch], device=device)

            if is_leaf.sum() == 0:
                continue  # Skip batches with no leaf samples

            X_batch_leaf = X_batch[is_leaf]
            y_batch_leaf = y_batch[is_leaf]
            leaf_samples += len(y_batch_leaf)

            # Forward pass
            outputs = validation_model(X_batch_leaf)
            total_loss, loss_leafs, loss_parents = loss_fn(outputs, y_batch_leaf)

            val_total_losses.append(total_loss.item())
            val_leaf_losses.append(loss_leafs.item())
            val_parent_losses.append(loss_parents.item())

            # Get predictions (argmax of logits)
            predictions = torch.argmax(outputs, dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(y_batch_leaf.cpu().numpy())

    # Calculate metrics
    avg_val_loss = np.mean(val_total_losses)
    avg_leaf_loss = np.mean(val_leaf_losses)
    avg_parent_loss = np.mean(val_parent_losses)

    # Accuracy (only meaningful for leaf nodes)
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    accuracy = (all_predictions == all_labels).mean()

    print(f"\nValidation Results (LEAF-LABELED SAMPLES ONLY):")
    print(f"  Total samples processed: {total_samples}")
    print(f"  Leaf-labeled samples: {leaf_samples} ({leaf_samples/total_samples*100:.1f}%)")
    print(f"  Internal-labeled samples (skipped): {total_samples - leaf_samples}")
    print(f"\n  Average Total Loss: {avg_val_loss:.4f}")
    print(f"  Average Leaf Loss:  {avg_leaf_loss:.4f}")
    print(f"  Average Parent Loss: {avg_parent_loss:.4f}")
    print(f"  Leaf Accuracy: {accuracy*100:.2f}%")
    print(f"\nNote: This validation only tests the model on leaf node predictions.")
    print(f"      Internal node labels are excluded from evaluation.")