# Tutorial 04: Clustering and Visualization

This tutorial covers dimensionality reduction, clustering, and visualization techniques for single-cell proteomics data.

## Learning Objectives

By the end of this tutorial, you will:
- Apply dimensionality reduction (PCA, UMAP)
- Run clustering algorithms (K-Means, graph-based)
- Visualize results with publication-quality plots
- Interpret clustering results and evaluate cluster quality
- Create comprehensive visualization reports

---

## 1. Setup

Import required libraries and load an example dataset.

In [None]:
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

# Apply SciencePlots style
plt.style.use(["science", "no-latex"])

# Import ScpTensor
import scptensor
from scptensor.datasets import load_example_with_clusters
from scptensor import (
    # Dimensionality reduction
    pca,
    umap,
    # Clustering
    run_kmeans,
    # Visualization
    embedding,
    scatter,
    heatmap,
    violin,
    qc_completeness,
    qc_matrix_spy,
    # Normalization
    log_normalize,
    zscore,
    # Imputation
    knn,
    # Integration
    combat,
)

print(f"ScpTensor version: {scptensor.__version__}")

## 2. Load and Prepare Data

We'll use a dataset with known cluster labels to evaluate our clustering results.

In [None]:
# Load dataset with known clusters
container = load_example_with_clusters()

print(f"Dataset loaded: {container}")
print(f"\nSamples: {container.n_samples}")
print(f"Features: {container.assays['proteins'].n_features}")

# View the true cell type labels
print("\nTrue cell type labels (ground truth):")
print(container.obs.group_by("cell_type").count().sort("cell_type"))

# Store true labels for evaluation
true_labels = container.obs["cell_type"].to_numpy()
true_label_ids = container.obs["cell_type_id"].to_numpy()

## 3. Data Preprocessing Pipeline

Before clustering, we need to preprocess the data: log transform, normalize, impute, and batch correct.

In [None]:
# Step 1: Log normalization
print("Step 1: Log normalization...")
container = log_normalize(
    container,
    assay_name="proteins",
    base_layer="raw",
    new_layer_name="log",
    base=2.0,
    offset=1.0,
)

# Step 2: Imputation
print("Step 2: KNN imputation...")
container = knn(
    container,
    assay_name="proteins",
    base_layer="log",
    new_layer_name="imputed",
    k=10,
)

# Step 3: Batch correction
print("Step 3: ComBat batch correction...")
container = combat(
    container,
    batch_key="batch",
    assay_name="proteins",
    base_layer="imputed",
    new_layer_name="corrected",
)

# Step 4: Z-score standardization (for clustering)
print("Step 4: Z-score standardization...")
container = zscore(
    container,
    assay_name="proteins",
    base_layer="corrected",
    new_layer_name="zscore",
)

print("\nPreprocessing completed!")
print(f"Available layers: {list(container.assays['proteins'].layers.keys())}")

## 4. Dimensionality Reduction

### 4.1 Principal Component Analysis (PCA)

PCA reduces dimensionality while preserving variance.

In [None]:
# Run PCA
print("Running PCA...")
container = pca(
    container,
    assay_name="proteins",
    base_layer_name="zscore",
    new_assay_name="pca",
    n_components=20,
)

# Get PCA results
pca_scores = container.assays["pca"].layers["scores"].X
pca_variance = container.assays["pca"].layers["scores"].metadata.variance_explained

print(f"PCA completed: {pca_scores.shape[1]} components")
print(f"\nVariance explained by first 10 PCs:")
for i in range(min(10, len(pca_variance))):
    print(f"  PC{i+1}: {pca_variance[i]:.2f}%")

print(f"\nTotal variance explained (first 10 PCs): {sum(pca_variance[:10]):.2f}%")

### 4.2 Visualize PCA Results

In [None]:
# Create PCA visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Get true cell types and batches
cell_types = container.obs["cell_type"].to_numpy()
batches = container.obs["batch"].to_numpy()

unique_ct = sorted(container.obs["cell_type"].unique().to_list())
unique_batch = sorted(container.obs["batch"].unique().to_list())

# PC1 vs PC2 colored by cell type
for ct in unique_ct:
    mask = cell_types == ct
    axes[0].scatter(pca_scores[mask, 0], pca_scores[mask, 1], label=ct, alpha=0.6, s=30)
axes[0].set_xlabel(f'PC1 ({pca_variance[0]:.1f}%)')
axes[0].set_ylabel(f'PC2 ({pca_variance[1]:.1f}%)')
axes[0].set_title('PCA: Colored by Cell Type')
axes[0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# PC1 vs PC2 colored by batch
for batch in unique_batch:
    mask = batches == batch
    axes[1].scatter(pca_scores[mask, 0], pca_scores[mask, 1], label=batch, alpha=0.6, s=30)
axes[1].set_xlabel(f'PC1 ({pca_variance[0]:.1f}%)')
axes[1].set_ylabel(f'PC2 ({pca_variance[1]:.1f}%)')
axes[1].set_title('PCA: Colored by Batch')
axes[1].legend()

# Scree plot
axes[2].bar(range(1, len(pca_variance) + 1), pca_variance, color='steelblue', alpha=0.7)
axes[2].set_xlabel('Principal Component')
axes[2].set_ylabel('Variance Explained (%)')
axes[2].set_title('Scree Plot')
axes[2].set_ylim(0, max(pca_variance) * 1.1)

plt.tight_layout()
plt.savefig('tutorial_output/pca_visualization.png', dpi=300)
plt.show()

print("PCA visualization saved to: tutorial_output/pca_visualization.png")

### 4.3 UMAP (Uniform Manifold Approximation and Projection)

UMAP provides non-linear dimensionality reduction, excellent for visualization.

In [None]:
# Run UMAP on PCA scores
print("Running UMAP...")
container = umap(
    container,
    assay_name="pca",
    base_layer="scores",
    new_assay_name="umap",
    n_neighbors=30,
    min_dist=0.1,
    n_components=2,
)

# Get UMAP coordinates
umap_scores = container.assays["umap"].layers["scores"].X

print(f"UMAP completed: {umap_scores.shape}")

### 4.4 Visualize UMAP Results

In [None]:
# Create UMAP visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# UMAP colored by cell type
for ct in unique_ct:
    mask = cell_types == ct
    axes[0].scatter(umap_scores[mask, 0], umap_scores[mask, 1], label=ct, alpha=0.6, s=30)
axes[0].set_xlabel('UMAP1')
axes[0].set_ylabel('UMAP2')
axes[0].set_title('UMAP: Colored by Cell Type (Ground Truth)')
axes[0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# UMAP colored by batch
for batch in unique_batch:
    mask = batches == batch
    axes[1].scatter(umap_scores[mask, 0], umap_scores[mask, 1], label=batch, alpha=0.6, s=30)
axes[1].set_xlabel('UMAP1')
axes[1].set_ylabel('UMAP2')
axes[1].set_title('UMAP: Colored by Batch')
axes[1].legend()

plt.tight_layout()
plt.savefig('tutorial_output/umap_visualization.png', dpi=300)
plt.show()

print("UMAP visualization saved to: tutorial_output/umap_visualization.png")

## 5. Clustering

### 5.1 K-Means Clustering

K-Means is a simple and fast clustering algorithm.

In [None]:
# Run K-Means clustering
print("Running K-Means clustering...")
n_clusters = len(unique_ct)  # Use true number of clusters
print(f"Number of clusters: {n_clusters}")

container = run_kmeans(
    container,
    assay_name="pca",
    base_layer="scores",
    n_clusters=n_clusters,
    key_added="kmeans_cluster",
    random_state=42,
)

# Get clustering results
kmeans_labels = container.obs["kmeans_cluster"].to_numpy()

print("\nK-Means clustering completed.")
print("\nCluster distribution:")
print(container.obs.group_by("kmeans_cluster").count().sort("kmeans_cluster"))

### 5.2 Visualize K-Means Results

In [None]:
# Visualize K-Means clusters
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# UMAP with K-Means clusters
unique_clusters = sorted(container.obs["kmeans_cluster"].unique().to_list())
for cluster in unique_clusters:
    mask = kmeans_labels == cluster
    axes[0].scatter(umap_scores[mask, 0], umap_scores[mask, 1], 
                   label=f'Cluster {cluster}', alpha=0.6, s=30)
axes[0].set_xlabel('UMAP1')
axes[0].set_ylabel('UMAP2')
axes[0].set_title('K-Means Clusters on UMAP')
axes[0].legend()

# PCA with K-Means clusters
for cluster in unique_clusters:
    mask = kmeans_labels == cluster
    axes[1].scatter(pca_scores[mask, 0], pca_scores[mask, 1], 
                   label=f'Cluster {cluster}', alpha=0.6, s=30)
axes[1].set_xlabel(f'PC1 ({pca_variance[0]:.1f}%)')
axes[1].set_ylabel(f'PC2 ({pca_variance[1]:.1f}%)')
axes[1].set_title('K-Means Clusters on PCA')
axes[1].legend()

# Confusion matrix-like visualization (Cluster vs Cell Type)
from scipy.stats import mode
confusion_data = np.zeros((n_clusters, len(unique_ct)))
for i, cluster in enumerate(unique_clusters):
    for j, ct in enumerate(unique_ct):
        confusion_data[i, j] = np.sum((kmeans_labels == cluster) & (cell_types == ct))

im = axes[2].imshow(confusion_data, aspect='auto', cmap='Blues')
axes[2].set_xticks(range(len(unique_ct)))
axes[2].set_xticklabels(unique_ct, rotation=45, ha='right')
axes[2].set_yticks(range(len(unique_clusters)))
axes[2].set_yticklabels([f'C{c}' for c in unique_clusters])
axes[2].set_xlabel('True Cell Type')
axes[2].set_ylabel('K-Means Cluster')
axes[2].set_title('Cluster vs Cell Type Overlap')
plt.colorbar(im, ax=axes[2])

plt.tight_layout()
plt.savefig('tutorial_output/kmeans_results.png', dpi=300)
plt.show()

print("K-Means visualization saved to: tutorial_output/kmeans_results.png")

## 6. Clustering Evaluation

Let's evaluate how well our clustering matches the true labels.

In [None]:
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score

# Calculate clustering metrics
ari = adjusted_rand_score(true_label_ids, kmeans_labels)
nmi = normalized_mutual_info_score(true_label_ids, kmeans_labels)
silhouette = silhouette_score(pca_scores, kmeans_labels)

print("Clustering Evaluation Metrics:")
print("=" * 40)
print(f"Adjusted Rand Index (ARI): {ari:.3f}")
print(f"  - Range: [-1, 1], 1 = perfect match")
print(f"Normalized Mutual Info (NMI): {nmi:.3f}")
print(f"  - Range: [0, 1], 1 = perfect match")
print(f"Silhouette Score: {silhouette:.3f}")
print(f"  - Range: [-1, 1], 1 = well-separated clusters")

## 7. Using ScpTensor Visualization Functions

ScpTensor provides convenient visualization functions.

In [None]:
# Use the built-in embedding function
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# PCA by cell type
ax = embedding(container, basis="pca", color="cell_type", ax=axes[0, 0])
ax.set_title('PCA by Cell Type')

# PCA by cluster
ax = embedding(container, basis="pca", color="kmeans_cluster", ax=axes[0, 1])
ax.set_title('PCA by K-Means Cluster')

# UMAP by cell type
ax = embedding(container, basis="umap", color="cell_type", ax=axes[1, 0])
ax.set_title('UMAP by Cell Type')

# UMAP by cluster
ax = embedding(container, basis="umap", color="kmeans_cluster", ax=axes[1, 1])
ax.set_title('UMAP by K-Means Cluster')

plt.tight_layout()
plt.savefig('tutorial_output/embedding_comparison.png', dpi=300)
plt.show()

print("Embedding comparison saved to: tutorial_output/embedding_comparison.png")

## 8. Heatmap Visualization

Heatmaps show expression patterns across clusters.

In [None]:
# Create a heatmap of top variable features
fig, ax = plt.subplots(figsize=(12, 8))

# Get feature names
var_features = container.assays["proteins"].var["protein_id"].to_numpy()

# Use zscore data for heatmap
X_zscore = container.assays["proteins"].layers["zscore"].X

# Select top 20 most variable features
feature_var = np.var(X_zscore, axis=0)
top_features = np.argsort(feature_var)[-20:]

# Order samples by cluster
sorted_indices = np.argsort(kmeans_labels)
X_heatmap = X_zscore[sorted_indices, :][:, top_features].T

# Plot heatmap
im = ax.imshow(X_heatmap, aspect='auto', cmap='RdBu_r', vmin=-2, vmax=2)

# Add cluster boundary lines
cluster_boundaries = np.cumsum([np.sum(kmeans_labels == c) for c in unique_clusters])
for boundary in cluster_boundaries[:-1]:
    ax.axvline(x=boundary, color='black', linewidth=1)

# Set labels
ax.set_xlabel('Samples (ordered by cluster)')
ax.set_ylabel('Features (top 20 variable)')
ax.set_yticks(range(20))
ax.set_yticklabels([var_features[i][:15] for i in top_features], fontsize=8)
ax.set_title('Heatmap of Top Variable Features')

plt.colorbar(im, ax=ax, label='Z-Score')
plt.tight_layout()
plt.savefig('tutorial_output/heatmap_clusters.png', dpi=300)
plt.show()

print("Heatmap saved to: tutorial_output/heatmap_clusters.png")

## 9. Violin Plots

Violin plots show the distribution of expression within clusters.

In [None]:
# Create violin plots for selected markers
# Let's pick the top 3 most variable features as "markers"
top_3_features = var_features[top_features[-3:]]

# Get zscore data
X_plot = container.assays["proteins"].layers["zscore"].X

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, feat in enumerate(top_3_features):
    feat_idx = np.where(var_features == feat)[0][0]
    
    # Collect data for each cluster
    data_by_cluster = []
    cluster_labels_plot = []
    
    for cluster in sorted(unique_clusters):
        mask = kmeans_labels == cluster
        data_by_cluster.append(X_plot[mask, feat_idx])
        cluster_labels_plot.append(f'C{cluster}')
    
    # Create violin plot
    parts = axes[idx].violinplot(data_by_cluster, 
                                   positions=range(len(cluster_labels_plot)),
                                   showmeans=True, showmedians=True)
    
    # Color the violin plot
    for pc in parts['bodies']:
        pc.set_facecolor(plt.cm.tab10(idx))
        pc.set_alpha(0.6)
    
    axes[idx].set_xticks(range(len(cluster_labels_plot)))
    axes[idx].set_xticklabels(cluster_labels_plot)
    axes[idx].set_ylabel('Z-Score')
    axes[idx].set_title(f'Feature: {feat[:20]}')

plt.tight_layout()
plt.savefig('tutorial_output/violin_plots.png', dpi=300)
plt.show()

print("Violin plots saved to: tutorial_output/violin_plots.png")

## 10. Cluster Summary Statistics

In [None]:
# Calculate cluster statistics
cluster_stats = []

for cluster in sorted(unique_clusters):
    mask = kmeans_labels == cluster
    cluster_data = X_zscore[mask]
    
    # Get dominant cell type
    cluster_cell_types = cell_types[mask]
    unique, counts = np.unique(cluster_cell_types, return_counts=True)
    dominant_ct = unique[np.argmax(counts)]
    purity = counts.max() / counts.sum()
    
    cluster_stats.append({
        'Cluster': cluster,
        'n_samples': mask.sum(),
        'dominant_cell_type': dominant_ct,
        'purity': purity,
        'mean_expression': cluster_data.mean(),
        'std_expression': cluster_data.std(),
    })

# Display cluster statistics
import pandas as pd
cluster_df = pd.DataFrame(cluster_stats)

print("Cluster Summary Statistics:")
print("=" * 80)
print(cluster_df.to_string(index=False))

# Overall clustering quality
avg_purity = cluster_df['purity'].mean()
print(f"\nAverage Cluster Purity: {avg_purity:.3f}")

## 11. QC Visualizations

Let's also check data completeness and quality.

In [None]:
# QC completeness plot by cluster
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Completeness by cluster
for cluster in sorted(unique_clusters):
    mask = kmeans_labels == cluster
    cluster_data = container.assays['proteins'].layers['zscore'].X[mask]
    n_detected = (cluster_data != 0).sum(axis=1)
    axes[0].hist(n_detected, bins=20, alpha=0.5, label=f'Cluster {cluster}')

axes[0].set_xlabel('Number of Detected Features')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Feature Detection by Cluster')
axes[0].legend()

# Missing rate by cluster
for cluster in sorted(unique_clusters):
    mask = kmeans_labels == cluster
    cluster_missing = container.obs[mask]['missing_rate'].to_numpy()
    axes[1].hist(cluster_missing, bins=20, alpha=0.5, label=f'Cluster {cluster}')

axes[1].set_xlabel('Missing Rate')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Missing Rate by Cluster')
axes[1].legend()

plt.tight_layout()
plt.savefig('tutorial_output/qc_by_cluster.png', dpi=300)
plt.show()

print("QC by cluster saved to: tutorial_output/qc_by_cluster.png")

## Summary

In this tutorial, you learned:

### Dimensionality Reduction:
1. **PCA**: Linear dimensionality reduction (`pca()`)
   - Preserves global variance
   - Fast and interpretable
   - Use first PCs for clustering

2. **UMAP**: Non-linear dimensionality reduction (`umap()`)
   - Preserves local structure
   - Excellent for visualization
   - Run on PCA scores for better performance

### Clustering:
1. **K-Means**: Simple, fast clustering (`run_kmeans()`)
   - Works well on PCA-reduced data
   - Requires specifying number of clusters

### Visualization:
1. **Embedding plots**: `embedding()` for PCA/UMAP
2. **Heatmaps**: Show feature patterns across clusters
3. **Violin plots**: Show distribution within clusters
4. **QC plots**: Check data quality by cluster

### Clustering Evaluation:
- **ARI (Adjusted Rand Index)**: Measures clustering accuracy vs. true labels
- **NMI (Normalized Mutual Info)**: Information-theoretic similarity
- **Silhouette Score**: Cluster separation quality
- **Purity**: Dominant class proportion in each cluster

### Best Practices:
- Always preprocess (normalize, impute, batch correct) before clustering
- Use PCA-reduced data for clustering (not raw or UMAP)
- Use UMAP for visualization, not for clustering input
- Evaluate clustering with multiple metrics
- Check biological interpretation of clusters
- Use SciencePlots style for publication-quality figures

### Analysis Pipeline:
```
Raw Data -> Log Transform -> Impute -> Batch Correct -> Z-Score
                                                         |
                                                         v
                                                     PCA (20 PCs)
                                                         |
                                    +------------------+------------------+
                                    |                  |                  |
                                    v                  v                  v
                               Clustering          UMAP              Evaluation
                                    |                  |                  |
                                    v                  v                  v
                               K-Means         Visualize           Metrics
```

### Next Steps:
- Try different clustering algorithms (graph-based, hierarchical)
- Perform differential expression analysis between clusters
- Identify cluster-specific marker proteins
- Apply this pipeline to your own data

### Additional Resources:
- API Reference: `docs/design/API_REFERENCE.md`
- Differential Expression: `scptensor.diff_expr` module