# Tutorial 09: Advanced Differential Expression Analysis

This tutorial covers advanced statistical methods for differential expression (DE) analysis in single-cell proteomics data. We explore methods designed for specific data characteristics:

## Learning Objectives

By the end of this tutorial, you will:
- Understand when to use count-based models vs non-parametric methods
- Apply VOOM + limma analysis for count data
- Use limma-trend for mean-variance dependency
- Perform DESeq2-like negative binomial analysis
- Apply Wilcoxon rank-sum and Brunner-Munzel tests
- Handle paired samples with non-parametric tests
- Compare and validate results across multiple methods

---

## Methods Covered

### Count-Based Models:
- **VOOM**: Precision weights for RNA-seq like data
- **limma-trend**: Empirical Bayes with trend correction
- **DESeq2**: Negative binomial GLM

### Non-Parametric Methods:
- **Wilcoxon**: Rank-sum test (paired/unpaired)
- **Brunner-Munzel**: Heteroscedastic robust test

## 1. Setup

Import required libraries and configure the plotting environment.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import polars as pl

# Apply SciencePlots style for publication-quality figures
plt.style.use(["science", "no-latex"])
plt.rcParams["figure.dpi"] = 100

# Import ScpTensor
import scptensor
from scptensor import (
    create_test_container,
)
from scptensor.diff_expr import (
    diff_expr_brunner_munzel,
    diff_expr_deseq2,
    diff_expr_limma_trend,
    diff_expr_voom,
    diff_expr_wilcoxon,
)

print(f"ScpTensor version: {scptensor.__version__}")

## 2. Method Selection Guide

Choosing the right statistical method is crucial for valid results:

| Method | Data Type | Key Assumption | Best For |
|--------|-----------|----------------|----------|
| **VOOM** | Counts | Mean-variance trend | Small samples, RNA-seq like |
| **limma-trend** | Counts | Mean-variance dependency | Trended variance |
| **DESeq2** | Counts | Negative binomial | Over-dispersed counts |
| **Wilcoxon** | Any | None (non-parametric) | Non-normal data |
| **Brunner-Munzel** | Any | None (non-parametric) | Unequal variances |

### Quick Decision Flow:
1. **Count data with many samples (>10 per group)**: DESeq2
2. **Count data with few samples**: VOOM or limma-trend
3. **Continuous/normalized data**: Wilcoxon or Brunner-Munzel
4. **Unequal variances**: Brunner-Munzel > Wilcoxon
5. **Paired samples**: Wilcoxon with `paired=True`

## 3. Create Example Dataset

For this tutorial, we'll create a simulated dataset with count-like characteristics that demonstrates the differences between methods.

In [None]:
# Create simulated single-cell proteomics data
# with two groups and differential expression
np.random.seed(42)

# Parameters
n_samples_a = 20
n_samples_b = 20
n_features = 200

# Generate count-like data (negative binomial distribution)
# Group A: baseline expression
counts_a = np.random.negative_binomial(10, 0.5, size=(n_samples_a, n_features))

# Group B: differential expression for some features
counts_b = np.random.negative_binomial(10, 0.5, size=(n_samples_b, n_features))

# Add differential expression:
# First 20 features: up-regulated in group A
# Features 20-40: down-regulated in group A
# Features 40-60: slightly up-regulated
# Remaining: no change
counts_a[:, :20] = np.random.negative_binomial(20, 0.4, size=(n_samples_a, 20))  # Up in A
counts_b[:, :20] = np.random.negative_binomial(5, 0.6, size=(n_samples_b, 20))  # Down in B

counts_a[:, 20:40] = np.random.negative_binomial(5, 0.6, size=(n_samples_a, 20))  # Down in A
counts_b[:, 20:40] = np.random.negative_binomial(20, 0.4, size=(n_samples_b, 20))  # Up in B

# Combine data
X_counts = np.vstack([counts_a, counts_b])

# Create metadata
sample_ids = [f"sample_{i}" for i in range(n_samples_a + n_samples_b)]
groups = ["A"] * n_samples_a + ["B"] * n_samples_b
batches = np.random.choice(["batch1", "batch2"], size=n_samples_a + n_samples_b)

# Create feature names
feature_ids = [f"protein_{i}" for i in range(n_features)]

# Create container
container = create_test_container(
    n_samples=n_samples_a + n_samples_b,
    n_features=n_features,
    sparse=False,
)

# Replace with our count data
container.assays["proteins"].layers["raw"].X = X_counts
container.obs = container.obs.with_columns(
    [
        pl.Series("_index", sample_ids),
        pl.Series("group", groups),
        pl.Series("batch", batches),
    ]
)

# Update feature IDs
container.assays["proteins"].var = container.assays["proteins"].var.with_columns(
    [pl.Series("feature_id", feature_ids)]
)
container.assays["proteins"].feature_id_col = "feature_id"

print("Dataset created:")
print(f"  Samples: {container.n_samples}")
print(f"  Features: {container.assays['proteins'].n_features}")
print(f"  Group A: {n_samples_a} samples")
print(f"  Group B: {n_samples_b} samples")
print("\nGroup distribution:")
print(container.obs.group_by("group").count())

### 3.1 Explore Data Characteristics

Before choosing a method, let's examine the data properties.

In [None]:
# Visualize mean-variance relationship
means = np.mean(X_counts, axis=0)
variances = np.var(X_counts, axis=0, ddof=1)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Mean-variance plot
axes[0].scatter(means, variances, alpha=0.5, s=20)
axes[0].set_xlabel("Mean Expression")
axes[0].set_ylabel("Variance")
axes[0].set_title("Mean-Variance Relationship")
axes[0].loglog()

# Add theoretical lines
x_line = np.linspace(min(means), max(means), 100)
axes[0].plot(x_line, x_line, "r--", label="Poisson (var = mean)", alpha=0.7)
axes[0].plot(x_line, x_line**1.5, "g--", label="Over-dispersed", alpha=0.7)
axes[0].legend()

# Count distribution histogram
axes[1].hist(X_counts.flatten(), bins=50, edgecolor="black", alpha=0.7)
axes[1].set_xlabel("Count Value")
axes[1].set_ylabel("Frequency")
axes[1].set_title("Count Distribution")
axes[1].set_yscale("log")

plt.tight_layout()
plt.savefig("tutorial_output/data_characteristics.png", dpi=300)
plt.show()

print("Data characteristics:")
print(
    f"  Mean-variance correlation: {np.corrcoef(np.log10(means + 1), np.log10(variances + 1))[0, 1]:.3f}"
)
print(f"  Sparsity: {(X_counts == 0).sum() / X_counts.size * 100:.1f}% zeros")
print(f"  Median count: {np.median(X_counts):.1f}")

## 4. Count-Based Models

Count-based models are designed for data where:
- Values are non-negative integers (counts)
- Variance depends on mean (heteroscedastic)
- Distribution may be over-dispersed

### 4.1 VOOM + limma Analysis

VOOM (mean-variance modelling at the observational level) transforms counts to log2-CPM with precision weights, then applies limma's empirical Bayes moderation.

In [None]:
# Run VOOM analysis
result_voom = diff_expr_voom(
    container=container,
    assay_name="proteins",
    layer="raw",
    groupby="group",
    group1="A",
    group2="B",
    min_count=10,
    normalize="tmm",
)

print("VOOM Analysis Results:")
print(f"  Method: {result_voom.method}")
print(f"  Features tested: {np.sum(~np.isnan(result_voom.p_values))}")
print(f"  Significant (FDR < 0.05): {np.sum(result_voom.p_values_adj < 0.05)}")
print(
    f"  Up-regulated in A: {np.sum((result_voom.p_values_adj < 0.05) & (result_voom.log2_fc > 1))}"
)
print(
    f"  Down-regulated in A: {np.sum((result_voom.p_values_adj < 0.05) & (result_voom.log2_fc < -1))}"
)

### 4.2 limma-trend Analysis

limma-trend applies empirical Bayes variance shrinkage with trend correction for the mean-variance relationship.

In [None]:
# Run limma-trend analysis
result_trend = diff_expr_limma_trend(
    container=container,
    assay_name="proteins",
    layer="raw",
    groupby="group",
    group1="A",
    group2="B",
    trend=True,
    robust=True,
)

print("limma-trend Analysis Results:")
print(f"  Method: {result_trend.method}")
print(f"  Features tested: {np.sum(~np.isnan(result_trend.p_values))}")
print(f"  Significant (FDR < 0.05): {np.sum(result_trend.p_values_adj < 0.05)}")
print(
    f"  Up-regulated in A: {np.sum((result_trend.p_values_adj < 0.05) & (result_trend.log2_fc > 1))}"
)
print(
    f"  Down-regulated in A: {np.sum((result_trend.p_values_adj < 0.05) & (result_trend.log2_fc < -1))}"
)

### 4.3 DESeq2-like Analysis

DESeq2 uses a negative binomial generalized linear model, accounting for overdispersion common in count data.

In [None]:
# Run DESeq2-like analysis
result_deseq2 = diff_expr_deseq2(
    container=container,
    assay_name="proteins",
    layer="raw",
    groupby="group",
    group1="A",
    group2="B",
    fit_type="parametric",
    test="wald",
    min_count=10,
)

print("DESeq2 Analysis Results:")
print(f"  Method: {result_deseq2.method}")
print(f"  Features tested: {np.sum(~np.isnan(result_deseq2.p_values))}")
print(f"  Significant (FDR < 0.05): {np.sum(result_deseq2.p_values_adj < 0.05)}")
print(
    f"  Up-regulated in A: {np.sum((result_deseq2.p_values_adj < 0.05) & (result_deseq2.log2_fc > 1))}"
)
print(
    f"  Down-regulated in A: {np.sum((result_deseq2.p_values_adj < 0.05) & (result_deseq2.log2_fc < -1))}"
)

### 4.4 Comparing Count-Based Methods

In [None]:
# Compare results across methods
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# P-value comparison
valid = ~np.isnan(result_voom.p_values_adj) & ~np.isnan(result_deseq2.p_values_adj)
axes[0, 0].scatter(
    result_voom.p_values_adj[valid], result_deseq2.p_values_adj[valid], alpha=0.5, s=15
)
axes[0, 0].plot([0, 1], [0, 1], "r--", linewidth=1)
axes[0, 0].set_xlabel("VOOM Adjusted P-value")
axes[0, 0].set_ylabel("DESeq2 Adjusted P-value")
axes[0, 0].set_title("VOOM vs DESeq2 P-values")

# Log2 FC comparison
axes[0, 1].scatter(result_voom.log2_fc, result_deseq2.log2_fc, alpha=0.5, s=15)
min_fc = min(result_voom.log2_fc.min(), result_deseq2.log2_fc.min())
max_fc = max(result_voom.log2_fc.max(), result_deseq2.log2_fc.max())
axes[0, 1].plot([min_fc, max_fc], [min_fc, max_fc], "r--", linewidth=1)
axes[0, 1].set_xlabel("VOOM log2 FC")
axes[0, 1].set_ylabel("DESeq2 log2 FC")
axes[0, 1].set_title("VOOM vs DESeq2 Log2 Fold Change")

# Venn diagram of significant features
from matplotlib_venn import venn3

sig_voom = set(np.where(result_voom.p_values_adj < 0.05)[0])
sig_trend = set(np.where(result_trend.p_values_adj < 0.05)[0])
sig_deseq2 = set(np.where(result_deseq2.p_values_adj < 0.05)[0])

venn3([sig_voom, sig_trend, sig_deseq2], ("VOOM", "limma-trend", "DESeq2"), ax=axes[1, 0])
axes[1, 0].set_title("Overlap of Significant Features (FDR < 0.05)")

# Method concordance
methods = ["VOOM", "limma-trend", "DESeq2"]
n_sig = [len(sig_voom), len(sig_trend), len(sig_deseq2)]
bars = axes[1, 1].bar(methods, n_sig, color=["#1f77b4", "#ff7f0e", "#2ca02c"])
axes[1, 1].set_ylabel("Number of Significant Features")
axes[1, 1].set_title("Significant Features by Method")
axes[1, 1].set_ylim(0, max(n_sig) * 1.2)

# Add count labels on bars
for bar, count in zip(bars, n_sig, strict=False):
    axes[1, 1].text(
        bar.get_x() + bar.get_width() / 2,
        bar.get_height() + 1,
        str(count),
        ha="center",
        va="bottom",
    )

plt.tight_layout()
plt.savefig("tutorial_output/count_methods_comparison.png", dpi=300)
plt.show()

print("\nConcordance Analysis:")
print(f"  VOOM & limma-trend overlap: {len(sig_voom & sig_trend)} features")
print(f"  VOOM & DESeq2 overlap: {len(sig_voom & sig_deseq2)} features")
print(f"  All three methods: {len(sig_voom & sig_trend & sig_deseq2)} features")

## 5. Non-Parametric Methods

Non-parametric methods make no distributional assumptions and are robust to:
- Non-normal distributions
- Outliers
- Unequal variances

### 5.1 Wilcoxon Rank-Sum Test

The Wilcoxon rank-sum test (Mann-Whitney U) tests whether samples from one group tend to have higher values than the other.

In [None]:
# Run Wilcoxon rank-sum test
result_wilcoxon = diff_expr_wilcoxon(
    container=container,
    assay_name="proteins",
    layer="raw",
    groupby="group",
    group1="A",
    group2="B",
    paired=False,
    alternative="two-sided",
    min_samples_per_group=3,
    missing_strategy="ignore",
)

print("Wilcoxon Rank-Sum Test Results:")
print(f"  Method: {result_wilcoxon.method}")
print(f"  Features tested: {np.sum(~np.isnan(result_wilcoxon.p_values))}")
print(f"  Significant (FDR < 0.05): {np.sum(result_wilcoxon.p_values_adj < 0.05)}")
print(
    f"  Up-regulated in A: {np.sum((result_wilcoxon.p_values_adj < 0.05) & (result_wilcoxon.log2_fc > 1))}"
)
print(
    f"  Down-regulated in A: {np.sum((result_wilcoxon.p_values_adj < 0.05) & (result_wilcoxon.log2_fc < -1))}"
)

### 5.2 Brunner-Munzel Test

The Brunner-Munzel test is robust to unequal variances (heteroscedasticity) and handles unequal sample sizes well.

In [None]:
# Run Brunner-Munzel test
result_bm = diff_expr_brunner_munzel(
    container=container,
    assay_name="proteins",
    layer="raw",
    groupby="group",
    group1="A",
    group2="B",
    alternative="two-sided",
    min_samples_per_group=3,
    missing_strategy="ignore",
)

print("Brunner-Munzel Test Results:")
print(f"  Method: {result_bm.method}")
print(f"  Features tested: {np.sum(~np.isnan(result_bm.p_values))}")
print(f"  Significant (FDR < 0.05): {np.sum(result_bm.p_values_adj < 0.05)}")
print(f"  Up-regulated in A: {np.sum((result_bm.p_values_adj < 0.05) & (result_bm.log2_fc > 1))}")
print(
    f"  Down-regulated in A: {np.sum((result_bm.p_values_adj < 0.05) & (result_bm.log2_fc < -1))}"
)

# Interpret relative effects
print("\nRelative Effects (pHat):")
p_hat = result_bm.effect_sizes
print(f"  Mean pHat: {np.nanmean(p_hat):.3f}")
print("  pHat = 0.5: Stochastic equality (no difference)")
print("  pHat > 0.5: Group A tends to have larger values")
print("  pHat < 0.5: Group B tends to have larger values")

### 5.3 Comparing Non-Parametric Methods

In [None]:
# Compare Wilcoxon and Brunner-Munzel
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# P-value comparison
valid_np = ~np.isnan(result_wilcoxon.p_values_adj) & ~np.isnan(result_bm.p_values_adj)
axes[0].scatter(
    result_wilcoxon.p_values_adj[valid_np], result_bm.p_values_adj[valid_np], alpha=0.5, s=15
)
axes[0].plot([0, 1], [0, 1], "r--", linewidth=1)
axes[0].set_xlabel("Wilcoxon Adjusted P-value")
axes[0].set_ylabel("Brunner-Munzel Adjusted P-value")
axes[0].set_title("Wilcoxon vs Brunner-Munzel P-values")

# Log2 FC comparison
axes[1].scatter(result_wilcoxon.log2_fc, result_bm.log2_fc, alpha=0.5, s=15)
min_fc = min(result_wilcoxon.log2_fc.min(), result_bm.log2_fc.min())
max_fc = max(result_wilcoxon.log2_fc.max(), result_bm.log2_fc.max())
axes[1].plot([min_fc, max_fc], [min_fc, max_fc], "r--", linewidth=1)
axes[1].set_xlabel("Wilcoxon log2 FC")
axes[1].set_ylabel("Brunner-Munzel log2 FC")
axes[1].set_title("Wilcoxon vs Brunner-Munzel Log2 Fold Change")

# Significant feature comparison
sig_wilcoxon = set(np.where(result_wilcoxon.p_values_adj < 0.05)[0])
sig_bm = set(np.where(result_bm.p_values_adj < 0.05)[0])

from matplotlib_venn import venn2

venn2([sig_wilcoxon, sig_bm], ("Wilcoxon", "Brunner-Munzel"), ax=axes[2])
axes[2].set_title("Overlap of Significant Features (FDR < 0.05)")

plt.tight_layout()
plt.savefig("tutorial_output/nonparametric_comparison.png", dpi=300)
plt.show()

print("\nNon-Parametric Method Concordance:")
print(f"  Wilcoxon only: {len(sig_wilcoxon - sig_bm)} features")
print(f"  Brunner-Munzel only: {len(sig_bm - sig_wilcoxon)} features")
print(f"  Both methods: {len(sig_wilcoxon & sig_bm)} features")

## 6. Paired Sample Analysis

For paired/matched samples (e.g., before-after treatment), use the paired Wilcoxon test. Let's create a paired dataset for demonstration.

In [None]:
# Create paired dataset
n_pairs = 15

# Simulate matched pairs (e.g., same patient before/after treatment)
# Each pair has one control and one treatment sample
paired_counts_control = np.random.negative_binomial(15, 0.5, size=(n_pairs, n_features))
paired_counts_treatment = paired_counts_control.copy()

# Add treatment effect to some features
paired_counts_treatment[:, :15] += np.random.poisson(10, size=(n_pairs, 15))
paired_counts_treatment[:, 15:30] -= np.random.poisson(5, size=(n_pairs, 15))
paired_counts_treatment = np.maximum(paired_counts_treatment, 0)

# Combine data
paired_X = np.vstack([paired_counts_control, paired_counts_treatment])

# Create metadata with pair IDs
paired_sample_ids = []
paired_groups = []
pair_ids = []

for i in range(n_pairs):
    paired_sample_ids.append(f"pair_{i}_control")
    paired_groups.append("control")
    pair_ids.append(f"pair_{i}")

    paired_sample_ids.append(f"pair_{i}_treatment")
    paired_groups.append("treatment")
    pair_ids.append(f"pair_{i}")

# Create paired container
paired_container = create_test_container(
    n_samples=n_pairs * 2,
    n_features=n_features,
    sparse=False,
)

paired_container.assays["proteins"].layers["raw"].X = paired_X
paired_container.obs = paired_container.obs.with_columns(
    [
        pl.Series("_index", paired_sample_ids),
        pl.Series("group", paired_groups),
        pl.Series("pair_id", pair_ids),
    ]
)

paired_container.assays["proteins"].var = paired_container.assays["proteins"].var.with_columns(
    [pl.Series("feature_id", feature_ids)]
)
paired_container.assays["proteins"].feature_id_col = "feature_id"

print("Paired dataset created:")
print(f"  Total samples: {paired_container.n_samples}")
print(f"  Number of pairs: {n_pairs}")
print(f"  Control: {n_pairs} samples")
print(f"  Treatment: {n_pairs} samples")

### 6.1 Paired Wilcoxon Test

In [None]:
# Run paired Wilcoxon test
result_paired = diff_expr_wilcoxon(
    container=paired_container,
    assay_name="proteins",
    layer="raw",
    groupby="group",
    group1="treatment",
    group2="control",
    paired=True,
    pair_id_col="pair_id",
    alternative="two-sided",
    min_samples_per_group=3,
    missing_strategy="ignore",
)

print("Paired Wilcoxon Test Results:")
print(f"  Method: {result_paired.method}")
print(f"  Number of pairs analyzed: {result_paired.params.get('n_pairs', 'N/A')}")
print(f"  Features tested: {np.sum(~np.isnan(result_paired.p_values))}")
print(f"  Significant (FDR < 0.05): {np.sum(result_paired.p_values_adj < 0.05)}")
print(
    f"  Up in treatment: {np.sum((result_paired.p_values_adj < 0.05) & (result_paired.log2_fc > 1))}"
)
print(
    f"  Down in treatment: {np.sum((result_paired.p_values_adj < 0.05) & (result_paired.log2_fc < -1))}"
)

### 6.2 Paired vs Unpaired Comparison

In [None]:
# Compare paired vs unpaired analysis
result_unpaired = diff_expr_wilcoxon(
    container=paired_container,
    assay_name="proteins",
    layer="raw",
    groupby="group",
    group1="treatment",
    group2="control",
    paired=False,
    alternative="two-sided",
    missing_strategy="ignore",
)

# Visualize comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# P-value comparison
valid_pair = ~np.isnan(result_paired.p_values_adj) & ~np.isnan(result_unpaired.p_values_adj)
axes[0].scatter(
    result_unpaired.p_values_adj[valid_pair],
    result_paired.p_values_adj[valid_pair],
    alpha=0.5,
    s=15,
)
axes[0].plot([0, 1], [0, 1], "r--", linewidth=1)
axes[0].set_xlabel("Unpaired Adjusted P-value")
axes[0].set_ylabel("Paired Adjusted P-value")
axes[0].set_title("Paired vs Unpaired Wilcoxon P-values")

# Significant features
sig_paired = set(np.where(result_paired.p_values_adj < 0.05)[0])
sig_unpaired = set(np.where(result_unpaired.p_values_adj < 0.05)[0])

from matplotlib_venn import venn2

venn2([sig_unpaired, sig_paired], ("Unpaired", "Paired"), ax=axes[1])
axes[1].set_title("Significant Features (FDR < 0.05)")

plt.tight_layout()
plt.savefig("tutorial_output/paired_vs_unpaired.png", dpi=300)
plt.show()

print("\nPaired vs Unpaired Analysis:")
print(f"  Unpaired only: {len(sig_unpaired - sig_paired)} features")
print(f"  Paired only: {len(sig_paired - sig_unpaired)} features")
print(f"  Both methods: {len(sig_paired & sig_unpaired)} features")
print("\nNote: Paired analysis typically has more power when pairs are matched.")

## 7. Result Visualization

### 7.1 Volcano Plots for Different Methods

In [None]:
# Create volcano plots for all methods
results = {
    "VOOM": result_voom,
    "limma-trend": result_trend,
    "DESeq2": result_deseq2,
    "Wilcoxon": result_wilcoxon,
    "Brunner-Munzel": result_bm,
}

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, (name, result) in enumerate(results.items()):
    if idx >= len(axes):
        break

    ax = axes[idx]

    log2fc = result.log2_fc
    neg_log_p = -np.log10(result.p_values_adj)

    # Color points
    colors = np.full(len(log2fc), "gray", dtype=object)
    valid = ~np.isnan(result.p_values_adj)

    sig = result.p_values_adj < 0.05
    colors[valid & sig & (log2fc > 1)] = "#d62728"  # red - up
    colors[valid & sig & (log2fc < -1)] = "#1f77b4"  # blue - down

    ax.scatter(log2fc[valid], neg_log_p[valid], c=colors[valid], alpha=0.6, s=20)

    # Threshold lines
    ax.axhline(-np.log10(0.05), color="black", linestyle="--", linewidth=0.5)
    ax.axvline(1, color="black", linestyle="--", linewidth=0.5)
    ax.axvline(-1, color="black", linestyle="--", linewidth=0.5)

    ax.set_xlabel("log2 Fold Change")
    ax.set_ylabel("-log10 Adjusted P-value")
    ax.set_title(f"{name}")
    ax.set_ylim(0, max(neg_log_p[valid]) * 1.1)

# Remove empty subplot
axes[-1].axis("off")

plt.tight_layout()
plt.savefig("tutorial_output/volcano_plots_all_methods.png", dpi=300)
plt.show()

print("Volcano plots generated for all methods.")

### 7.2 MA Plot (Intensity vs Fold Change)

In [None]:
# Create MA plot for VOOM results
fig, ax = plt.subplots(figsize=(10, 8))

# Calculate M (log2 fold change) and A (average expression)
M = result_voom.log2_fc
A = (result_voom.group_stats["A_mean"] + result_voom.group_stats["B_mean"]) / 2
A = np.log2(A + 1)  # Log transform average expression

# Color by significance
colors = np.full(len(M), "gray", dtype=object)
valid = ~np.isnan(result_voom.p_values_adj)
sig = result_voom.p_values_adj < 0.05

colors[valid & sig & (M > 1)] = "#d62728"  # red - up
colors[valid & sig & (M < -1)] = "#1f77b4"  # blue - down

ax.scatter(A[valid], M[valid], c=colors[valid], alpha=0.6, s=20)

# Add threshold lines
ax.axhline(0, color="black", linestyle="-", linewidth=0.5)
ax.axhline(1, color="black", linestyle="--", linewidth=0.5, alpha=0.5)
ax.axhline(-1, color="black", linestyle="--", linewidth=0.5, alpha=0.5)

ax.set_xlabel("Average Expression (log2)")
ax.set_ylabel("log2 Fold Change")
ax.set_title("MA Plot: VOOM Analysis")

plt.tight_layout()
plt.savefig("tutorial_output/ma_plot.png", dpi=300)
plt.show()

print("MA plot generated.")
print("The MA plot shows:")
print("  X-axis: Average expression across groups")
print("  Y-axis: Log2 fold change (M = difference, A = average)")

### 7.3 P-value Distribution

In [None]:
# Compare p-value distributions across methods
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# P-value histogram
bins = np.linspace(0, 1, 21)
for name, result in results.items():
    pvals = result.p_values_adj[~np.isnan(result.p_values_adj)]
    axes[0].hist(pvals, bins=bins, alpha=0.5, label=name, density=True)

axes[0].axvline(0.05, color="red", linestyle="--", linewidth=1, label="FDR = 0.05")
axes[0].set_xlabel("Adjusted P-value")
axes[0].set_ylabel("Density")
axes[0].set_title("P-value Distributions")
axes[0].legend(fontsize=8)

# Q-Q plot (quantile-quantile) for VOOM

pvals = result_voom.p_values[~np.isnan(result_voom.p_values)]
observed = -np.log10(sorted(pvals))
expected = -np.log10(np.linspace(1 / len(pvals), 1, len(pvals)))

axes[1].scatter(expected, observed, alpha=0.5, s=15)
axes[1].plot([0, max(expected)], [0, max(expected)], "r--", linewidth=1, label="Expected (null)")
axes[1].set_xlabel("Expected -log10 P-value")
axes[1].set_ylabel("Observed -log10 P-value")
axes[1].set_title("Q-Q Plot: VOOM Analysis")
axes[1].legend()

plt.tight_layout()
plt.savefig("tutorial_output/pvalue_distributions.png", dpi=300)
plt.show()

print("P-value distribution analysis complete.")
print("A deviation above the diagonal in the Q-Q plot indicates enrichment of low p-values.")

## 8. Handling Common Challenges

### 8.1 Low Count Filtering

The `min_count` parameter controls feature filtering. Features with insufficient counts are excluded from analysis.

In [None]:
# Demonstrate effect of min_count filtering
min_counts = [0, 5, 10, 20, 50]
n_tested = []
n_significant = []

for mc in min_counts:
    result = diff_expr_voom(
        container=container,
        assay_name="proteins",
        layer="raw",
        groupby="group",
        group1="A",
        group2="B",
        min_count=mc,
        normalize="tmm",
    )
    n_tested.append(np.sum(~np.isnan(result.p_values)))
    n_significant.append(np.sum(result.p_values_adj < 0.05))

# Plot results
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(min_counts))
width = 0.35

bars1 = ax.bar(x - width / 2, n_tested, width, label="Features Tested", color="#1f77b4")
bars2 = ax.bar(
    x + width / 2, n_significant, width, label="Significant (FDR < 0.05)", color="#d62728"
)

ax.set_xlabel("min_count Threshold")
ax.set_ylabel("Number of Features")
ax.set_title("Effect of Low Count Filtering")
ax.set_xticks(x)
ax.set_xticklabels(min_counts)
ax.legend()

# Add count labels
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            height,
            f"{int(height)}",
            ha="center",
            va="bottom",
            fontsize=8,
        )

plt.tight_layout()
plt.savefig("tutorial_output/mincount_filtering.png", dpi=300)
plt.show()

print("\nEffect of min_count parameter:")
for mc, tested, sig in zip(min_counts, n_tested, n_significant, strict=False):
    print(f"  min_count={mc:3d}: {tested:3d} tested, {sig:3d} significant")

### 8.2 Handling Missing Values

ScpTensor methods natively handle missing values via the mask matrix. The `missing_strategy` parameter controls how NaN values are treated:
- **ignore**: Skip missing values (default)
- **zero**: Replace with zeros
- **median**: Replace with feature median

In [None]:
# Create data with missing values for demonstration
container_na = container.clone()
X_na = container_na.assays["proteins"].layers["raw"].X.copy()

# Add random missing values
np.random.seed(123)
missing_mask = np.random.random(X_na.shape) < 0.1  # 10% missing
X_na[missing_mask] = np.nan

container_na.assays["proteins"].layers["raw"].X = X_na

# Compare missing value strategies
strategies = ["ignore", "zero", "median"]
results_na = {}

for strategy in strategies:
    try:
        result = diff_expr_wilcoxon(
            container=container_na,
            assay_name="proteins",
            layer="raw",
            groupby="group",
            group1="A",
            group2="B",
            missing_strategy=strategy,
        )
        results_na[strategy] = result
    except Exception as e:
        print(f"Strategy '{strategy}' failed: {e}")

# Compare results
if len(results_na) > 1:
    print("\nMissing Value Strategy Comparison:")
    for strategy, result in results_na.items():
        n_sig = np.sum(result.p_values_adj < 0.05)
        print(f"  {strategy:8s}: {n_sig} significant features")

print("\nRecommendation: Use 'ignore' (default) for most applications.")

## 9. Cross-Method Validation

Robust differential expression findings should be consistent across multiple statistical methods.

In [None]:
# Perform cross-method validation
from matplotlib_venn import venn3

# Get significant features from each method type
sig_count = set(np.where(result_voom.p_values_adj < 0.05)[0])
sig_nonpar = set(np.where(result_wilcoxon.p_values_adj < 0.05)[0])
sig_brunner = set(np.where(result_bm.p_values_adj < 0.05)[0])

# Consensus significant features (found by at least 2 methods)
from collections import Counter

all_sig = [sig_count, sig_nonpar, sig_brunner]
feature_counts = Counter()
for sig_set in all_sig:
    feature_counts.update(sig_set)

consensus = {f for f, c in feature_counts.items() if c >= 2}

print("Cross-Method Validation:")
print(f"  Count-based (VOOM): {len(sig_count)} features")
print(f"  Non-parametric (Wilcoxon): {len(sig_nonpar)} features")
print(f"  Non-parametric (Brunner-Munzel): {len(sig_brunner)} features")
print(f"  Consensus (2+ methods): {len(consensus)} features")

# Visualize overlap
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Venn diagram
venn3(
    [sig_count, sig_nonpar, sig_brunner],
    ("Count-based\n(VOOM)", "Non-parametric\n(Wilcoxon)", "Robust\n(Brunner-Munzel)"),
    ax=axes[0],
)
axes[0].set_title("Method Overlap (FDR < 0.05)")

# Concordance heatmap
methods_list = ["VOOM", "Wilcoxon", "Brunner-Munzel"]
sig_sets = [sig_count, sig_nonpar, sig_brunner]

concordance = np.zeros((3, 3))
for i in range(3):
    for j in range(3):
        if i == j:
            concordance[i, j] = 1.0
        else:
            concordance[i, j] = len(sig_sets[i] & sig_sets[j]) / len(sig_sets[i] | sig_sets[j])

im = axes[1].imshow(concordance, cmap="Blues", vmin=0, vmax=1)
axes[1].set_xticks(range(3))
axes[1].set_yticks(range(3))
axes[1].set_xticklabels(methods_list)
axes[1].set_yticklabels(methods_list)
axes[1].set_title("Method Concordance (Jaccard Index)")

# Add values to heatmap
for i in range(3):
    for j in range(3):
        text = axes[1].text(
            j, i, f"{concordance[i, j]:.2f}", ha="center", va="center", color="black", fontsize=12
        )

plt.colorbar(im, ax=axes[1])
plt.tight_layout()
plt.savefig("tutorial_output/cross_method_validation.png", dpi=300)
plt.show()

print("\nValidation Strategy:")
print("  1. Use consensus features (detected by 2+ methods)")
print("  2. Prioritize features with consistent fold change direction")
print("  3. Consider biological plausibility of findings")

## 10. Summary and Best Practices

### Key Takeaways

1. **Match method to data characteristics**
   - Count data: VOOM, limma-trend, or DESeq2
   - Continuous/normalized: Wilcoxon or Brunner-Munzel
   - Heteroscedastic: Brunner-Munzel preferred

2. **Always inspect data before testing**
   - Check mean-variance relationship
   - Assess sparsity
   - Verify group sizes

3. **Use multiple methods for validation**
   - Cross-method validation increases confidence
   - Focus on consensus findings
   - Report method used in publications

4. **Report adjusted p-values (FDR)**
   - Always report FDR-adjusted p-values
   - Specify correction method (Benjamini-Hochberg default)
   - Include fold change thresholds

### Quick Reference

| Scenario | Recommended Method | Function |
|----------|-------------------|----------|
| Count data, small sample | VOOM | `diff_expr_voom()` |
| Count data, mean-variance trend | limma-trend | `diff_expr_limma_trend()` |
| Count data, over-dispersed | DESeq2 | `diff_expr_deseq2()` |
| Non-normal distribution | Wilcoxon | `diff_expr_wilcoxon()` |
| Unequal variances | Brunner-Munzel | `diff_expr_brunner_munzel()` |
| Paired samples | Wilcoxon paired | `diff_expr_wilcoxon(paired=True)` |

### Common Pitfalls

- Using parametric tests on highly non-normal data
- Ignoring the mean-variance relationship in count data
- Not filtering low-count features before count-based analysis
- Forgetting to use paired tests for matched samples
- Over-interpreting single-method results

### Next Steps
- **Tutorial 10**: [Advanced Topics]
- **Documentation**: See `docs/design/API_REFERENCE.md` for full API details
- **Examples**: Check `examples/` for more usage patterns

In [None]:
# Final summary of all results
print("=" * 70)
print("DIFFERENTIAL EXPRESSION ANALYSIS SUMMARY")
print("=" * 70)

summary_data = []
for name, result in results.items():
    n_tested = np.sum(~np.isnan(result.p_values))
    n_sig = np.sum(result.p_values_adj < 0.05)
    n_up = np.sum((result.p_values_adj < 0.05) & (result.log2_fc > 1))
    n_down = np.sum((result.p_values_adj < 0.05) & (result.log2_fc < -1))
    summary_data.append([name, n_tested, n_sig, n_up, n_down])

summary_df = pl.DataFrame(
    summary_data, schema=["Method", "Tested", "Significant", "Up in A", "Down in A"]
)
print(summary_df)

print("\n" + "=" * 70)
print("Analysis complete! Results saved to tutorial_output/")
print("=" * 70)