# CellType-NN Example Usage

This notebook demonstrates how to use the CellType-NN framework for cell type prediction from single-cell RNA-seq data.

## Setup

In [None]:
import sys
sys.path.insert(0, '../src')

import scanpy as sc
import torch
import pytorch_lightning as pl

from celltype_nn.data.loader import create_dataloaders
from celltype_nn.preprocessing.preprocess import preprocess_rna
from celltype_nn.models.rna_classifier import RNAClassifier
from celltype_nn.training.lightning_module import CellTypeClassifierModule
from celltype_nn.evaluation.metrics import evaluate_model, plot_confusion_matrix

# Set random seed
pl.seed_everything(42)

## Load and Preprocess Data

For this example, we'll use a publicly available dataset. Replace with your own data.

In [None]:
# Example: Load PBMC dataset from scanpy
adata = sc.datasets.pbmc3k()

# Basic QC and preprocessing
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)

# Add example cell type labels (in real data, these would be from annotation)
# For demo, we'll use Leiden clustering
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
sc.pp.pca(adata)
sc.pp.neighbors(adata)
sc.tl.leiden(adata, resolution=0.5)
adata.obs['cell_type'] = adata.obs['leiden']

print(f"Dataset: {adata.n_obs} cells x {adata.n_vars} genes")
print(f"Cell types: {adata.obs['cell_type'].value_counts()}")

## Preprocess for Neural Network

In [None]:
# Preprocess data
adata_processed = preprocess_rna(
    adata,
    n_top_genes=2000,
    normalize=True,
    log_transform=True,
    scale=False  # Don't scale for neural networks
)

# Subset to highly variable genes
adata_hvg = adata_processed[:, adata_processed.var['highly_variable']].copy()

print(f"Using {adata_hvg.n_vars} highly variable genes")

## Create DataLoaders

In [None]:
# Create train/val/test splits
dataloaders = create_dataloaders(
    adata_hvg,
    label_key='cell_type',
    batch_size=128,
    train_size=0.7,
    val_size=0.15,
    test_size=0.15,
    stratify=True,
    num_workers=0
)

train_dataset = dataloaders['datasets']['train']
print(f"Number of classes: {train_dataset.num_classes}")
print(f"Number of features: {train_dataset.num_features}")

## Build Model

In [None]:
# Create model
model = RNAClassifier(
    input_dim=train_dataset.num_features,
    num_classes=train_dataset.num_classes,
    hidden_dims=[256, 128, 64],
    dropout_rate=0.3,
    batch_norm=True,
    activation='relu'
)

print(f"Model architecture:\n{model}")

## Create Lightning Module

In [None]:
# Create Lightning module
lightning_module = CellTypeClassifierModule(
    model=model,
    num_classes=train_dataset.num_classes,
    learning_rate=1e-3,
    weight_decay=1e-5,
    optimizer='adamw',
    scheduler='cosine'
)

## Train Model

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

# Setup callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val/loss',
    mode='min',
    save_top_k=1,
    filename='best-model'
)

early_stop_callback = EarlyStopping(
    monitor='val/loss',
    patience=10,
    mode='min'
)

# Create trainer
trainer = pl.Trainer(
    max_epochs=50,
    accelerator='auto',
    devices=1,
    callbacks=[checkpoint_callback, early_stop_callback],
    log_every_n_steps=10
)

# Train
trainer.fit(
    lightning_module,
    train_dataloaders=dataloaders['train'],
    val_dataloaders=dataloaders['val']
)

## Evaluate on Test Set

In [None]:
# Test
test_results = trainer.test(lightning_module, dataloaders=dataloaders['test'])
print(test_results)

## Detailed Evaluation

In [None]:
# Get label names
label_names = [train_dataset.get_label_name(i) for i in range(train_dataset.num_classes)]

# Evaluate model
results = evaluate_model(
    lightning_module.model,
    dataloaders['test'],
    label_names=label_names
)

print("\nTest Metrics:")
for metric, value in results['metrics'].items():
    if not metric.startswith('f1_') or metric in ['f1_macro', 'f1_micro', 'f1_weighted']:
        print(f"  {metric}: {value:.4f}")

## Visualize Results

In [None]:
# Plot confusion matrix
plot_confusion_matrix(
    results['labels'],
    results['predictions'],
    label_names=label_names,
    normalize=True
)

## Make Predictions on New Data

In [None]:
# Get predictions and probabilities
import torch.nn.functional as F

# Example: predict on a batch
batch = next(iter(dataloaders['test']))
features = batch['features']

with torch.no_grad():
    logits = lightning_module.model(features)
    probs = F.softmax(logits, dim=1)
    preds = torch.argmax(probs, dim=1)

# Convert to cell type names
predicted_types = [label_names[p.item()] for p in preds]
print(f"Predicted cell types: {predicted_types[:10]}")

## Save Model

In [None]:
# Save checkpoint
trainer.save_checkpoint("final_model.ckpt")
print("Model saved to final_model.ckpt")

## Load Model for Inference

In [None]:
# Load from checkpoint
loaded_module = CellTypeClassifierModule.load_from_checkpoint(
    "final_model.ckpt",
    model=model
)
loaded_module.eval()
print("Model loaded successfully!")