# Tutorial 02: Quality Control and Normalization

This tutorial covers quality control (QC) and normalization techniques for single-cell proteomics data.

## Learning Objectives

By the end of this tutorial, you will:
- Calculate and visualize QC metrics
- Detect and filter low-quality samples and features
- Apply various normalization methods
- Compare normalization approaches
- Understand the impact of normalization on data distribution

---

## 1. Setup

Import required libraries and load an example dataset.

In [None]:
import numpy as np
import polars as pl
import matplotlib.pyplot as plt

# Apply SciencePlots style
plt.style.use(["science", "no-latex"])

# Import ScpTensor
import scptensor
from scptensor.datasets import load_simulated_scrnaseq_like
from scptensor import (
    # QC functions
    calculate_qc_metrics,
    detect_outliers,
    filter_samples_by_missing_rate,
    filter_samples_by_total_count,
    filter_features_by_missing_rate,
    filter_features_by_variance,
    # Normalization functions
    log_normalize,
    sample_median_normalization,
    global_median_normalization,
    tmm_normalization,
    upper_quartile_normalization,
    zscore,
)

print(f"ScpTensor version: {scptensor.__version__}")

## 2. Load a Larger Dataset

For this tutorial, we'll use a larger simulated dataset that includes QC metrics.

In [None]:
# Load simulated dataset
container = load_simulated_scrnaseq_like()

print(f"Dataset loaded: {container}")
print(f"\nSamples: {container.n_samples}")
print(f"Features: {container.assays['proteins'].n_features}")
print(f"\nAvailable columns in obs:")
print(container.obs.columns)

### Expected Output:
```
Dataset loaded: ScpContainer with 500 samples and 1 assay

Samples: 500
Features: 200

Available columns in obs:
['sample_id', 'batch', 'cell_type', 'batch_id', 'cell_type_id', 
 'n_detected', 'missing_rate', 'total_intensity', 'mean_intensity', 
 'median_intensity', 'mad_intensity']
```

## 3. Quality Control Metrics

### 3.1 Calculate QC Metrics

Let's calculate comprehensive QC metrics for our data.

In [None]:
# Calculate QC metrics
container = calculate_qc_metrics(container, assay_name="proteins")

# View the added QC metrics
print("QC metrics added to obs:")
print("=" * 50)
print(container.obs.select([
    "sample_id", "n_detected", "missing_rate", 
    "total_intensity", "mean_intensity"
]).head(10))

### 3.2 Visualize QC Metrics

In [None]:
# Create QC visualization
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

# Missing rate distribution
axes[0, 0].hist(container.obs["missing_rate"].to_numpy(), bins=30, edgecolor='black', alpha=0.7)
axes[0, 0].set_xlabel('Missing Rate')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('Distribution of Missing Rate per Sample')

# Number of detected features
axes[0, 1].hist(container.obs["n_detected"].to_numpy(), bins=30, edgecolor='black', alpha=0.7)
axes[0, 1].set_xlabel('Number of Detected Proteins')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Distribution of Detected Proteins')

# Total intensity
axes[1, 0].hist(container.obs["total_intensity"].to_numpy(), bins=30, edgecolor='black', alpha=0.7)
axes[1, 0].set_xlabel('Total Intensity')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].set_title('Distribution of Total Intensity')

# Mean intensity
axes[1, 1].hist(container.obs["mean_intensity"].to_numpy(), bins=30, edgecolor='black', alpha=0.7)
axes[1, 1].set_xlabel('Mean Intensity')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].set_title('Distribution of Mean Intensity')

plt.tight_layout()
plt.savefig('tutorial_output/qc_distributions.png', dpi=300)
plt.show()

print("QC plots saved to: tutorial_output/qc_distributions.png")

### 3.3 QC by Batch and Cell Type

In [None]:
# Group statistics by batch and cell type
print("QC Statistics by Batch:")
print("=" * 50)
batch_stats = container.obs.group_by("batch").agg(
    pl.col("missing_rate").mean().alias("mean_missing"),
    pl.col("n_detected").mean().alias("mean_detected"),
    pl.col("total_intensity").mean().alias("mean_total_intensity"),
    pl.len().alias("n_samples")
).sort("batch")
print(batch_stats)

print("\nQC Statistics by Cell Type:")
print("=" * 50)
celltype_stats = container.obs.group_by("cell_type").agg(
    pl.col("missing_rate").mean().alias("mean_missing"),
    pl.col("n_detected").mean().alias("mean_detected"),
    pl.col("total_intensity").mean().alias("mean_total_intensity"),
    pl.len().alias("n_samples")
).sort("cell_type")
print(celltype_stats)

## 4. Outlier Detection

### 4.1 Detect Outlier Samples

In [None]:
# Detect outliers using median absolute deviation (MAD)
outliers = detect_outliers(
    container,
    assay_name="proteins",
    metric="mad",  # Options: 'mad', 'zscore', 'iqr'
    threshold=3.0,  # MAD threshold
    key_added="is_outlier"
)

print(f"Outlier detection completed.")
print(f"Number of outliers detected: {outliers.obs['is_outlier'].sum()}")
print(f"Outlier percentage: {outliers.obs['is_outlier'].mean() * 100:.2f}%")

### 4.2 Visualize Outliers

In [None]:
# Plot total intensity with outliers highlighted
fig, ax = plt.subplots(figsize=(10, 6))

total_intensity = outliers.obs["total_intensity"].to_numpy()
is_outlier = outliers.obs["is_outlier"].to_numpy()

# Plot non-outliers
ax.scatter(np.where(~is_outlier)[0], total_intensity[~is_outlier], 
           c='blue', alpha=0.6, s=30, label='Normal')

# Plot outliers
ax.scatter(np.where(is_outlier)[0], total_intensity[is_outlier], 
           c='red', alpha=0.8, s=50, label='Outlier', marker='x')

ax.set_xlabel('Sample Index')
ax.set_ylabel('Total Intensity')
ax.set_title('Outlier Detection Based on Total Intensity')
ax.legend()

plt.tight_layout()
plt.savefig('tutorial_output/outlier_detection.png', dpi=300)
plt.show()

print("Outlier plot saved to: tutorial_output/outlier_detection.png")

## 5. Filtering

### 5.1 Filter Samples by Missing Rate

In [None]:
# Filter samples with >50% missing values
print("Before filtering:")
print(f"  Samples: {container.n_samples}")
print(f"  Features: {container.assays['proteins'].n_features}")

container_filtered = filter_samples_by_missing_rate(
    container,
    assay_name="proteins",
    threshold=0.5,  # Keep samples with <=50% missing
)

print("\nAfter filtering by missing rate:")
print(f"  Samples: {container_filtered.n_samples}")
print(f"  Features: {container_filtered.assays['proteins'].n_features}")
print(f"  Samples removed: {container.n_samples - container_filtered.n_samples}")

### 5.2 Filter Samples by Total Count

In [None]:
# Filter samples with low total intensity (bottom 5%)
min_intensity = np.percentile(container_filtered.obs["total_intensity"].to_numpy(), 5)

container_filtered = filter_samples_by_total_count(
    container_filtered,
    assay_name="proteins",
    min_total_count=min_intensity,
)

print(f"\nAfter filtering by total count:")
print(f"  Samples: {container_filtered.n_samples}")
print(f"  Features: {container_filtered.assays['proteins'].n_features}")

### 5.3 Filter Features by Missing Rate

In [None]:
# Filter proteins detected in <20% of samples
print("Before feature filtering:")
print(f"  Features: {container_filtered.assays['proteins'].n_features}")

container_filtered = filter_features_by_missing_rate(
    container_filtered,
    assay_name="proteins",
    threshold=0.8,  # Keep features with <=80% missing (detected in >=20%)
)

print("\nAfter feature filtering:")
print(f"  Samples: {container_filtered.n_samples}")
print(f"  Features: {container_filtered.assays['proteins'].n_features}")
print(f"  Features removed: {container.assays['proteins'].n_features - container_filtered.assays['proteins'].n_features}")

### 5.4 Filter Features by Variance

In [None]:
# Filter low-variance features (bottom 10%)
container_filtered = filter_features_by_variance(
    container_filtered,
    assay_name="proteins",
    percentile=10,  # Remove bottom 10% variance
)

print(f"\nAfter filtering by variance:")
print(f"  Samples: {container_filtered.n_samples}")
print(f"  Features: {container_filtered.assays['proteins'].n_features}")
print(f"\nFinal data shape: {container_filtered.n_samples} x {container_filtered.assays['proteins'].n_features}")

## 6. Normalization

Normalization removes technical variation (e.g., sample-specific effects) to make samples comparable.

### 6.1 Data Distribution Before Normalization

In [None]:
# Get raw data
X_raw = container_filtered.assays["proteins"].layers["raw"].X
M_raw = container_filtered.assays["proteins"].layers["raw"].M

# Create masked array for visualization
X_raw_masked = X_raw.copy().astype(float)
X_raw_masked[M_raw != 0] = np.nan

# Plot distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Density plot
axes[0].hist(X_raw_masked.flatten(), bins=50, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Intensity')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Raw Data Distribution')
axes[0].set_yscale('log')

# Box plot by sample
sample_medians = np.nanmedian(X_raw_masked, axis=1)
axes[1].boxplot(sample_medians, vert=False)
axes[1].set_xlabel('Median Intensity')
axes[1].set_yticks([])
axes[1].set_title('Distribution of Sample Medians (Raw)')

plt.tight_layout()
plt.savefig('tutorial_output/before_normalization.png', dpi=300)
plt.show()

print(f"Raw data statistics:")
print(f"  Mean: {np.nanmean(X_raw_masked):.4f}")
print(f"  Median: {np.nanmedian(X_raw_masked):.4f}")
print(f"  Std: {np.nanstd(X_raw_masked):.4f}")
print(f"  CV: {np.nanstd(X_raw_masked) / np.nanmean(X_raw_masked) * 100:.2f}%")

### 6.2 Log Normalization

Log transformation stabilizes variance and makes the data more normally distributed.

In [None]:
# Apply log2 normalization with offset
container_normalized = log_normalize(
    container_filtered,
    assay_name="proteins",
    base_layer="raw",
    new_layer_name="log",
    base=2.0,  # Log base
    offset=1.0,  # Pseudocount to avoid log(0)
)

# Check that log layer was created
print("Layers after log normalization:")
print(list(container_normalized.assays['proteins'].layers.keys()))

# Get log-transformed data
X_log = container_normalized.assays["proteins"].layers["log"].X

# Plot distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(X_log.flatten(), bins=50, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Log2 Intensity')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Log Normalized Data Distribution')

sample_medians_log = np.nanmedian(X_log, axis=1)
axes[1].boxplot(sample_medians_log, vert=False)
axes[1].set_xlabel('Median Log2 Intensity')
axes[1].set_yticks([])
axes[1].set_title('Distribution of Sample Medians (Log)')

plt.tight_layout()
plt.savefig('tutorial_output/log_normalization.png', dpi=300)
plt.show()

print(f"\nLog normalized statistics:")
print(f"  Mean: {np.nanmean(X_log):.4f}")
print(f"  Median: {np.nanmedian(X_log):.4f}")
print(f"  Std: {np.nanstd(X_log):.4f}")

### 6.3 Median Normalization

Median normalization scales each sample to have the same median intensity.

In [None]:
# Apply sample median normalization to log data
container_normalized = sample_median_normalization(
    container_normalized,
    assay_name="proteins",
    base_layer="log",
    new_layer_name="log_median",
)

X_log_median = container_normalized.assays["proteins"].layers["log_median"].X

# Plot sample medians before and after
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].boxplot([sample_medians_log, np.nanmedian(X_log_median, axis=1)], vert=False)
axes[0].set_yticklabels(['Log', 'Log + Median'])
axes[0].set_xlabel('Median Intensity')
axes[0].set_title('Sample Medians Comparison')

# CV comparison
cv_before = np.std(sample_medians_log) / np.mean(sample_medians_log) * 100
cv_after = np.nanstd(X_log_median, axis=1).std() / np.nanmean(X_log_median) * 100

axes[1].bar(['Before', 'After'], [cv_before, cv_after], color=['coral', 'skyblue'])
axes[1].set_ylabel('Coefficient of Variation (%)')
axes[1].set_title('Cross-Sample Variation')

plt.tight_layout()
plt.savefig('tutorial_output/median_normalization.png', dpi=300)
plt.show()

print(f"CV before median normalization: {cv_before:.2f}%")
print(f"CV after median normalization: {cv_after:.2f}%")

### 6.4 TMM Normalization (Trimmed Mean of M-values)

TMM is a robust normalization method commonly used in proteomics.

In [None]:
# Apply TMM normalization
container_normalized = tmm_normalization(
    container_normalized,
    assay_name="proteins",
    base_layer="raw",
    new_layer_name="tmm",
    trim_ratio=0.3,  # Trim 30% from each end
)

X_tmm = container_normalized.assays["proteins"].layers["tmm"].X

print("TMM normalization completed.")
print(f"TMM data shape: {X_tmm.shape}")
print(f"Mean intensity: {np.nanmean(X_tmm):.4f}")
print(f"Median intensity: {np.nanmedian(X_tmm):.4f}")

### 6.5 Upper Quartile Normalization

In [None]:
# Apply upper quartile normalization
container_normalized = upper_quartile_normalization(
    container_normalized,
    assay_name="proteins",
    base_layer="raw",
    new_layer_name="uq",
    percentile=0.75,  # Upper quartile (75th percentile)
)

X_uq = container_normalized.assays["proteins"].layers["uq"].X

print("Upper quartile normalization completed.")
print(f"UQ data shape: {X_uq.shape}")

### 6.6 Z-Score Standardization

Z-score standardization scales each feature to have mean=0 and std=1.

In [None]:
# Apply z-score standardization
container_normalized = zscore(
    container_normalized,
    assay_name="proteins",
    base_layer="log",
    new_layer_name="zscore",
)

X_zscore = container_normalized.assays["proteins"].layers["zscore"].X

print("Z-score standardization completed.")
print(f"Mean (should be ~0): {np.nanmean(X_zscore):.6f}")
print(f"Std (should be ~1): {np.nanstd(X_zscore):.6f}")

# Plot z-scored data
fig, ax = plt.subplots(figsize=(8, 5))
ax.hist(X_zscore.flatten(), bins=50, edgecolor='black', alpha=0.7)
ax.set_xlabel('Z-Score')
ax.set_ylabel('Frequency')
ax.set_title('Z-Score Standardized Data Distribution')
plt.tight_layout()
plt.savefig('tutorial_output/zscore_distribution.png', dpi=300)
plt.show()

## 7. Comparing Normalization Methods

Let's compare the effects of different normalization methods.

In [None]:
# Collect sample medians for each method
layers = ['raw', 'log', 'log_median', 'tmm', 'uq']
sample_medians_by_method = {}

for layer in layers:
    if layer in container_normalized.assays['proteins'].layers:
        X = container_normalized.assays['proteins'].layers[layer].X
        sample_medians_by_method[layer] = np.nanmedian(X, axis=1)

# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Box plot comparison
data_to_plot = [sample_medians_by_method[layer] for layer in layers if layer in sample_medians_by_method]
labels_to_plot = [layer for layer in layers if layer in sample_medians_by_method]
bp = axes[0].boxplot(data_to_plot, labels=labels_to_plot)
axes[0].set_ylabel('Median Intensity')
axes[0].set_title('Comparison of Normalization Methods')
axes[0].tick_params(axis='x', rotation=45)

# CV comparison
cv_values = [np.std(d) / np.mean(d) * 100 for d in data_to_plot]
axes[1].bar(labels_to_plot, cv_values, color='steelblue')
axes[1].set_ylabel('Coefficient of Variation (%)')
axes[1].set_title('Cross-Sample Variation by Method')
axes[1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.savefig('tutorial_output/normalization_comparison.png', dpi=300)
plt.show()

print("\nCoefficient of Variation by method:")
for layer, cv in zip(labels_to_plot, cv_values):
    print(f"  {layer:15s}: {cv:6.2f}%")

## 8. Visualizing Normalization Effects

Let's visualize how the data distribution changes with normalization.

In [None]:
# Get first 100 features for visualization
n_features_plot = min(100, container_normalized.assays['proteins'].n_features)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

methods = ['raw', 'log', 'log_median', 'tmm', 'uq', 'zscore']
titles = ['Raw', 'Log2', 'Log + Median', 'TMM', 'Upper Quartile', 'Z-Score']

for i, (method, title) in enumerate(zip(methods, titles)):
    if method in container_normalized.assays['proteins'].layers:
        X = container_normalized.assays['proteins'].layers[method].X
        
        # Plot first 100 features
        for j in range(n_features_plot):
            axes[i].plot(X[:, j], alpha=0.1, color='blue')
        
        axes[i].set_title(title)
        axes[i].set_xlabel('Sample Index')
        if i == 0:
            axes[i].set_ylabel('Intensity')

plt.tight_layout()
plt.savefig('tutorial_output/normalization_profiles.png', dpi=300)
plt.show()

print("Normalization profiles saved to: tutorial_output/normalization_profiles.png")

## Summary

In this tutorial, you learned:

### Quality Control:
1. **Calculate QC Metrics**: Using `calculate_qc_metrics()`
2. **Detect Outliers**: Using `detect_outliers()` with MAD, Z-score, or IQR methods
3. **Filter Samples**: By missing rate (`filter_samples_by_missing_rate()`) and total count (`filter_samples_by_total_count()`)
4. **Filter Features**: By missing rate (`filter_features_by_missing_rate()`) and variance (`filter_features_by_variance()`)

### Normalization:
1. **Log Normalization**: Stabilizes variance (`log_normalize()`)
2. **Median Normalization**: Scales samples to common median (`sample_median_normalization()`)
3. **TMM Normalization**: Robust normalization for proteomics (`tmm_normalization()`)
4. **Upper Quartile Normalization**: Uses 75th percentile (`upper_quartile_normalization()`)
5. **Z-Score Standardization**: Mean=0, Std=1 (`zscore()`)

### Best Practices:
- Always QC before normalization
- Log transform before most normalization methods
- Choose normalization based on your data characteristics
- Use z-score for methods that assume standardized data (e.g., PCA, clustering)

### Next Steps:
- **Tutorial 03**: Imputation and Batch Correction
- **Tutorial 04**: Clustering and Visualization