# Tutorial 03: Imputation and Batch Integration

This tutorial covers missing value imputation and batch effect correction for single-cell proteomics data.

## Learning Objectives

By the end of this tutorial, you will:
- Understand missing value patterns in SCP data (MCAR vs MNAR)
- Apply various imputation methods (KNN, PPCA, SVD, MissForest)
- Detect and assess batch effects
- Apply batch correction methods (ComBat, Harmony, MNN)
- Evaluate the effectiveness of imputation and integration

---

## 1. Setup

Import required libraries and load an example dataset.

In [None]:
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

# Apply SciencePlots style
plt.style.use(["science", "no-latex"])

# Import ScpTensor
import scptensor
from scptensor.datasets import load_simulated_scrnaseq_like
from scptensor import (
    # Imputation
    knn,
    ppca,
    svd_impute,
    missforest,
    # Integration
    combat,
    harmony,
    mnn_correct,
    # QC and utilities
    count_mask_codes,
    MaskCode,
    calculate_qc_metrics,
    # Normalization
    log_normalize,
    zscore,
)

print(f"ScpTensor version: {scptensor.__version__}")

## 2. Load and Prepare Data

Let's load a dataset with batch effects and missing values.

In [None]:
# Load dataset
container = load_simulated_scrnaseq_like()

print(f"Dataset loaded: {container}")
print(f"Samples: {container.n_samples}")
print(f"Features: {container.assays['proteins'].n_features}")

# Check batch distribution
print("\nBatch distribution:")
print(container.obs.group_by("batch").count().sort("batch"))

# Check cell type distribution
print("\nCell type distribution:")
print(container.obs.group_by("cell_type").count().sort("cell_type"))

## 3. Understanding Missing Values

### 3.1 Missing Value Patterns

Single-cell proteomics data has two types of missing values:

- **MCAR (Missing Completely At Random)**: Technical dropout, unrelated to intensity
- **MNAR (Missing Not At Random)**: Limit of detection (LOD), related to low intensity

In [None]:
# Analyze missing value patterns
matrix = container.assays["proteins"].layers["raw"]
mask_counts = count_mask_codes(matrix.M)

print("Missing Value Analysis:")
print("=" * 50)
print(f"Total values: {matrix.M.size}")
print(f"Valid values (0): {mask_counts.get(0, 0)} ({mask_counts.get(0, 0)/matrix.M.size*100:.1f}%)")
print(f"MCAR/MBR (1): {mask_counts.get(1, 0)} ({mask_counts.get(1, 0)/matrix.M.size*100:.1f}%)")
print(f"MNAR/LOD (2): {mask_counts.get(2, 0)} ({mask_counts.get(2, 0)/matrix.M.size*100:.1f}%)")

# Overall missing rate
missing_rate = (matrix.M != 0).sum() / matrix.M.size
print(f"\nOverall missing rate: {missing_rate * 100:.1f}%")

### 3.2 Visualize Missing Value Patterns

In [None]:
# Visualize missing value patterns
fig = plt.figure(figsize=(14, 10))
gs = GridSpec(3, 3, figure=fig)

# Missing rate per sample
ax1 = fig.add_subplot(gs[0, 0])
sample_missing = (matrix.M != 0).sum(axis=1) / matrix.M.shape[1]
ax1.bar(range(len(sample_missing)), sample_missing, color='steelblue')
ax1.set_xlabel('Sample Index')
ax1.set_ylabel('Missing Rate')
ax1.set_title('Missing Rate per Sample')

# Missing rate per feature
ax2 = fig.add_subplot(gs[0, 1])
feature_missing = (matrix.M != 0).sum(axis=0) / matrix.M.shape[0]
ax2.hist(feature_missing, bins=30, color='coral', edgecolor='black')
ax2.set_xlabel('Missing Rate')
ax2.set_ylabel('Number of Features')
ax2.set_title('Distribution of Missing Rate per Feature')

# Missing by batch
ax3 = fig.add_subplot(gs[0, 2])
batch_missing = []
batch_labels = sorted(container.obs["batch"].unique().to_list())
for batch in batch_labels:
    batch_samples = container.obs.filter(pl.col("batch") == batch)["sample_id"].to_list()
    # Get indices for this batch
    batch_indices = [i for i, sid in enumerate(container.obs["sample_id"].to_list()) if sid in batch_samples]
    batch_mr = sample_missing[batch_indices].mean()
    batch_missing.append(batch_mr)

ax3.bar(range(len(batch_labels)), batch_missing, color='lightgreen')
ax3.set_xticks(range(len(batch_labels)))
ax3.set_xticklabels(batch_labels)
ax3.set_ylabel('Mean Missing Rate')
ax3.set_title('Missing Rate by Batch')

# Spy plot (missing pattern)
ax4 = fig.add_subplot(gs[1, :])
missing_mask = (matrix.M != 0).astype(float)
ax4.imshow(missing_mask[:100, :100], aspect='auto', cmap='Reds', interpolation='none')
ax4.set_xlabel('Feature Index')
ax4.set_ylabel('Sample Index')
ax4.set_title('Missing Value Pattern (Spy Plot) - First 100 samples x 100 features')

# Intensity vs missing probability
ax5 = fig.add_subplot(gs[2, :])
X_flat = matrix.X.flatten()
M_flat = matrix.M.flatten()

# Bin by intensity and compute missing rate
percentiles = np.percentile(X_flat[M_flat == 0], np.linspace(0, 100, 20))
missing_by_intensity = []
intensity_bins = []

for i in range(len(percentiles) - 1):
    mask = (X_flat >= percentiles[i]) & (X_flat < percentiles[i+1])
    if mask.sum() > 0:
        missing_by_intensity.append((M_flat[mask] != 0).mean())
        intensity_bins.append((percentiles[i] + percentiles[i+1]) / 2)

ax5.plot(intensity_bins, missing_by_intensity, 'o-', color='darkred')
ax5.set_xlabel('Intensity')
ax5.set_ylabel('Missing Rate')
ax5.set_title('Missing Rate vs Intensity (MNAR pattern)')

plt.tight_layout()
plt.savefig('tutorial_output/missing_patterns.png', dpi=300)
plt.show()

print("Missing pattern visualizations saved to: tutorial_output/missing_patterns.png")

## 4. Imputation Methods

### 4.1 Preprocessing: Log Normalization

First, let's apply log normalization to stabilize variance before imputation.

In [None]:
# Apply log normalization
container = log_normalize(
    container,
    assay_name="proteins",
    base_layer="raw",
    new_layer_name="log",
    base=2.0,
    offset=1.0,
)

print("Log normalization completed.")
print(f"Available layers: {list(container.assays['proteins'].layers.keys())}")

### 4.2 K-Nearest Neighbors (KNN) Imputation

KNN imputation fills missing values using the average of k nearest neighbors.

In [None]:
# Apply KNN imputation
print("Running KNN imputation (this may take a moment)...")

container = knn(
    container,
    assay_name="proteins",
    base_layer="log",
    new_layer_name="knn_imputed",
    k=10,  # Number of neighbors
)

print("KNN imputation completed.")

# Check the result
knn_matrix = container.assays["proteins"].layers["knn_imputed"]
knn_missing_rate = (knn_matrix.M != 0).sum() / knn_matrix.M.size
print(f"Missing rate after KNN imputation: {knn_missing_rate * 100:.2f}%")
print(f"Imputed values marked with mask code 5: {(knn_matrix.M == 5).sum()}")

### 4.3 Probabilistic PCA (PPCA) Imputation

In [None]:
# Apply PPCA imputation
print("Running PPCA imputation...")

container = ppca(
    container,
    assay_name="proteins",
    base_layer="log",
    new_layer_name="ppca_imputed",
    n_components=10,  # Number of principal components
)

print("PPCA imputation completed.")

# Check the result
ppca_matrix = container.assays["proteins"].layers["ppca_imputed"]
print(f"Missing rate after PPCA: {(ppca_matrix.M != 0).sum() / ppca_matrix.M.size * 100:.2f}%")

### 4.4 SVD Imputation

In [None]:
# Apply SVD imputation
print("Running SVD imputation...")

container = svd_impute(
    container,
    assay_name="proteins",
    base_layer="log",
    new_layer_name="svd_imputed",
    rank=10,  # Rank for SVD approximation
)

print("SVD imputation completed.")

# Check the result
svd_matrix = container.assays["proteins"].layers["svd_imputed"]
print(f"Missing rate after SVD: {(svd_matrix.M != 0).sum() / svd_matrix.M.size * 100:.2f}%")

### 4.5 MissForest Imputation (Random Forest)

MissForest uses an iterative random forest approach for imputation.

In [None]:
# Apply MissForest imputation
# Note: This is slower than other methods
print("Running MissForest imputation (this may take longer)...")

try:
    container = missforest(
        container,
        assay_name="proteins",
        base_layer="log",
        new_layer_name="mf_imputed",
        max_iter=10,  # Maximum iterations
        n_estimators=50,  # Number of trees
    )
    print("MissForest imputation completed.")
    
    # Check the result
    mf_matrix = container.assays["proteins"].layers["mf_imputed"]
    print(f"Missing rate after MissForest: {(mf_matrix.M != 0).sum() / mf_matrix.M.size * 100:.2f}%")
except Exception as e:
    print(f"MissForest imputation skipped: {e}")
    print("(MissForest requires scikit-learn to be installed)")

## 5. Comparing Imputation Methods

In [None]:
# Compare imputation results
imputed_layers = ['log', 'knn_imputed', 'ppca_imputed', 'svd_imputed']

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for i, layer in enumerate(imputed_layers):
    if layer in container.assays['proteins'].layers:
        X = container.assays['proteins'].layers[layer].X
        ax = axes[i // 2, i % 2]
        
        # Plot distribution of first 1000 values
        ax.hist(X.flatten()[:1000], bins=50, alpha=0.7, edgecolor='black')
        ax.set_title(f'{layer.replace("_", " ").title()}')
        ax.set_xlabel('Intensity')
        ax.set_ylabel('Frequency')

plt.tight_layout()
plt.savefig('tutorial_output/imputation_comparison.png', dpi=300)
plt.show()

print("Imputation comparison saved to: tutorial_output/imputation_comparison.png")

## 6. Batch Effect Detection

Before batch correction, let's detect batch effects using PCA.

In [None]:
from scptensor import pca

# Run PCA on imputed data to visualize batch effects
container = pca(
    container,
    assay_name="proteins",
    base_layer_name="knn_imputed",
    new_assay_name="pca_pre_correction",
    n_components=10,
)

# Get PCA coordinates
pc1 = container.assays["pca_pre_correction"].layers["scores"].X[:, 0]
pc2 = container.assays["pca_pre_correction"].layers["scores"].X[:, 1]

# Visualize batch effects
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Color by batch
batches = container.obs["batch"].to_numpy()
unique_batches = sorted(container.obs["batch"].unique().to_list())
colors_batch = plt.cm.tab10(np.linspace(0, 1, len(unique_batches)))

for batch, color in zip(unique_batches, colors_batch):
    mask = batches == batch
    axes[0].scatter(pc1[mask], pc2[mask], c=[color], label=batch, alpha=0.6, s=30)

axes[0].set_xlabel('PC1')
axes[0].set_ylabel('PC2')
axes[0].set_title('PCA Colored by Batch (Before Correction)')
axes[0].legend()

# Color by cell type
cell_types = container.obs["cell_type"].to_numpy()
unique_celltypes = sorted(container.obs["cell_type"].unique().to_list())
colors_ct = plt.cm.Set2(np.linspace(0, 1, len(unique_celltypes)))

for ct, color in zip(unique_celltypes, colors_ct):
    mask = cell_types == ct
    axes[1].scatter(pc1[mask], pc2[mask], c=[color], label=ct, alpha=0.6, s=30)

axes[1].set_xlabel('PC1')
axes[1].set_ylabel('PC2')
axes[1].set_title('PCA Colored by Cell Type (Before Correction)')
axes[1].legend()

plt.tight_layout()
plt.savefig('tutorial_output/pca_before_correction.png', dpi=300)
plt.show()

print("PCA visualization saved to: tutorial_output/pca_before_correction.png")

## 7. Batch Correction Methods

### 7.1 ComBat (Empirical Bayes)

ComBat is a popular batch correction method using empirical Bayes.

In [None]:
# Apply ComBat batch correction
print("Running ComBat batch correction...")

container = combat(
    container,
    batch_key="batch",
    assay_name="proteins",
    base_layer="knn_imputed",
    new_layer_name="combat_corrected",
)

print("ComBat correction completed.")

# Run PCA on corrected data
container = pca(
    container,
    assay_name="proteins",
    base_layer_name="combat_corrected",
    new_assay_name="pca_combat",
    n_components=10,
)

# Get PCA coordinates
pc1_combat = container.assays["pca_combat"].layers["scores"].X[:, 0]
pc2_combat = container.assays["pca_combat"].layers["scores"].X[:, 1]

print("PCA on ComBat-corrected data completed.")

### 7.2 Visualizing ComBat Results

In [None]:
# Visualize ComBat correction
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Before correction - by batch
for batch, color in zip(unique_batches, colors_batch):
    mask = batches == batch
    axes[0, 0].scatter(pc1[mask], pc2[mask], c=[color], label=batch, alpha=0.6, s=30)
axes[0, 0].set_xlabel('PC1')
axes[0, 0].set_ylabel('PC2')
axes[0, 0].set_title('Before ComBat: By Batch')
axes[0, 0].legend()

# Before correction - by cell type
for ct, color in zip(unique_celltypes, colors_ct):
    mask = cell_types == ct
    axes[0, 1].scatter(pc1[mask], pc2[mask], c=[color], label=ct, alpha=0.6, s=30)
axes[0, 1].set_xlabel('PC1')
axes[0, 1].set_ylabel('PC2')
axes[0, 1].set_title('Before ComBat: By Cell Type')
axes[0, 1].legend()

# After correction - by batch
for batch, color in zip(unique_batches, colors_batch):
    mask = batches == batch
    axes[1, 0].scatter(pc1_combat[mask], pc2_combat[mask], c=[color], label=batch, alpha=0.6, s=30)
axes[1, 0].set_xlabel('PC1')
axes[1, 0].set_ylabel('PC2')
axes[1, 0].set_title('After ComBat: By Batch')
axes[1, 0].legend()

# After correction - by cell type
for ct, color in zip(unique_celltypes, colors_ct):
    mask = cell_types == ct
    axes[1, 1].scatter(pc1_combat[mask], pc2_combat[mask], c=[color], label=ct, alpha=0.6, s=30)
axes[1, 1].set_xlabel('PC1')
axes[1, 1].set_ylabel('PC2')
axes[1, 1].set_title('After ComBat: By Cell Type')
axes[1, 1].legend()

plt.tight_layout()
plt.savefig('tutorial_output/combat_correction.png', dpi=300)
plt.show()

print("ComBat visualization saved to: tutorial_output/combat_correction.png")

### 7.3 MNN Correction (Mutual Nearest Neighbors)

In [None]:
# Apply MNN correction
print("Running MNN correction...")

try:
    container = mnn_correct(
        container,
        batch_key="batch",
        assay_name="proteins",
        base_layer="knn_imputed",
        new_layer_name="mnn_corrected",
        k=20,  # Number of nearest neighbors
        sigma=1.0,  # MNN kernel bandwidth
    )
    
    print("MNN correction completed.")
    
    # Run PCA on MNN corrected data
    container = pca(
        container,
        assay_name="proteins",
        base_layer_name="mnn_corrected",
        new_assay_name="pca_mnn",
        n_components=10,
    )
    
    pc1_mnn = container.assays["pca_mnn"].layers["scores"].X[:, 0]
    pc2_mnn = container.assays["pca_mnn"].layers["scores"].X[:, 1]
    
    print("PCA on MNN-corrected data completed.")
    
except Exception as e:
    print(f"MNN correction skipped: {e}")
    pc1_mnn, pc2_mnn = None, None

### 7.4 Harmony Integration

Harmony uses an iterative clustering approach for batch correction.

In [None]:
# Apply Harmony integration (requires harmonypy)
print("Running Harmony integration...")

try:
    container = harmony(
        container,
        batch_key="batch",
        assay_name="pca_pre_correction",
        base_layer="scores",
        new_layer_name="harmony_scores",
        lambda_val=1.0,  # Clustering penalty
        theta_val=2.0,  # Diversity penalty
    )
    
    print("Harmony integration completed.")
    
    # Get Harmony coordinates
    harm1 = container.assays["pca_pre_correction"].layers["harmony_scores"].X[:, 0]
    harm2 = container.assays["pca_pre_correction"].layers["harmony_scores"].X[:, 1]
    
except Exception as e:
    print(f"Harmony integration skipped: {e}")
    print("(Harmony requires harmonypy to be installed: pip install harmonypy)")
    harm1, harm2 = None, None

## 8. Quantifying Batch Effect Removal

Let's quantify how well each method removed batch effects.

In [None]:
# Function to calculate batch effect strength (PC1 variance explained by batch)
from scipy.stats import f_oneway

def batch_effect_strength(pc, batches):
    """
    Calculate batch effect strength using ANOVA F-statistic.
    Higher values indicate stronger batch effects.
    """
    groups = [pc[batches == b] for b in np.unique(batches)]
    f_stat, _ = f_oneway(*groups)
    return f_stat

# Calculate batch effect strength for each method
results = []
methods = []

# Before correction
be_before = batch_effect_strength(pc1, batches)
results.append(be_before)
methods.append('Before Correction')

# After ComBat
be_combat = batch_effect_strength(pc1_combat, batches)
results.append(be_combat)
methods.append('ComBat')

# After MNN
if pc1_mnn is not None:
    be_mnn = batch_effect_strength(pc1_mnn, batches)
    results.append(be_mnn)
    methods.append('MNN')

# Plot comparison
fig, ax = plt.subplots(figsize=(10, 6))
colors = ['coral', 'skyblue', 'lightgreen'][:len(methods)]
bars = ax.bar(methods, results, color=colors)
ax.set_ylabel('ANOVA F-Statistic (Batch Effect Strength)')
ax.set_title('Batch Effect Removal Comparison')

# Add value labels on bars
for bar, val in zip(bars, results):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(results)*0.01,
            f'{val:.1f}', ha='center', va='bottom')

plt.tight_layout()
plt.savefig('tutorial_output/batch_effect_comparison.png', dpi=300)
plt.show()

print("\nBatch effect strength (lower is better):")
for method, strength in zip(methods, results):
    print(f"  {method:20s}: {strength:8.2f}")

## Summary

In this tutorial, you learned:

### Imputation:
1. **KNN Imputation**: Uses k-nearest neighbors (`knn()`)
2. **PPCA Imputation**: Probabilistic PCA (`ppca()`)
3. **SVD Imputation**: Iterative SVD (`svd_impute()`)
4. **MissForest**: Random forest based (`missforest()`)

### Batch Correction:
1. **ComBat**: Empirical Bayes method (`combat()`)
2. **MNN**: Mutual Nearest Neighbors (`mnn_correct()`)
3. **Harmony**: Iterative clustering (`harmony()`)

### Best Practices:
- Always visualize missing value patterns before imputation
- Choose imputation method based on missingness pattern (MCAR vs MNAR)
- Apply log normalization before most imputation methods
- Use KNN for smaller datasets, PPCA/SVD for larger datasets
- Always verify batch correction results visually
- Preserve biological signal while removing batch effects

### Choosing a Method:
- **KNN**: Fast, works well for MCAR data
- **PPCA/SVD**: Good for large datasets, assumes linear structure
- **MissForest**: Best for complex patterns, but slow
- **ComBat**: Fast and effective for most cases
- **MNN**: Preserves local structure well
- **Harmony**: Good for integrating multiple batches

### Next Steps:
- **Tutorial 04**: Clustering and Visualization