# BioBatchNet Tutorial

**BioBatchNet** is a deep learning framework for batch effect correction in biological data, supporting both:
- **Imaging Mass Cytometry (IMC)** data
- **Single-cell RNA-seq (scRNA-seq)** data

This tutorial demonstrates three usage patterns:
1. **Quick Start** - Using the simple API with default parameters
2. **Custom Loss Weights** - Fine-tuning training objectives
3. **Direct Model Access** - Full control over architecture and training

---

## 1. Installation and Setup

In [None]:
# Install BioBatchNet (uncomment if needed)
# !pip install biobatchnet

import numpy as np
import pandas as pd
import torch
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt
import os

from biobatchnet import correct_batch_effects, IMCVAE

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 2. Load Example Data

We'll use the **IMMUcan IMC dataset** as an example. The data will be downloaded from Google Drive.

In [None]:
import gdown

# Download data from Google Drive
FILE_ID = "1S0AgcT0J7tnRtnnshRzAkECwhse0mTrK"
FILENAME = "IMMUcan_batch.h5ad"

if not os.path.exists(FILENAME):
    print("Downloading IMMUcan dataset...")
    gdown.download(id=FILE_ID, output=FILENAME, quiet=False)
else:
    print(f"{FILENAME} already exists.")

# Load data
adata = ad.read_h5ad(FILENAME)
X = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X

# Extract batch labels
unique_batches = np.unique(adata.obs['BATCH'].values)
batch_to_int = {batch: i for i, batch in enumerate(unique_batches)}
batch_labels = np.array([batch_to_int[b] for b in adata.obs['BATCH'].values])

print(f"✓ Data loaded: {X.shape[0]:,} cells, {X.shape[1]} features, {len(unique_batches)} batches")

In [None]:
def plot_umap(adata, use_rep="X", title=None, color_by=['BATCH', 'celltype']):
    """Plot UMAP visualization colored by batch and cell type"""
    adata_vis = adata.copy()
    adata_vis.obs['BATCH'] = adata_vis.obs['BATCH'].astype("category")
    sc.pp.neighbors(adata_vis, use_rep=use_rep)
    sc.tl.umap(adata_vis)
    sc.pl.umap(adata_vis, color=color_by, title=title, frameon=False, wspace=0.5)

# Visualize original data
plot_umap(adata, title="Original Data (with batch effects)")

---

## 3. Method 1: Quick Start with Simple API

The easiest way to use BioBatchNet - just provide your data and batch labels.

In [None]:
# Run batch correction with default parameters
bio_embeddings, batch_embeddings = correct_batch_effects(
    data=pd.DataFrame(X),
    batch_info=pd.DataFrame({'BATCH': batch_labels}),
    batch_key='BATCH',
    data_type='imc',        # 'imc' or 'scrna'
    latent_dim=20,          # Latent space dimension
    epochs=100,             # Training epochs
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

print(f"✓ Biological embeddings: {bio_embeddings.shape}")
print(f"✓ Batch embeddings: {batch_embeddings.shape}")

# Add corrected embeddings to AnnData and visualize
adata.obsm['X_biobatchnet'] = bio_embeddings
plot_umap(adata, use_rep="X_biobatchnet", title="After BioBatchNet Correction")

---

## 4. Method 2: Custom Loss Weights

Fine-tune the training objectives by adjusting loss weights.

In [None]:
# Define custom loss weights
custom_loss_weights = {
    'recon_loss': 10,       # Reconstruction loss (default: 10)
    'discriminator': 0.1,   # Batch mixing (default: 0.3, lower = more mixing)
    'classifier': 1。0,      # Batch retention (default: 1)
    'kl_loss_1': 0.005,     # KL divergence for bio encoder (default: 0.005)
    'kl_loss_2': 0.1,       # KL divergence for batch encoder (default: 0.1)
    'ortho_loss': 0.01     # Orthogonality constraint (default: 0.01)
}

# Run with custom weights
bio_embeddings_custom, _ = correct_batch_effects(
    data=pd.DataFrame(X),
    batch_info=pd.DataFrame({'BATCH': batch_labels}),
    batch_key='BATCH',
    data_type='imc',
    latent_dim=20,
    epochs=100,
    loss_weights=custom_loss_weights,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

print("✓ Training complete with custom loss weights")

# Visualize
adata.obsm['X_custom'] = bio_embeddings_custom
plot_umap(adata, use_rep="X_custom", title="Custom Loss Weights")

---

## 5. Method 3: Direct Model Usage (Advanced)

For full control, use the model classes directly to customize architecture.

In [None]:
# Create IMCVAE model with custom architecture
n_cells, n_features = X.shape
n_batches = len(unique_batches)

model = IMCVAE(
    in_sz=n_features,
    out_sz=n_features,
    latent_sz=20,
    num_batch=n_batches,
    bio_encoder_hidden_layers=[256, 512, 512],        # Custom bio encoder
    batch_encoder_hidden_layers=[128, 256],           # Custom batch encoder
    decoder_hidden_layers=[512, 512, 256],            # Custom decoder
    batch_classifier_layers_power=[500, 2000, 2000],  # Discriminator
    batch_classifier_layers_weak=[128]                # Batch classifier
)

# Train model with custom loss weights
custom_loss = {
    'recon_loss': 10,
    'discriminator': 0.1,
    'classifier': 1.0,
    'kl_loss_1': 0.005,
    'kl_loss_2': 0.1,
    'ortho_loss': 0.01
}

model.fit(
    data=X,
    batch_info=batch_labels,
    epochs=100,
    lr=1e-4,
    batch_size=256,
    loss_weights=custom_loss,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

print("✓ Training complete")

# Extract embeddings and visualize
bio_embeddings_direct = model.get_bio_embeddings(X)
print(f"✓ Bio embeddings: {bio_embeddings_direct.shape}")

adata.obsm['X_direct'] = bio_embeddings_direct
plot_umap(adata, use_rep="X_direct", title="Direct Model Usage")

---

## 6. Summary and Tips

### Three Usage Patterns

| Method | Use Case | Complexity |
|--------|----------|------------|
| **Simple API** | Quick batch correction with defaults | Low |
| **Custom Loss** | Fine-tune training objectives | Medium |
| **Direct Model** | Full control over architecture | High |

### Parameter Tuning Tips

**Loss Weights:**
- `recon_loss` (10): Higher = better reconstruction quality
- `discriminator` (0.3): Lower = stronger batch mixing (use 0.1 for many batches)
- `classifier` (1): Ensures batch information is preserved
- `kl_loss_1` (0.005): Regularization for bio encoder
- `kl_loss_2` (0.1): Regularization for batch encoder
- `ortho_loss` (0.01): Encourages orthogonal bio/batch representations

**Architecture:**
- `latent_dim`: 20 for IMC
- Increase hidden layers for complex datasets
- Reduce model size if memory is limited

**Training:**
- Start with 100 epochs, increase if needed
- Use GPU if available for faster training
- Reduce `batch_size` if out of memory

### Resources

- **GitHub**: [https://github.com/Manchester-HealthAI/BioBatchNet](https://github.com/Manchester-HealthAI/BioBatchNet)
- **Documentation**: See `README.md` and `USAGE.md`
- **Paper**: [Link to paper when published]

---

**Questions or issues?** Please open an issue on GitHub!