# Tutorial 2: Batch Effect Correction with ComBat

**ScpTensor v0.1.0-beta**

This tutorial demonstrates batch effect correction using the ComBat method. Batch effects are technical variations that can obscure biological signals in single-cell proteomics data.

### What you will learn:

1. **Understanding Batch Effects** - How to detect and visualize batch effects
2. **Data Preparation** - Preprocessing before batch correction
3. **ComBat Correction** - Apply empirical Bayes batch correction
4. **Integration Verification** - Assess correction quality
5. **Before/After Visualization** - Compare results

---

## 1. Setup and Imports

In [None]:
# Core imports
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
from sklearn.metrics import pairwise_distances

# ScpTensor imports
from scptensor.core import ScpContainer, Assay, ScpMatrix
from scptensor.normalization import log_normalize
from scptensor.impute import knn
from scptensor.integration import combat
from scptensor.dim_reduction import pca
from scptensor.viz.recipes import embedding, qc_completeness

# Configure plotting
plt.style.use(["science", "no-latex"])
plt.rcParams['figure.dpi'] = 100
plt.rcParams['figure.figsize'] = (10, 6)

print("Imports successful!")

## 2. Generate Data with Strong Batch Effects

We'll create synthetic data with:
- **3 batches** with strong technical variation
- **2 biological groups** (distributed across batches)
- **Missing values** following realistic patterns

The batch effect will be intentionally strong so we can clearly see the correction.

In [None]:
def generate_batched_data(n_samples=300, n_features=500, n_batches=3):
    """
    Generate synthetic SCP data with strong batch effects.
    
    The batch effect includes:
    - Additive shift (different baseline per batch)
    - Multiplicative scaling (different variance per batch)
    """
    np.random.seed(42)
    
    # Create sample metadata
    samples_per_batch = n_samples // n_batches
    batches = []
    groups = []
    
    for i in range(n_batches):
        batch_name = f"Batch{i+1}"
        batches.extend([batch_name] * samples_per_batch)
        # Distribute groups across batches
        group_labels = ['GroupA'] * (samples_per_batch // 2) + ['GroupB'] * (samples_per_batch - samples_per_batch // 2)
        groups.extend(group_labels)
    
    obs = pl.DataFrame({
        'sample_id': [f'S{i+1:03d}' for i in range(n_samples)],
        'batch': batches,
        'group': groups
    })
    
    # Generate base expression (biological signal)
    X_bio = np.random.lognormal(mean=2, sigma=0.3, size=(n_samples, n_features))
    # Add group effect (first 50 proteins are higher in GroupB)
    group_mask = np.array([g == 'GroupB' for g in groups])
    X_bio[group_mask, :50] *= 1.8
    
    # Add strong batch effects
    X_batched = X_bio.copy()
    for i, batch in enumerate(['Batch1', 'Batch2', 'Batch3']):
        mask = np.array([b == batch for b in batches])
        # Different baseline per batch (additive)
        baseline_shift = [0, 0.5, 1.0][i]  # Batch3 has highest baseline
        # Different scaling per batch (multiplicative)
        scale_factor = [1.0, 1.5, 2.0][i]  # Batch3 has highest variance
        X_batched[mask] = X_batched[mask] * scale_factor + baseline_shift
    
    # Introduce missing values
    X_observed = X_batched.copy()
    M = np.zeros((n_samples, n_features), dtype=int)
    
    # LOD missing (15%)
    threshold = np.percentile(X_batched, 15)
    lod_mask = X_batched < threshold
    X_observed[lod_mask] = 0
    M[lod_mask] = 2
    
    # Random missing (25%)
    n_random_missing = int(n_samples * n_features * 0.25)
    valid_indices = np.argwhere(M == 0)
    random_indices = valid_indices[np.random.choice(len(valid_indices), size=n_random_missing, replace=False)]
    X_observed[random_indices[:, 0], random_indices[:, 1]] = 0
    M[random_indices[:, 0], random_indices[:, 1]] = 1
    
    # Create feature metadata
    var = pl.DataFrame({
        'protein_id': [f'P{i+1:04d}' for i in range(n_features)],
        '_index': [f'P{i+1:04d}' for i in range(n_features)]
    })
    
    matrix = ScpMatrix(X=X_observed, M=M)
    assay = Assay(var=var, layers={'raw': matrix}, feature_id_col='protein_id')
    
    container = ScpContainer(
        assays={'protein': assay},
        obs=obs.with_columns(pl.Series(name="_index", values=obs["sample_id"].to_list())),
        sample_id_col='sample_id'
    )
    
    return container

# Generate data
container = generate_batched_data(n_samples=300, n_features=500, n_batches=3)

print(f"Generated data with strong batch effects:")
print(f"  - Samples: {container.n_samples}")
print(f"  - Features: {container.n_features}")
print(f"  - Batches: {container.obs['batch'].unique().to_list()}")
print(f"  - Groups: {container.obs['group'].unique().to_list()}")
print(f"  - Missing rate: {np.mean(container.assays['protein'].layers['raw'].M != 0):.1%}")

## 3. Preprocessing Pipeline

Before batch correction, we need to:
1. Normalize the data (log transform)
2. Impute missing values (required for most batch correction methods)
3. Run PCA to visualize the batch effect

In [None]:
# Step 1: Log normalization
container = log_normalize(
    container,
    assay_name='protein',
    base_layer='raw',
    new_layer_name='log',
    base=2.0,
    offset=1.0
)
print("Step 1: Log normalization complete")

# Step 2: KNN imputation
container = knn(
    container,
    assay_name='protein',
    base_layer='log',
    new_layer_name='imputed',
    k=5
)
print("Step 2: KNN imputation complete")

# Step 3: PCA for visualization
container = pca(
    container,
    assay_name='protein',
    base_layer_name='imputed',
    new_assay_name='pca_before',
    n_components=10,
    center=True,
    scale=False
)
print("Step 3: PCA complete")

print(f"\nAvailable layers: {list(container.assays['protein'].layers.keys())}")
print(f"Available assays: {list(container.assays.keys())}")

## 4. Visualize Batch Effect (Before Correction)

Let's examine the data before batch correction. We expect to see samples clustering by batch rather than biological group.

In [None]:
# PCA visualization before correction
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

plt.sca(axes[0])
embedding(container, basis='pca_before', color='batch')
plt.title("Before ComBat: By Batch")

plt.sca(axes[1])
embedding(container, basis='pca_before', color='group')
plt.title("Before ComBat: By Biological Group")

plt.sca(axes[2])
# Create a combined label for better visualization
combined = container.obs['batch'] + '_' + container.obs['group']
container_copy = container
container_copy.obs = container_copy.obs.with_columns(pl.Series('batch_group', combined))
embedding(container_copy, basis='pca_before', color='batch_group')
plt.title("Before ComBat: Batch + Group")

plt.tight_layout()
plt.show()

### Quantify Batch Effect

We can quantify the batch effect by measuring how much variance is explained by batch vs biological group.

In [None]:
def quantify_batch_effect(container, pca_assay='pca_before'):
    """
    Quantify batch effect using PCA coordinates.
    Returns the ratio of within-batch distance to between-batch distance.
    """
    scores = container.assays[pca_assay].layers['scores'].X[:, :2]
    
    # Calculate within-batch distances (samples within same batch should be similar)
    batches = container.obs['batch'].to_numpy()
    unique_batches = np.unique(batches)
    
    within_distances = []
    for batch in unique_batches:
        batch_scores = scores[batches == batch]
        # Average pairwise distance within batch
        if len(batch_scores) > 1:
            dists = pairwise_distances(batch_scores)
            within_distances.append(np.mean(dists))
    
    # Calculate between-batch distances
    between_distances = []
    for i, batch1 in enumerate(unique_batches):
        for batch2 in unique_batches[i+1:]:
            scores1 = scores[batches == batch1]
            scores2 = scores[batches == batch2]
            dists = pairwise_distances(scores1, scores2)
            between_distances.append(np.mean(dists))
    
    within_mean = np.mean(within_distances)
    between_mean = np.mean(between_distances)
    
    return {
        'within_batch_distance': within_mean,
        'between_batch_distance': between_mean,
        'batch_severity': between_mean / within_mean if within_mean > 0 else float('inf')
    }

metrics_before = quantify_batch_effect(container, 'pca_before')
print("Batch Effect Quantification (Before Correction):")
print(f"  Within-batch distance:  {metrics_before['within_batch_distance']:.4f}")
print(f"  Between-batch distance: {metrics_before['between_batch_distance']:.4f}")
print(f"  Severity ratio:         {metrics_before['batch_severity']:.4f}")
print("\n(Severity > 1 indicates samples cluster by batch more than biology)")

## 5. Apply ComBat Batch Correction

**ComBat** uses an empirical Bayes approach to adjust for batch effects while preserving biological signals.

**How it works:**
1. Standardize data within each batch
2. Estimate batch effect parameters (additive and multiplicative)
3. Apply empirical Bayes shrinkage to stabilize estimates
4. Adjust data to remove batch effects

**Key parameters:**
- `batch_key`: Column name in obs containing batch labels
- `covariates`: Optional biological variables to preserve (e.g., 'group')

In [None]:
# Apply ComBat correction
# We use 'group' as a covariate to preserve the biological signal
container = combat(
    container,
    batch_key='batch',
    assay_name='protein',
    base_layer='imputed',
    new_layer_name='combat_corrected',
    covariates=['group']  # Preserve biological group differences
)

print("ComBat batch correction complete!")
print(f"\nAvailable layers: {list(container.assays['protein'].layers.keys())}")

# Verify the correction
X_before = container.assays['protein'].layers['imputed'].X
X_after = container.assays['protein'].layers['combat_corrected'].X

print(f"\nData shape unchanged: {X_before.shape} -> {X_after.shape}")
print(f"Data range before: [{np.min(X_before):.2f}, {np.max(X_before):.2f}]")
print(f"Data range after:  [{np.min(X_after):.2f}, {np.max(X_after):.2f}]")

## 6. Visualize Results (After Correction)

Now let's run PCA on the corrected data and visualize the results.

In [None]:
# Run PCA on corrected data
container = pca(
    container,
    assay_name='protein',
    base_layer_name='combat_corrected',
    new_assay_name='pca_after',
    n_components=10,
    center=True,
    scale=False
)

print("PCA on corrected data complete!")

# PCA visualization after correction
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

plt.sca(axes[0])
embedding(container, basis='pca_after', color='batch')
plt.title("After ComBat: By Batch")

plt.sca(axes[1])
embedding(container, basis='pca_after', color='group')
plt.title("After ComBat: By Biological Group")

plt.sca(axes[2])
container_copy.obs = container_copy.obs.with_columns(pl.Series('batch_group', combined))
embedding(container_copy, basis='pca_after', color='batch_group')
plt.title("After ComBat: Batch + Group")

plt.tight_layout()
plt.show()

## 7. Before/After Comparison

Let's compare the results side by side to clearly see the improvement.

In [None]:
# Side-by-side comparison
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Top row: Before correction
plt.sca(axes[0, 0])
embedding(container, basis='pca_before', color='batch')
plt.title("BEFORE: Colored by Batch")

plt.sca(axes[0, 1])
embedding(container, basis='pca_before', color='group')
plt.title("BEFORE: Colored by Group")

# Bottom row: After correction
plt.sca(axes[1, 0])
embedding(container, basis='pca_after', color='batch')
plt.title("AFTER: Colored by Batch")

plt.sca(axes[1, 1])
embedding(container, basis='pca_after', color='group')
plt.title("AFTER: Colored by Group")

plt.tight_layout()
plt.show()

## 8. Quantify Correction Quality

Let's measure how well the batch correction worked.

In [None]:
# Quantify batch effect after correction
metrics_after = quantify_batch_effect(container, 'pca_after')

print("Batch Effect Quantification:")
print("\n" + "="*50)
print(f"{'Metric':<30} {'Before':>12} {'After':>12}")
print("="*50)
print(f"{'Within-batch distance':<30} {metrics_before['within_batch_distance']:>12.4f} {metrics_after['within_batch_distance']:>12.4f}")
print(f"{'Between-batch distance':<30} {metrics_before['between_batch_distance']:>12.4f} {metrics_after['between_batch_distance']:>12.4f}")
print(f"{'Severity ratio':<30} {metrics_before['batch_severity']:>12.4f} {metrics_after['batch_severity']:>12.4f}")
print("="*50)

# Calculate improvement
severity_reduction = (metrics_before['batch_severity'] - metrics_after['batch_severity']) / metrics_before['batch_severity'] * 100
print(f"\nSeverity reduction: {severity_reduction:.1f}%")

if severity_reduction > 50:
    print("Result: Excellent batch correction!")
elif severity_reduction > 20:
    print("Result: Good batch correction.")
else:
    print("Result: Batch effect still present (consider more aggressive correction).")

## 9. Summary

### What we covered:

| Step | Description | Function |
|------|-------------|----------|
| 1. Data Generation | Created synthetic data with strong batch effects | `generate_batched_data()` |
| 2. Normalization | Log transform for variance stabilization | `log_normalize()` |
| 3. Imputation | Fill missing values | `knn()` |
| 4. Pre-PCA | Visualize batch effect before correction | `pca()` |
| 5. Batch Correction | ComBat empirical Bayes correction | `combat()` |
| 6. Post-PCA | Visualize corrected data | `pca()` |
| 7. Verification | Quantify correction quality | `quantify_batch_effect()` |

### Key Takeaways:

1. **Batch Effects Matter**: Technical variation can mask biological signals
2. **ComBat Parameters**:
   - `batch_key`: Required - identifies batch membership
   - `covariates`: Optional - preserves biological variables
   - Always impute missing values before correction

3. **Verification is Essential**:
   - Visual inspection (PCA plots)
   - Quantitative metrics (batch severity ratio)
   - Check that biological signals are preserved

4. **When to Use Batch Correction**:
   - Multiple experimental batches
   - Different instruments or protocols
   - Data collected at different times
   - When PCA shows batch clustering

### Next Steps:
- Try adjusting ComBat parameters (with/without covariates)
- Explore other integration methods (`harmony`, `scanorama`, `mnn`)
- Apply to your own single-cell proteomics data
- Check out Tutorial 1 for the basic workflow