# CellType-NN Demo: Cell Type Prediction with Deep Learning

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yourusername/celltype-nn/blob/main/notebooks/celltype_nn_demo.ipynb)

This notebook demonstrates how to use CellType-NN for automated cell type prediction from single-cell RNA-seq data using deep learning.

## What you'll learn:
- Generate synthetic single-cell data for testing
- Preprocess scRNA-seq data
- Train a neural network classifier
- Evaluate model performance
- Visualize results

## 1. Setup and Installation

In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q pytorch-lightning
!pip install -q scanpy anndata muon
!pip install -q scikit-learn matplotlib seaborn

In [None]:
# Clone the CellType-NN repository
!git clone https://github.com/salzcamino/test-celltype-nn.git
%cd test-celltype-nn
!pip install -q -e .

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import pytorch_lightning as pl
from sklearn.metrics import confusion_matrix, classification_report

# Import CellType-NN modules
from celltype_nn.data.loader import load_anndata, split_anndata, create_dataloaders
from celltype_nn.preprocessing.preprocess import preprocess_rna
from celltype_nn.models.rna_classifier import RNAClassifier
from celltype_nn.training.lightning_module import CellTypeClassifierModule
from celltype_nn.evaluation.metrics import (
    calculate_metrics,
    plot_confusion_matrix,
    plot_per_class_metrics
)

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
pl.seed_everything(42)

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

print("✓ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"GPU available: {torch.cuda.is_available()}")

## 2. Generate Synthetic Single-Cell Data

We'll create a synthetic dataset with 5 cell types and realistic gene expression patterns.

In [None]:
def generate_synthetic_scrna_data(n_cells=2000, n_genes=2000, n_cell_types=5, seed=42):
    """
    Generate synthetic single-cell RNA-seq data.
    
    Creates realistic count data with:
    - Cell type-specific marker genes
    - Biological variability
    - Dropout events (zeros)
    """
    np.random.seed(seed)
    
    # Define cell types
    cell_types = [f'CellType_{i+1}' for i in range(n_cell_types)]
    cells_per_type = n_cells // n_cell_types
    
    # Create cell type labels
    cell_type_labels = []
    for ct in cell_types:
        cell_type_labels.extend([ct] * cells_per_type)
    
    # Generate base expression (negative binomial for realistic count data)
    base_expression = np.random.negative_binomial(5, 0.3, size=(n_cells, n_genes))
    
    # Add cell type-specific signatures
    markers_per_type = n_genes // (n_cell_types * 2)  # 10% of genes are markers
    
    for i, ct in enumerate(cell_types):
        start_cell = i * cells_per_type
        end_cell = (i + 1) * cells_per_type
        
        # Select marker genes for this cell type
        marker_start = i * markers_per_type
        marker_end = (i + 1) * markers_per_type
        
        # Increase expression of marker genes
        base_expression[start_cell:end_cell, marker_start:marker_end] += \
            np.random.negative_binomial(10, 0.2, 
                                       size=(cells_per_type, markers_per_type))
    
    # Add dropout (set some values to zero)
    dropout_mask = np.random.random((n_cells, n_genes)) < 0.3
    base_expression[dropout_mask] = 0
    
    # Create AnnData object
    adata = sc.AnnData(
        X=base_expression,
        obs=pd.DataFrame({
            'cell_type': cell_type_labels,
            'n_counts': base_expression.sum(axis=1),
            'n_genes': (base_expression > 0).sum(axis=1)
        }),
        var=pd.DataFrame(index=[f'Gene_{i}' for i in range(n_genes)])
    )
    
    return adata


# Generate synthetic dataset
print("Generating synthetic single-cell data...")
adata = generate_synthetic_scrna_data(
    n_cells=2000,
    n_genes=2000,
    n_cell_types=5,
    seed=42
)

print(f"\n✓ Generated dataset:")
print(f"  Cells: {adata.n_obs}")
print(f"  Genes: {adata.n_vars}")
print(f"  Cell types: {adata.obs['cell_type'].nunique()}")
print(f"\nCell type distribution:")
print(adata.obs['cell_type'].value_counts())

### Visualize Raw Data

In [None]:
# Quick visualization of the raw data
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Cell type distribution
adata.obs['cell_type'].value_counts().plot(kind='bar', ax=axes[0], color='steelblue')
axes[0].set_title('Cell Type Distribution', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Cell Type')
axes[0].set_ylabel('Number of Cells')
axes[0].tick_params(axis='x', rotation=45)

# Total counts per cell
axes[1].hist(adata.obs['n_counts'], bins=50, color='coral', alpha=0.7)
axes[1].set_title('Total Counts per Cell', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Total Counts')
axes[1].set_ylabel('Number of Cells')

# Genes per cell
axes[2].hist(adata.obs['n_genes'], bins=50, color='seagreen', alpha=0.7)
axes[2].set_title('Genes Detected per Cell', fontsize=14, fontweight='bold')
axes[2].set_xlabel('Number of Genes')
axes[2].set_ylabel('Number of Cells')

plt.tight_layout()
plt.show()

## 3. Data Preprocessing

Preprocessing steps:
1. Quality control filtering
2. Normalization (CPM + log transform)
3. Highly variable gene selection
4. Data splitting (train/val/test)

In [None]:
# Preprocess the data
print("Preprocessing data...")

adata_processed = preprocess_rna(
    adata.copy(),
    n_top_genes=500,  # Select 500 highly variable genes
    min_genes=50,
    min_cells=3,
    normalize=True,
    log_transform=True,
    target_sum=1e4
)

print(f"\n✓ Preprocessed dataset:")
print(f"  Cells after filtering: {adata_processed.n_obs}")
print(f"  Highly variable genes: {len(adata_processed.var_names)}")
print(f"  Data shape: {adata_processed.X.shape}")

In [None]:
# Split into train/validation/test sets
print("\nSplitting data...")

train_adata, val_adata, test_adata = split_anndata(
    adata_processed,
    train_size=0.7,
    val_size=0.15,
    test_size=0.15,
    stratify_key='cell_type',
    random_state=42
)

print(f"\n✓ Data split:")
print(f"  Training set: {train_adata.n_obs} cells")
print(f"  Validation set: {val_adata.n_obs} cells")
print(f"  Test set: {test_adata.n_obs} cells")

In [None]:
# Create PyTorch data loaders
print("\nCreating data loaders...")

train_loader = create_dataloaders(
    train_adata,
    label_key='cell_type',
    batch_size=64,
    shuffle=True
)

val_loader = create_dataloaders(
    val_adata,
    label_key='cell_type',
    batch_size=64,
    shuffle=False
)

test_loader = create_dataloaders(
    test_adata,
    label_key='cell_type',
    batch_size=64,
    shuffle=False
)

print(f"\n✓ Data loaders created!")
print(f"  Number of classes: {len(train_adata.obs['cell_type'].unique())}")
print(f"  Input features: {train_adata.n_vars}")
print(f"  Batch size: 64")

## 4. Model Training

We'll train a feedforward neural network with:
- 3 hidden layers (512, 256, 128)
- Batch normalization and dropout
- Adam optimizer
- Early stopping

In [None]:
# Get number of features and classes
input_dim = train_adata.n_vars
num_classes = len(train_adata.obs['cell_type'].unique())

# Create model
model = RNAClassifier(
    input_dim=input_dim,
    num_classes=num_classes,
    hidden_dims=[512, 256, 128],
    dropout_rate=0.3,
    activation='relu',
    batch_norm=True
)

# Wrap in Lightning module
lit_model = CellTypeClassifierModule(
    model=model,
    learning_rate=1e-3,
    optimizer='adam',
    scheduler='cosine',
    num_classes=num_classes
)

print(f"\n✓ Model created:")
print(f"  Input dimension: {input_dim}")
print(f"  Output classes: {num_classes}")
print(f"  Hidden layers: [512, 256, 128]")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Setup trainer
trainer = pl.Trainer(
    max_epochs=50,
    accelerator='auto',
    devices=1,
    callbacks=[
        pl.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            mode='min'
        ),
        pl.callbacks.ModelCheckpoint(
            monitor='val_accuracy',
            mode='max',
            save_top_k=1
        )
    ],
    enable_progress_bar=True,
    log_every_n_steps=10
)

# Train model
print("\nStarting training...\n")
trainer.fit(lit_model, train_loader, val_loader)
print("\n✓ Training complete!")

## 5. Model Evaluation

Evaluate the trained model on the test set.

In [None]:
# Test the model
print("Evaluating on test set...\n")
test_results = trainer.test(lit_model, test_loader)

print(f"\n✓ Test Results:")
print(f"  Test Accuracy: {test_results[0]['test_accuracy']:.4f}")
print(f"  Test Loss: {test_results[0]['test_loss']:.4f}")

In [None]:
# Get predictions for detailed analysis
lit_model.eval()
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for batch in test_loader:
        features, labels = batch
        logits = lit_model(features)
        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(probs, dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)

# Get class names
class_names = sorted(test_adata.obs['cell_type'].unique())

print("✓ Predictions generated!")

## 6. Visualizations

In [None]:
# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names,
    cbar_kws={'label': 'Count'}
)
plt.title('Confusion Matrix', fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Predicted Cell Type', fontsize=12)
plt.ylabel('True Cell Type', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

In [None]:
# Classification Report
print("\nClassification Report:\n")
print(classification_report(
    all_labels,
    all_preds,
    target_names=class_names,
    digits=4
))

In [None]:
# Per-class accuracy visualization
from sklearn.metrics import precision_recall_fscore_support

precision, recall, f1, support = precision_recall_fscore_support(
    all_labels, all_preds, average=None
)

# Create DataFrame for plotting
metrics_df = pd.DataFrame({
    'Cell Type': class_names,
    'Precision': precision,
    'Recall': recall,
    'F1-Score': f1
})

# Plot
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(class_names))
width = 0.25

ax.bar(x - width, metrics_df['Precision'], width, label='Precision', color='steelblue')
ax.bar(x, metrics_df['Recall'], width, label='Recall', color='coral')
ax.bar(x + width, metrics_df['F1-Score'], width, label='F1-Score', color='seagreen')

ax.set_xlabel('Cell Type', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Per-Class Performance Metrics', fontsize=16, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.legend()
ax.set_ylim([0, 1.1])
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Prediction confidence distribution
max_probs = all_probs.max(axis=1)
correct = (all_preds == all_labels)

fig, ax = plt.subplots(figsize=(10, 6))

ax.hist(max_probs[correct], bins=50, alpha=0.6, label='Correct', color='green')
ax.hist(max_probs[~correct], bins=50, alpha=0.6, label='Incorrect', color='red')

ax.set_xlabel('Prediction Confidence', fontsize=12)
ax.set_ylabel('Number of Predictions', fontsize=12)
ax.set_title('Prediction Confidence Distribution', fontsize=16, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nAverage confidence for correct predictions: {max_probs[correct].mean():.4f}")
print(f"Average confidence for incorrect predictions: {max_probs[~correct].mean():.4f}")

## 7. Summary

In this notebook, we:

1. ✓ Generated synthetic single-cell RNA-seq data with 5 cell types
2. ✓ Preprocessed the data (normalization, HVG selection)
3. ✓ Split data into train/validation/test sets
4. ✓ Trained a neural network classifier
5. ✓ Evaluated model performance with multiple metrics
6. ✓ Visualized results (confusion matrix, per-class metrics, confidence)

### Next Steps:

- Try with your own single-cell data
- Experiment with different architectures (attention, VAE)
- Test multi-modal models (RNA + protein + ATAC)
- Apply batch correction for multi-batch datasets
- Fine-tune hyperparameters

### Resources:

- [CellType-NN Documentation](https://github.com/salzcamino/test-celltype-nn)
- [Python README](https://github.com/salzcamino/test-celltype-nn/blob/main/README.md)
- [R Implementation](https://github.com/salzcamino/test-celltype-nn/blob/main/R_README.md)

In [None]:
# Optional: Save the trained model
# torch.save(lit_model.state_dict(), 'celltype_classifier.pth')
# print("✓ Model saved to celltype_classifier.pth")