# scVI Batch Correction QC Validation

**Purpose**: Validate the quality of scVI batch correction on the canonical single-cell dataset
(`CODEX_scvi_BioCov_phenotyped_newDuctal.h5ad`).

**Metrics computed**:
1. Silhouette batch score (sklearn) — measures batch mixing in scVI vs PCA space
2. LISI (Local Inverse Simpson Index) — quantifies local batch diversity
3. PCA vs scVI UMAP comparison — visual assessment of donor mixing
4. Per-donor cell type distributions — consistency check across donors
5. Donor 6533 integration verification

**Expected outcome**: scVI latent space should show better batch mixing than PCA while
preserving biological signal (cell type separation).

In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score, silhouette_samples
from sklearn.neighbors import NearestNeighbors
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

sc.settings.verbosity = 2
sc.settings.set_figure_params(dpi=100, frameon=False)

## 1. Load canonical single-cell data

In [None]:
import os
h5ad_path = os.path.join('..', 'single_cell_analysis', 'CODEX_scvi_BioCov_phenotyped_newDuctal.h5ad')
print(f'Loading: {h5ad_path}')
adata = sc.read_h5ad(h5ad_path)
print(f'Shape: {adata.shape}')
print(f'Obs columns: {list(adata.obs.columns)}')
print(f'Obsm keys: {list(adata.obsm.keys())}')
print(f'\nDonor status distribution:')
print(adata.obs['Donor Status'].value_counts())
print(f'\nUnique image IDs (donors/slides): {adata.obs["imageid"].nunique()}')
print(adata.obs['imageid'].value_counts())

## 2. Verify donor 6533 is present and well-represented

In [None]:
# Check for 6533 cells
donor_6533_mask = adata.obs['imageid'].astype(str).str.contains('6533')
n_6533 = donor_6533_mask.sum()
print(f'Donor 6533 cells: {n_6533:,}')

if n_6533 > 0:
    print(f'\n6533 phenotype distribution:')
    print(adata.obs.loc[donor_6533_mask, 'phenotype'].value_counts().head(10))
    print(f'\n6533 Donor Status: {adata.obs.loc[donor_6533_mask, "Donor Status"].unique()}')
else:
    print('WARNING: Donor 6533 is MISSING from single-cell data!')
    print('This donor must be integrated before proceeding with trajectory analysis.')

## 3. Compute PCA on raw expression (baseline comparison)

In [None]:
# Subsample for computational efficiency (scVI QC doesn't need all 2.6M cells)
np.random.seed(42)
n_subsample = min(50000, adata.n_obs)
idx = np.random.choice(adata.n_obs, n_subsample, replace=False)
adata_sub = adata[idx].copy()
print(f'Subsampled to {adata_sub.n_obs:,} cells for QC metrics')

# Compute PCA on raw/normalized expression
sc.pp.pca(adata_sub, n_comps=20)
print(f'PCA computed: {adata_sub.obsm["X_pca"].shape}')

# Verify scVI latent space exists
if 'X_scVI' in adata_sub.obsm:
    print(f'scVI latent space: {adata_sub.obsm["X_scVI"].shape}')
else:
    print('WARNING: X_scVI not found in obsm!')
    print(f'Available obsm keys: {list(adata_sub.obsm.keys())}')

## 4. Silhouette batch score: PCA vs scVI

The **silhouette batch score** measures how well batches (donors/images) are mixed.
- Score near 0 = good mixing (batches overlap)
- Score near 1 = poor mixing (batches separate)

We want batch silhouette to be **lower** in scVI space than PCA space.

In [None]:
batch_labels = adata_sub.obs['imageid'].values

# Silhouette score using batch (imageid) as labels
# Lower = better batch mixing
sil_pca = silhouette_score(adata_sub.obsm['X_pca'], batch_labels, metric='euclidean', sample_size=10000)
print(f'Silhouette batch score (PCA):  {sil_pca:.4f}')

if 'X_scVI' in adata_sub.obsm:
    sil_scvi = silhouette_score(adata_sub.obsm['X_scVI'], batch_labels, metric='euclidean', sample_size=10000)
    print(f'Silhouette batch score (scVI): {sil_scvi:.4f}')
    print(f'\nImprovement: {sil_pca - sil_scvi:.4f} (positive = scVI mixes batches better)')
    
    if sil_scvi < sil_pca:
        print('PASS: scVI shows better batch mixing than PCA')
    else:
        print('WARNING: scVI does not improve batch mixing over PCA')

## 5. Silhouette cell-type score: biological signal preservation

We also check that scVI preserves cell-type separation.
- Score near 1 = good cell type separation (desirable)
- Score near 0 = poor separation (cell types mixed together)

In [None]:
celltype_labels = adata_sub.obs['phenotype'].values

sil_ct_pca = silhouette_score(adata_sub.obsm['X_pca'], celltype_labels, metric='euclidean', sample_size=10000)
print(f'Silhouette cell-type score (PCA):  {sil_ct_pca:.4f}')

if 'X_scVI' in adata_sub.obsm:
    sil_ct_scvi = silhouette_score(adata_sub.obsm['X_scVI'], celltype_labels, metric='euclidean', sample_size=10000)
    print(f'Silhouette cell-type score (scVI): {sil_ct_scvi:.4f}')
    
    print(f'\nBiological signal change: {sil_ct_scvi - sil_ct_pca:.4f}')
    if sil_ct_scvi >= sil_ct_pca * 0.8:  # Allow up to 20% loss
        print('PASS: scVI preserves cell-type separation')
    else:
        print('WARNING: scVI may be over-correcting (losing biological signal)')

## 6. LISI (Local Inverse Simpson Index)

LISI measures local neighborhood diversity:
- **Batch LISI** near N_batches = good mixing
- **Cell-type LISI** near 1 = good separation

In [None]:
def compute_lisi(X, labels, perplexity=30):
    """Compute LISI scores for each cell.
    
    Based on Korsunsky et al. 2019 (Harmony paper).
    Uses k-nearest neighbors to estimate local diversity.
    """
    from sklearn.neighbors import NearestNeighbors
    
    k = min(perplexity * 3, X.shape[0] - 1)
    nn = NearestNeighbors(n_neighbors=k, metric='euclidean')
    nn.fit(X)
    distances, indices = nn.kneighbors(X)
    
    # Convert labels to integer codes
    label_codes = pd.Categorical(labels).codes
    n_categories = len(np.unique(label_codes))
    
    lisi_scores = np.zeros(X.shape[0])
    for i in range(X.shape[0]):
        neighbor_labels = label_codes[indices[i]]
        # Compute Simpson's index from neighbor proportions
        counts = np.bincount(neighbor_labels, minlength=n_categories)
        proportions = counts / counts.sum()
        simpson = np.sum(proportions ** 2)
        lisi_scores[i] = 1.0 / simpson  # Inverse Simpson
    
    return lisi_scores

# Subsample further for LISI (computationally intensive)
n_lisi = min(10000, adata_sub.n_obs)
lisi_idx = np.random.choice(adata_sub.n_obs, n_lisi, replace=False)
adata_lisi = adata_sub[lisi_idx].copy()

n_batches = adata_lisi.obs['imageid'].nunique()
print(f'Computing LISI on {n_lisi:,} cells ({n_batches} batches)...\n')

# Batch LISI (higher = better mixing, max = n_batches)
lisi_pca_batch = compute_lisi(adata_lisi.obsm['X_pca'], adata_lisi.obs['imageid'].values)
print(f'Batch LISI (PCA):  median={np.median(lisi_pca_batch):.2f}, mean={np.mean(lisi_pca_batch):.2f} (ideal={n_batches})')

if 'X_scVI' in adata_lisi.obsm:
    lisi_scvi_batch = compute_lisi(adata_lisi.obsm['X_scVI'], adata_lisi.obs['imageid'].values)
    print(f'Batch LISI (scVI): median={np.median(lisi_scvi_batch):.2f}, mean={np.mean(lisi_scvi_batch):.2f} (ideal={n_batches})')
    
    if np.median(lisi_scvi_batch) > np.median(lisi_pca_batch):
        print('\nPASS: scVI improves batch mixing (higher batch LISI)')
    else:
        print('\nWARNING: scVI does not improve batch LISI over PCA')

# Cell-type LISI (lower = better separation, ideal = 1)
lisi_pca_ct = compute_lisi(adata_lisi.obsm['X_pca'], adata_lisi.obs['phenotype'].values)
print(f'\nCell-type LISI (PCA):  median={np.median(lisi_pca_ct):.2f}, mean={np.mean(lisi_pca_ct):.2f} (ideal=1)')

if 'X_scVI' in adata_lisi.obsm:
    lisi_scvi_ct = compute_lisi(adata_lisi.obsm['X_scVI'], adata_lisi.obs['phenotype'].values)
    print(f'Cell-type LISI (scVI): median={np.median(lisi_scvi_ct):.2f}, mean={np.mean(lisi_scvi_ct):.2f} (ideal=1)')

## 7. Visual comparison: PCA vs scVI UMAP

In [None]:
# Compute UMAP from PCA
adata_pca = adata_sub.copy()
sc.pp.neighbors(adata_pca, n_neighbors=15, use_rep='X_pca')
sc.tl.umap(adata_pca)
adata_sub.obsm['X_umap_pca'] = adata_pca.obsm['X_umap'].copy()

# Compute UMAP from scVI
if 'X_scVI' in adata_sub.obsm:
    sc.pp.neighbors(adata_sub, n_neighbors=15, use_rep='X_scVI', metric='cosine')
    sc.tl.umap(adata_sub)
    adata_sub.obsm['X_umap_scvi'] = adata_sub.obsm['X_umap'].copy()

print('UMAPs computed')

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 14))

# PCA UMAP colored by donor
adata_sub.obsm['X_umap'] = adata_sub.obsm['X_umap_pca']
sc.pl.umap(adata_sub, color='imageid', ax=axes[0, 0], show=False, title='PCA UMAP — Donor (batch)', legend_loc='none')
sc.pl.umap(adata_sub, color='phenotype', ax=axes[0, 1], show=False, title='PCA UMAP — Cell type', legend_loc='none')

# scVI UMAP colored by donor
if 'X_umap_scvi' in adata_sub.obsm:
    adata_sub.obsm['X_umap'] = adata_sub.obsm['X_umap_scvi']
    sc.pl.umap(adata_sub, color='imageid', ax=axes[1, 0], show=False, title='scVI UMAP — Donor (batch)', legend_loc='none')
    sc.pl.umap(adata_sub, color='phenotype', ax=axes[1, 1], show=False, title='scVI UMAP — Cell type', legend_loc='none')

plt.tight_layout()
plt.savefig('../notebooks/scvi_pca_vs_scvi_umap.png', dpi=150, bbox_inches='tight')
plt.show()
print('Figure saved: scvi_pca_vs_scvi_umap.png')

## 8. Per-donor cell type distribution consistency

In [None]:
# Compute per-donor phenotype proportions
ct_proportions = adata.obs.groupby('imageid')['phenotype'].value_counts(normalize=True).unstack(fill_value=0)
print(f'Per-donor cell type proportions ({ct_proportions.shape[0]} donors × {ct_proportions.shape[1]} types):\n')

# Show top cell types by mean proportion
mean_props = ct_proportions.mean().sort_values(ascending=False)
print('Mean proportions across donors:')
for ct, prop in mean_props.head(10).items():
    cv = ct_proportions[ct].std() / ct_proportions[ct].mean() if ct_proportions[ct].mean() > 0 else 0
    print(f'  {ct:25s}: {prop:.3f} (CV={cv:.2f})')

# Check if 6533 has similar distribution
donor_6533_ids = [d for d in ct_proportions.index if '6533' in str(d)]
if donor_6533_ids:
    print(f'\n--- Donor 6533 cell type proportions ---')
    for ct in mean_props.head(8).index:
        for d_id in donor_6533_ids:
            val = ct_proportions.loc[d_id, ct]
            mean_val = mean_props[ct]
            z = (val - ct_proportions[ct].mean()) / (ct_proportions[ct].std() + 1e-10)
            flag = ' **' if abs(z) > 2 else ''
            print(f'  {ct:25s}: {val:.3f} (mean={mean_val:.3f}, z={z:+.1f}){flag}')

In [None]:
# Heatmap of per-donor cell type proportions
import matplotlib.pyplot as plt

# Select top 12 cell types for readability
top_types = mean_props.head(12).index
plot_data = ct_proportions[top_types]

fig, ax = plt.subplots(figsize=(14, 6))
im = ax.imshow(plot_data.values.T, aspect='auto', cmap='YlOrRd')
ax.set_xticks(range(len(plot_data.index)))
ax.set_xticklabels(plot_data.index, rotation=45, ha='right', fontsize=8)
ax.set_yticks(range(len(top_types)))
ax.set_yticklabels(top_types, fontsize=9)
ax.set_xlabel('Donor (imageid)')
ax.set_title('Per-donor cell type proportions (top 12 types)')
plt.colorbar(im, ax=ax, label='Proportion')
plt.tight_layout()
plt.savefig('../notebooks/scvi_donor_celltype_proportions.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Summary QC Report

In [None]:
print('='*60)
print('scVI BATCH CORRECTION QC SUMMARY')
print('='*60)
print(f'Dataset: {adata.n_obs:,} cells × {adata.n_vars} proteins')
print(f'Batches (imageids): {adata.obs["imageid"].nunique()}')
print(f'Cell types: {adata.obs["phenotype"].nunique()}')
print(f'Donor 6533 present: {n_6533 > 0} ({n_6533:,} cells)')
print()
print('--- Batch Integration Metrics ---')
print(f'Silhouette batch (PCA):  {sil_pca:.4f}  (lower=better)')
if 'X_scVI' in adata_sub.obsm:
    print(f'Silhouette batch (scVI): {sil_scvi:.4f}')
    print(f'Silhouette cell-type (PCA):  {sil_ct_pca:.4f}  (higher=better)')
    print(f'Silhouette cell-type (scVI): {sil_ct_scvi:.4f}')
    print(f'Batch LISI (PCA):  {np.median(lisi_pca_batch):.2f}  (higher=better, max={n_batches})')
    print(f'Batch LISI (scVI): {np.median(lisi_scvi_batch):.2f}')
    print(f'Cell-type LISI (PCA):  {np.median(lisi_pca_ct):.2f}  (lower=better, min=1)')
    print(f'Cell-type LISI (scVI): {np.median(lisi_scvi_ct):.2f}')
print()

# Overall verdict
passes = []
if 'X_scVI' in adata_sub.obsm:
    passes.append(('Batch silhouette improved', sil_scvi < sil_pca))
    passes.append(('Cell-type silhouette preserved (>80%)', sil_ct_scvi >= sil_ct_pca * 0.8))
    passes.append(('Batch LISI improved', np.median(lisi_scvi_batch) > np.median(lisi_pca_batch)))
passes.append(('Donor 6533 present', n_6533 > 0))

print('--- QC Checks ---')
all_pass = True
for desc, passed in passes:
    status = 'PASS' if passed else 'FAIL'
    if not passed:
        all_pass = False
    print(f'  [{status}] {desc}')

print()
if all_pass:
    print('OVERALL: All QC checks passed. scVI correction is validated.')
else:
    print('OVERALL: Some QC checks failed. Review metrics above before proceeding.')