In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import json
import matplotlib.pyplot as plt
from scipy import sparse
from scipy.io import mmread

# MaxFuse imports
import maxfuse as mf
from maxfuse import Fusor, Mario
from maxfuse.mario import pipelined_mario

import warnings
warnings.filterwarnings("ignore")

In [None]:
# Load preprocessed data from 1_preprocessing.ipynb
# Run 1_preprocessing.ipynb first to generate these files

import os

# Fix working directory if notebook started from wrong location
# The notebook expects to run from the notebooks/ directory
notebook_dir = os.path.dirname(os.path.abspath('__file__'))
if not os.path.exists('../results/1_preprocessing') and os.path.exists('/home/smith6jt/maxfuse/results/1_preprocessing'):
    os.chdir('/home/smith6jt/maxfuse/notebooks')
    print(f"Changed working directory to: {os.getcwd()}")

results_dir = '../results/1_preprocessing'

if not os.path.exists(results_dir):
    raise FileNotFoundError(
        f"Results directory '{results_dir}' not found. "
        f"Run 1_preprocessing.ipynb first to generate the input files."
    )

# Load processed AnnData objects
protein_adata = sc.read_h5ad(f'{results_dir}/protein_adata.h5ad')
rna_adata = sc.read_h5ad(f'{results_dir}/rna_adata.h5ad')
rna_adata_lognorm = sc.read_h5ad(f'{results_dir}/rna_adata_lognorm.h5ad')

print(f"Loaded from {results_dir}/")
print(f"  Protein data: {protein_adata.shape}")
print(f"  RNA data: {rna_adata.shape}")
print(f"  RNA log-normalized: {rna_adata_lognorm.shape}")

# # Load preprocessing parameters
# with open(f'{results_dir}/preprocessing_params.json', 'r') as f:
#     preprocess_params = json.load(f)
# print(f"\nPreprocessing timestamp: {preprocess_params['timestamp']}")

In [None]:
sampled = sc.pp.sample(protein_adata[protein_adata.obs['Tissue']=='Pancreas'], n=5000, rng=42, copy=True)
sc.pp.log1p(sampled)
sc.pp.scale(sampled,zero_center=False)
sc.pp.pca(sampled)
sc.pp.neighbors(sampled)

In [None]:
sc.tl.leiden(sampled, resolution=0.5, flavor='igraph')

In [None]:
sc.tl.paga(sampled, groups='leiden')
sc.pl.paga(sampled, plot=False)

In [None]:
sc.tl.umap(sampled, min_dist=0.5, spread=1.5, init_pos='paga')
sc.pl.umap(sampled, color='leiden')

## Step 3: Build Protein-Gene Correspondence

Map CODEX protein markers to their corresponding gene names in the RNA data.

In [None]:
# Load correspondence table
correspondence = pd.read_csv('../data/protein_gene_conversion.csv', encoding='utf-8-sig')
print(f"Correspondence table: {correspondence.shape[0]} entries")
correspondence.head(10)

In [None]:
protein_adata.var_names

In [None]:
# Find matching features between CODEX markers and RNA genes
# Define markers to exclude (non-immune, structural, or problematic for CD45+ RNA data)
EXCLUDED_MARKERS = [
    'DAPI',           # Nuclear stain
    'ECAD', 'E-cadherin',  # Epithelial
    'IAPP', 'INS', 'GCG', 'SST',  # Pancreatic endocrine
    # 'Ker8-18', 'Pan-Cytokeratin', 'Keratin 5', 'EpCAM', 'TP63',  # Epithelial
    # 'Collagen IV', 'Vimentin', 'SMA', 'Caveolin',  # Stromal/structural
    # 'CD31', 'Podoplanin',  # Endothelial/lymphatic
    'Beta-actin',     # Housekeeping
    'LAG3',          # Excluded per user
]

rna_protein_correspondence = []
unmatched_proteins = []

for marker in protein_adata.var_names:
    # Skip excluded markers
    if marker in EXCLUDED_MARKERS:
        continue
    
    # Look up in correspondence table
    matches = correspondence[correspondence['Protein name'].str.lower() == marker.lower()]
    
    if len(matches) == 0:
        # Try alternative names
        alt_names = {
            'CD3e': 'CD3',
            'FoxP3': 'FOXP3',
            'HLADR': 'HLA-DR',
            'Lyve1': 'LYVE1',
            'SMActin': 'SMA',
            'CollagenIV': 'collagen IV',
        }
        alt_marker = alt_names.get(marker, marker)
        matches = correspondence[correspondence['Protein name'].str.lower() == alt_marker.lower()]
    
    if len(matches) > 0:
        rna_names_str = matches.iloc[0]['RNA name']
        if 'Ignore' in str(rna_names_str):
            unmatched_proteins.append((marker, 'Ignored'))
            continue
        
        # Try each RNA name option
        found = False
        for rna_name in str(rna_names_str).split('/'):
            if rna_name in rna_adata.var_names:
                rna_protein_correspondence.append([rna_name, marker])
                found = True
                break
        if not found:
            unmatched_proteins.append((marker, rna_names_str))
    else:
        unmatched_proteins.append((marker, 'Not in table'))

rna_protein_correspondence = np.array(rna_protein_correspondence)
print(f"Found {len(rna_protein_correspondence)} protein-gene pairs")
print(f"Excluded {len(EXCLUDED_MARKERS)} non-immune/structural markers")

if unmatched_proteins:
    print(f"\nUnmatched proteins ({len(unmatched_proteins)}):")
    for prot, reason in unmatched_proteins:
        print(f"  {prot}: {reason}")

In [None]:
# Remove duplicates (same RNA mapping to multiple proteins)
# Keep first occurrence
seen_rna = set()
unique_pairs = []
for rna, prot in rna_protein_correspondence:
    if rna not in seen_rna:
        seen_rna.add(rna)
        unique_pairs.append([rna, prot])
    else:
        print(f"Removing duplicate RNA mapping: {rna} -> {prot}")

rna_protein_correspondence = np.array(unique_pairs)
print(f"\nFinal correspondence: {len(rna_protein_correspondence)} pairs")

print("\nMatched features:")
for rna, prot in rna_protein_correspondence:
    print(f"  {rna:15} <-> {prot}")

## Step 4: Prepare Arrays for Integration

Extract and normalize:
- **Shared arrays**: Corresponding protein/gene features (used for initial matching)
- **Active arrays**: All features (used for refinement)

In [None]:
shared_rna_genes = rna_protein_correspondence[:, 0]  # RNA gene names
shared_protein_markers = rna_protein_correspondence[:, 1]  # Protein marker names

# Verify all features exist
missing_rna = [g for g in shared_rna_genes if g not in rna_adata.var_names]
missing_prot = [p for p in shared_protein_markers if p not in protein_adata.var_names]

if missing_rna:
    print(f"WARNING: Missing RNA genes: {missing_rna}")
if missing_prot:
    print(f"WARNING: Missing protein markers: {missing_prot}")

# Create shared feature AnnData objects
rna_shared_adata = rna_adata[:, shared_rna_genes].copy()
protein_shared_adata = protein_adata[:, shared_protein_markers].copy()

print(f"Shared feature AnnData objects created:")
print(f"  rna_shared_adata: {rna_shared_adata.shape}")
print(f"  protein_shared_adata: {protein_shared_adata.shape}")
print(f"\nRNA shared features: {list(rna_shared_adata.var_names[:5])}...")
print(f"Protein shared features: {list(protein_shared_adata.var_names[:5])}...")

In [None]:
# Normalize shared features for MaxFuse
# BOTH modalities: detection-aware z-score
# - RNA: zeros = no transcript
# - Protein: zeros = below gate (from notebook 1)

from scipy import sparse, stats
from sklearn.preprocessing import StandardScaler

ZERO_VALUE = -1.0  # Value for non-expressing cells

print("=" * 70)
print("SHARED FEATURE NORMALIZATION")
print("=" * 70)

# ============================================================
# RNA SHARED
# ============================================================

rna_shared_lognorm = rna_adata_lognorm[:, shared_rna_genes].X.copy()
if sparse.issparse(rna_shared_lognorm):
    rna_shared_lognorm = rna_shared_lognorm.toarray()

rna_shared_raw = rna_shared_adata.X.copy()
if sparse.issparse(rna_shared_raw):
    rna_shared_raw = rna_shared_raw.toarray()

rna_expressing = rna_shared_lognorm > 0

# ============================================================
# PROTEIN SHARED
# ============================================================

shared_marker_names = [prot_name for rna_name, prot_name in rna_protein_correspondence]

protein_gated = protein_shared_adata.X.copy()
if sparse.issparse(protein_gated):
    protein_gated = protein_gated.toarray()

protein_expressing = protein_gated > 0

protein_log_layer = protein_adata.layers['log']
if sparse.issparse(protein_log_layer):
    protein_log_layer = protein_log_layer.toarray()

protein_log_shared = np.zeros((protein_adata.n_obs, len(shared_marker_names)))
for i, marker in enumerate(shared_marker_names):
    marker_idx = list(protein_adata.var_names).index(marker)
    protein_log_shared[:, i] = protein_log_layer[:, marker_idx]

# ============================================================
# ITERATIVE FILTERING: markers and cells until stable
# ============================================================

def filter_markers(rna_expr, prot_expr, correspondence, *arrays):
    """Remove markers with no expressing cells in either modality."""
    rna_has = rna_expr.any(axis=0)
    prot_has = prot_expr.any(axis=0)
    valid = rna_has & prot_has
    n_removed = (~valid).sum()
    if n_removed > 0:
        removed = correspondence[~valid]
        print(f"  Removing {n_removed} markers: {[p for r,p in removed]}")
    return valid, n_removed

def filter_cells(rna_expr, prot_expr):
    """Remove cells with no marker expression."""
    rna_any = rna_expr.any(axis=1)
    prot_any = prot_expr.any(axis=1)
    return rna_any, prot_any, (~rna_any).sum(), (~prot_any).sum()

print("\nFiltering markers and cells...")
iteration = 0
while True:
    iteration += 1
    changes = 0
    
    # Filter markers
    marker_valid, n_markers = filter_markers(rna_expressing, protein_expressing, rna_protein_correspondence)
    if n_markers > 0:
        changes += n_markers
        rna_shared_lognorm = rna_shared_lognorm[:, marker_valid]
        rna_shared_raw = rna_shared_raw[:, marker_valid]
        rna_expressing = rna_expressing[:, marker_valid]
        protein_gated = protein_gated[:, marker_valid]
        protein_log_shared = protein_log_shared[:, marker_valid]
        protein_expressing = protein_expressing[:, marker_valid]
        rna_protein_correspondence = rna_protein_correspondence[marker_valid]
    
    # Filter cells
    rna_keep, prot_keep, n_rna, n_prot = filter_cells(rna_expressing, protein_expressing)
    if n_rna > 0:
        print(f"  Removing {n_rna:,} RNA cells with no expression")
        changes += n_rna
        rna_shared_lognorm = rna_shared_lognorm[rna_keep]
        rna_shared_raw = rna_shared_raw[rna_keep]
        rna_expressing = rna_expressing[rna_keep]
        rna_shared_adata = rna_shared_adata[rna_keep].copy()
        rna_adata = rna_adata[rna_keep].copy()
        rna_adata_lognorm = rna_adata_lognorm[rna_keep].copy()
    
    if n_prot > 0:
        print(f"  Removing {n_prot:,} protein cells with no expression")
        changes += n_prot
        protein_gated = protein_gated[prot_keep]
        protein_log_shared = protein_log_shared[prot_keep]
        protein_expressing = protein_expressing[prot_keep]
        protein_shared_adata = protein_shared_adata[prot_keep].copy()
        protein_adata = protein_adata[prot_keep].copy()
    
    if changes == 0:
        break
    if iteration > 10:
        print("  WARNING: filtering did not converge")
        break

print(f"  Done in {iteration} iteration(s)")

# ============================================================
# DETECTION-AWARE Z-SCORE
# ============================================================

def normalize_expressing(data, expressing_mask, nonexpr_value):
    """Z-score expressing cells, set non-expressing to fixed value."""
    n_cells, n_features = data.shape
    normalized = np.full_like(data, nonexpr_value, dtype=np.float64)
    for j in range(n_features):
        expr = expressing_mask[:, j]
        if expr.sum() > 1:
            vals = data[expr, j]
            z = (vals - vals.mean()) / (vals.std() + 1e-8)
            normalized[expr, j] = np.clip(z, -3, 3)
    return normalized

rna_shared = normalize_expressing(rna_shared_lognorm, rna_expressing, ZERO_VALUE).astype(np.float32)
protein_shared = normalize_expressing(protein_log_shared, protein_expressing, ZERO_VALUE).astype(np.float32)

# Stats
rna_expr_vals = rna_shared[rna_expressing]
prot_expr_vals = protein_shared[protein_expressing]

print(f"\nRNA ({rna_shared.shape[1]} features, {rna_shared.shape[0]:,} cells):")
print(f"  Expressing: {rna_expressing.sum():,} values, mean={rna_expr_vals.mean():.3f}, std={rna_expr_vals.std():.3f}")

print(f"\nProtein ({protein_shared.shape[1]} features, {protein_shared.shape[0]:,} cells):")
print(f"  Expressing: {protein_expressing.sum():,} values, mean={prot_expr_vals.mean():.3f}, std={prot_expr_vals.std():.3f}")

# ============================================================
# REMOVE ZERO-VARIANCE FEATURES  
# ============================================================

rna_std = rna_shared.std(axis=0)
prot_std = protein_shared.std(axis=0)
valid_mask = (rna_std > 1e-6) & (prot_std > 1e-6)

if not valid_mask.all():
    n_removed = (~valid_mask).sum()
    print(f"\nRemoving {n_removed} zero-variance features")
    rna_shared = rna_shared[:, valid_mask]
    protein_shared = protein_shared[:, valid_mask]
    rna_expressing = rna_expressing[:, valid_mask]
    protein_expressing = protein_expressing[:, valid_mask]
    rna_protein_correspondence = rna_protein_correspondence[valid_mask]

rna_after_log = rna_shared_lognorm[:, valid_mask] if not valid_mask.all() else rna_shared_lognorm
protein_shared_raw = protein_log_shared[:, valid_mask] if not valid_mask.all() else protein_log_shared
rna_shared_raw = rna_shared_raw[:, valid_mask] if not valid_mask.all() else rna_shared_raw
rna_detection_mask = rna_expressing
protein_detection_mask = protein_expressing

rna_shared_adata = rna_shared_adata[:, valid_mask].copy() if not valid_mask.all() else rna_shared_adata
protein_shared_adata = protein_shared_adata[:, valid_mask].copy() if not valid_mask.all() else protein_shared_adata
rna_shared_adata.X = rna_shared
protein_shared_adata.X = protein_shared

# ============================================================
# SUMMARY
# ============================================================

n_features = rna_shared.shape[1]
print(f"\n" + "=" * 70)
print(f"FINAL: {n_features} shared features")
print(f"  RNA: {rna_shared.shape[0]:,} cells")
print(f"  Protein: {protein_shared.shape[0]:,} cells")
print("=" * 70)

In [None]:
# Visualize normalization - focus on expressing cells
from scipy.stats import rankdata, spearmanr
from adjustText import adjust_text

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

feature_names = list(rna_protein_correspondence[:, 0])
n_features = rna_shared.shape[1]

# Count expressing cells per marker
rna_expr_counts = [rna_expressing[:, i].sum() for i in range(n_features)]
prot_expr_counts = [protein_expressing[:, i].sum() for i in range(n_features)]

# Expression levels (mean of expressing cells, pre-normalization)
rna_expr_means = [rna_after_log[rna_expressing[:, i], i].mean() if rna_expressing[:, i].sum() > 0 else 0 
                  for i in range(n_features)]
prot_expr_means = [protein_shared_raw[protein_expressing[:, i], i].mean() if protein_expressing[:, i].sum() > 0 else 0 
                   for i in range(n_features)]

# Row 1: Expressing cell counts and distributions
ax = axes[0, 0]
x_pos = np.arange(n_features)
width = 0.35
ax.bar(x_pos - width/2, rna_expr_counts, width, label='RNA', color='steelblue', alpha=0.8)
ax.bar(x_pos + width/2, prot_expr_counts, width, label='Protein', color='darkorange', alpha=0.8)
ax.set_ylabel('Expressing cells')
ax.set_title('Cells Expressing Each Marker')
ax.set_xticks(x_pos)
ax.set_xticklabels([f[:5] for f in feature_names], rotation=45, ha='right', fontsize=7)
ax.legend()
ax.set_yscale('log')

ax = axes[0, 1]
rna_expr_vals = rna_shared[rna_expressing]
prot_expr_vals = protein_shared[protein_expressing]
bins = np.linspace(-3, 3, 50)
ax.hist(rna_expr_vals, bins=bins, alpha=0.6, density=True, label='RNA', color='steelblue')
ax.hist(prot_expr_vals, bins=bins, alpha=0.6, density=True, label='Protein', color='darkorange')
ax.set_title('Z-score Distribution\n(expressing cells only)')
ax.set_xlabel('Z-score')
ax.legend()

ax = axes[0, 2]
# Rank-scale expression means
rna_ranks = (rankdata(rna_expr_means) - 1) / (n_features - 1) if n_features > 1 else np.zeros(n_features)
prot_ranks = (rankdata(prot_expr_means) - 1) / (n_features - 1) if n_features > 1 else np.zeros(n_features)
ax.scatter(rna_ranks, prot_ranks, s=60, alpha=0.7, c='purple', edgecolors='black')
texts = [ax.text(rna_ranks[i], prot_ranks[i], feature_names[i][:5], fontsize=7) for i in range(n_features)]
ax.plot([0, 1], [0, 1], 'r--', alpha=0.5)
spearman_corr, _ = spearmanr(rna_expr_means, prot_expr_means)
ax.set_title(f'Expression Level Correlation\n(ρ={spearman_corr:.2f})')
ax.set_xlabel('RNA (rank)')
ax.set_ylabel('Protein (rank)')
adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle='-', color='gray', alpha=0.5))

# Row 2: Per-marker details
ax = axes[1, 0]
ax.bar(x_pos - width/2, rna_ranks, width, label='RNA', color='steelblue', alpha=0.8)
ax.bar(x_pos + width/2, prot_ranks, width, label='Protein', color='darkorange', alpha=0.8)
ax.set_ylabel('Expression level (rank)')
ax.set_title('Per-Marker Expression')
ax.set_xticks(x_pos)
ax.set_xticklabels([f[:5] for f in feature_names], rotation=45, ha='right', fontsize=7)
ax.legend()

# Pick a well-expressed marker for example
best_idx = np.argmax([rna_expr_counts[i] * prot_expr_counts[i] for i in range(n_features)])
best_name = feature_names[best_idx]

ax = axes[1, 1]
rna_vals = rna_shared[rna_expressing[:, best_idx], best_idx]
ax.hist(rna_vals, bins=30, alpha=0.7, color='steelblue', edgecolor='white')
ax.set_title(f'RNA: {best_name}\n({len(rna_vals):,} expressing cells)')
ax.set_xlabel('Z-score')

ax = axes[1, 2]
prot_vals = protein_shared[protein_expressing[:, best_idx], best_idx]
ax.hist(prot_vals, bins=30, alpha=0.7, color='darkorange', edgecolor='white')
ax.set_title(f'Protein: {best_name}\n({len(prot_vals):,} expressing cells)')
ax.set_xlabel('Z-score')

plt.tight_layout()
plt.show()

print(f"Expression correlation (Spearman): ρ = {spearman_corr:.3f}")

In [None]:
# Shared feature summary
print("=" * 60)
print("SHARED FEATURE SUMMARY")
print("=" * 60)

feature_stats = []
for i, (rna_gene, prot_marker) in enumerate(rna_protein_correspondence):
    rna_n = rna_detection_mask[:, i].sum()
    prot_n = protein_detection_mask[:, i].sum()
    feature_stats.append({
        'Marker': prot_marker,
        'Gene': rna_gene,
        'RNA_expressing': rna_n,
        'Protein_expressing': prot_n,
    })

stats_df = pd.DataFrame(feature_stats)
stats_df = stats_df.sort_values('RNA_expressing', ascending=False)

print(f"\n{'Marker':<12} {'Gene':<12} {'RNA cells':>12} {'Protein cells':>14}")
print("-" * 52)
for _, row in stats_df.iterrows():
    print(f"{row['Marker'][:12]:<12} {row['Gene'][:12]:<12} {row['RNA_expressing']:>12,} {row['Protein_expressing']:>14,}")

print(f"\n{'='*60}")
print(f"Total: {len(rna_protein_correspondence)} shared features")
print(f"RNA cells: {rna_shared.shape[0]:,}")
print(f"Protein cells: {protein_shared.shape[0]:,}")

In [None]:
# Protein active features - detection-aware z-score
# Gated mask defines expressing, z-score LOG values

protein_markers_active = [m for m in protein_adata.var_names if m not in EXCLUDED_MARKERS]

# Gated data for expressing mask
protein_adata_active = protein_adata[:, protein_markers_active].copy()
protein_gated_active = protein_adata_active.X.copy()
if sparse.issparse(protein_gated_active):
    protein_gated_active = protein_gated_active.toarray()

protein_active_expressing = protein_gated_active > 0

# Log layer for z-scoring
protein_log_full = protein_adata.layers['log']
if sparse.issparse(protein_log_full):
    protein_log_full = protein_log_full.toarray()

active_marker_indices = [list(protein_adata.var_names).index(m) for m in protein_markers_active]
protein_log_active = protein_log_full[:, active_marker_indices]

# Detection-aware z-score
protein_active_normalized = normalize_expressing(protein_log_active, protein_active_expressing, ZERO_VALUE)
protein_active_normalized = protein_active_normalized.astype(np.float32)
protein_adata_active.X = protein_active_normalized

prot_active_expr = protein_active_normalized[protein_active_expressing]
print(f"Protein active: {protein_adata_active.shape}")
print(f"  {len(protein_adata.var_names)} markers → {len(protein_markers_active)} (excluded {len(EXCLUDED_MARKERS)} non-immune)")
print(f"  Expressing: mean={prot_active_expr.mean():.3f}, std={prot_active_expr.std():.3f}")

In [None]:
# Extract active feature arrays - z-score expressing cells

# RNA active
rna_active_lognorm = rna_adata_lognorm.X.copy()
if sparse.issparse(rna_active_lognorm):
    rna_active_lognorm = rna_active_lognorm.toarray()

rna_active_expressing = rna_active_lognorm > 0
rna_active = normalize_expressing(rna_active_lognorm, rna_active_expressing, ZERO_VALUE).astype(np.float32)

# Protein active (already normalized in previous cell)
protein_active = protein_adata_active.X.copy()
if sparse.issparse(protein_active):
    protein_active = protein_active.toarray()

# Remove zero-variance features
rna_var = rna_active.std(axis=0)
protein_var = protein_active.std(axis=0)

rna_active = rna_active[:, rna_var > 1e-6]
protein_active = protein_active[:, protein_var > 1e-6]

print(f"\nActive arrays:")
print(f"  rna_active: {rna_active.shape}")
print(f"  protein_active: {protein_active.shape}")

In [None]:
# CRITICAL VALIDATION: Check array dimensions match
print("=" * 50)
print("DIMENSION VALIDATION")
print("=" * 50)
print(f"RNA shared cells:     {rna_shared.shape[0]}")
print(f"RNA active cells:     {rna_active.shape[0]}")
print(f"Protein shared cells: {protein_shared.shape[0]}")
print(f"Protein active cells: {protein_active.shape[0]}")
print()

assert rna_shared.shape[0] == rna_active.shape[0], \
    f"RNA mismatch: shared={rna_shared.shape[0]}, active={rna_active.shape[0]}"
assert protein_shared.shape[0] == protein_active.shape[0], \
    f"Protein mismatch: shared={protein_shared.shape[0]}, active={protein_active.shape[0]}"
assert rna_shared.shape[1] == protein_shared.shape[1], \
    f"Shared feature mismatch: RNA={rna_shared.shape[1]}, Protein={protein_shared.shape[1]}"

print("All dimensions validated!")
print(f"\nIntegrating {rna_active.shape[0]} RNA cells with {protein_active.shape[0]} protein cells")
print(f"Using {rna_shared.shape[1]} shared features for initialization")

## Step 5: MARIO - Matchability Test (Pre-Integration Diagnostic)

Before running integration, we test whether the two datasets have meaningful correspondence.
MARIO uses random sign flips to create a null distribution and computes p-values.

- **Low p-value** (< 0.05): Datasets are matchable
- **High p-value** (> 0.05): No significant correspondence detected

### Data Validation

Check for and handle NaN/Inf values that may result from normalization of sparse features.

In [None]:
# # Check and handle NaN/Inf values before MARIO
# # Detection-aware normalization can produce NaN for problematic features

# print("Checking for NaN/Inf values in shared arrays...")
# print(f"  rna_shared: NaN={np.isnan(rna_shared).sum()}, Inf={np.isinf(rna_shared).sum()}")
# print(f"  protein_shared: NaN={np.isnan(protein_shared).sum()}, Inf={np.isinf(protein_shared).sum()}")

# # Replace NaN/Inf with 0 (these are likely failed normalizations for sparse features)
# if np.isnan(rna_shared).any() or np.isinf(rna_shared).any():
#     print("\nCleaning rna_shared...")
#     rna_shared = np.nan_to_num(rna_shared, nan=0.0, posinf=0.0, neginf=0.0)
    
# if np.isnan(protein_shared).any() or np.isinf(protein_shared).any():
#     print("Cleaning protein_shared...")
#     protein_shared = np.nan_to_num(protein_shared, nan=0.0, posinf=0.0, neginf=0.0)

# # Also check active arrays
# print(f"\n  rna_active: NaN={np.isnan(rna_active).sum()}, Inf={np.isinf(rna_active).sum()}")
# print(f"  protein_active: NaN={np.isnan(protein_active).sum()}, Inf={np.isinf(protein_active).sum()}")

# if np.isnan(rna_active).any() or np.isinf(rna_active).any():
#     print("\nCleaning rna_active...")
#     rna_active = np.nan_to_num(rna_active, nan=0.0, posinf=0.0, neginf=0.0)
    
# if np.isnan(protein_active).any() or np.isinf(protein_active).any():
#     print("Cleaning protein_active...")
#     protein_active = np.nan_to_num(protein_active, nan=0.0, posinf=0.0, neginf=0.0)

# print("\nArrays cleaned and ready for integration.")

In [None]:
# # Subsample for MARIO (MARIO requires n1 <= n2, and for speed we subsample)
# np.random.seed(42)

# # MARIO needs RNA (smaller) to be df1 and Protein (larger) to be df2
# n_rna_subsample = min(2000, rna_shared.shape[0])
# n_prot_subsample = min(10000, protein_shared.shape[0])

# rna_idx_subsample = np.random.choice(rna_shared.shape[0], n_rna_subsample, replace=False)
# prot_idx_subsample = np.random.choice(protein_shared.shape[0], n_prot_subsample, replace=False)

# # Create DataFrames with overlapping column names (required by MARIO)
# shared_feature_names = [f"feat_{i}" for i in range(rna_shared.shape[1])]

# # Extract subsamples and ensure no NaN values
# rna_subsample = rna_shared[rna_idx_subsample].copy()
# prot_subsample = protein_shared[prot_idx_subsample].copy()

# # Final NaN check on subsamples
# rna_subsample = np.nan_to_num(rna_subsample, nan=0.0, posinf=0.0, neginf=0.0)
# prot_subsample = np.nan_to_num(prot_subsample, nan=0.0, posinf=0.0, neginf=0.0)

# rna_df_mario = pd.DataFrame(rna_subsample, columns=shared_feature_names)
# prot_df_mario = pd.DataFrame(prot_subsample, columns=shared_feature_names)

# print(f"MARIO subsample sizes:")
# print(f"  RNA: {rna_df_mario.shape}")
# print(f"  Protein: {prot_df_mario.shape}")
# print(f"  NaN in RNA df: {rna_df_mario.isna().sum().sum()}")
# print(f"  NaN in Protein df: {prot_df_mario.isna().sum().sum()}")

In [None]:
# # Initialize MARIO
# mario = Mario(rna_df_mario, prot_df_mario, normalization=True)

# # Specify matching parameters
# # n_matched_per_cell: how many protein cells to match with each RNA cell
# n_matched = max(1, n_prot_subsample // n_rna_subsample)
# mario.specify_matching_params(n_matched_per_cell=n_matched)

# print(f"Matching {n_matched} protein cells per RNA cell")

In [None]:
# # Compute distance using overlapping features
# n_ovlp_components = min(15, rna_shared.shape[1] - 1)
# dist_ovlp, singular_vals = mario.compute_dist_ovlp(n_components=n_ovlp_components)

# print(f"Distance matrix shape: {dist_ovlp.shape}")
# print(f"Singular values: {singular_vals[:5]}...")

# # Plot singular values
# plt.figure(figsize=(8, 4))
# plt.plot(singular_vals, 'bo-')
# plt.xlabel('Component')
# plt.ylabel('Singular Value')
# plt.title('MARIO: Singular Values of Stacked Overlap Features')
# plt.show()

In [None]:
# # Initial matching using overlap features
# print("Finding initial matching using overlap features...")
# matching_ovlp = mario.match_cells('ovlp', sparsity=None, mode='auto')

# Count matched cells
# n_matched_cells = sum(1 for m in matching_ovlp if len(m) > 0)
# print(f"Matched {n_matched_cells}/{len(matching_ovlp)} RNA cells")

In [None]:
# # Add active features (all HVGs) for refined matching
# # For MARIO, we need DataFrames with:
# # - Overlapping columns (shared features) with same names
# # - Non-overlapping columns (active features) with different names

# # RNA: shared features + active features
# rna_active_subsample = rna_active[rna_idx_subsample].copy()
# rna_active_subsample = np.nan_to_num(rna_active_subsample, nan=0.0, posinf=0.0, neginf=0.0)
# rna_active_names = [f"rna_feat_{i}" for i in range(rna_active_subsample.shape[1])]

# rna_df_full = pd.DataFrame(
#     np.hstack([rna_subsample, rna_active_subsample]),
#     columns=shared_feature_names + rna_active_names
# )

# # Protein: shared features + active features
# prot_active_subsample = protein_active[prot_idx_subsample].copy()
# prot_active_subsample = np.nan_to_num(prot_active_subsample, nan=0.0, posinf=0.0, neginf=0.0)
# prot_active_names = [f"prot_feat_{i}" for i in range(prot_active_subsample.shape[1])]

# prot_df_full = pd.DataFrame(
#     np.hstack([prot_subsample, prot_active_subsample]),
#     columns=shared_feature_names + prot_active_names
# )

# print(f"Full DataFrames for MARIO:")
# print(f"  RNA: {rna_df_full.shape} ({len(shared_feature_names)} shared + {len(rna_active_names)} active)")
# print(f"  Protein: {prot_df_full.shape} ({len(shared_feature_names)} shared + {len(prot_active_names)} active)")
# print(f"  NaN check - RNA: {rna_df_full.isna().sum().sum()}, Protein: {prot_df_full.isna().sum().sum()}")

In [None]:
# # Create new MARIO object with full features
# mario_full = Mario(rna_df_full, prot_df_full, normalization=False)
# mario_full.specify_matching_params(n_matched_per_cell=n_matched)

# # Compute distance using overlap features
# _ = mario_full.compute_dist_ovlp(n_components=n_ovlp_components)

# # Initial matching
# _ = mario_full.match_cells('ovlp', sparsity=None, mode='auto')

# # Compute distance using ALL features (CCA refinement)
# # NOTE: Use conservative number of CCA components.
# # With few protein features relative to RNA features and matched samples,
# # CCA can find trivially perfect correlations if given too many components.
# # Rule of thumb: use min(n_shared - 1, sqrt(n_prot_active))
# n_prot_active = prot_df_full.shape[1] - len(shared_feature_names)
# n_cca_components = min(
#     len(shared_feature_names) - 1,  # No more than shared features
#     int(np.sqrt(n_prot_active)) + 1,  # Conservative based on protein features
#     8  # Hard cap for this data
# )
# n_cca_components = max(3, n_cca_components)  # At least 3 components
# print(f"Using {n_cca_components} CCA components")
# print(f"  (shared features: {len(shared_feature_names)}, protein active: {n_prot_active})")

# dist_all, cancor = mario_full.compute_dist_all('ovlp', n_components=n_cca_components)

# # Interpret canonical correlations
# print(f"\nCanonical correlations: {np.round(cancor, 4)}")

# if np.allclose(cancor, 1.0, atol=0.01):
#     print("\nNOTE: Canonical correlations are very high (~1.0).")
#     print("This is common when protein features are few relative to matched samples.")
#     print("The CCA can perfectly align matched pairs in this low-dimensional space.")
#     print("Matching quality depends on how well CCA generalizes to unmatched cells.")
# elif np.mean(cancor) > 0.7:
#     print("\nGood: High canonical correlations indicate strong alignment.")
# else:
#     print("\nNote: Moderate correlations - may indicate weaker cross-modal alignment.")

# # Plot canonical correlations
# plt.figure(figsize=(8, 4))
# plt.bar(range(len(cancor)), cancor)
# plt.xlabel('CCA Component')
# plt.ylabel('Canonical Correlation')
# plt.title('MARIO: Canonical Correlations')
# plt.ylim(0, 1.1)
# plt.axhline(y=0.7, color="orange", linestyle="--", alpha=0.5, label="Good threshold")
# plt.legend()
# plt.show()

In [None]:
# # Match using all features
# matching_all = mario_full.match_cells('all', sparsity=None, mode='auto')

# n_matched_all = sum(1 for m in matching_all if len(m) > 0)
# print(f"Matched {n_matched_all}/{len(matching_all)} RNA cells using all features")

In [None]:
# # DIAGNOSTIC: Check canonical correlations before matchability test

# print("="*60)
# print("MATCHABILITY DIAGNOSTIC")
# print("="*60)

# # Check the canonical correlations from the existing matching
# print("\n1. OBSERVED CANONICAL CORRELATIONS:")
# print(f"   From compute_dist_all (stored): {cancor[:5] if 'cancor' in dir() else 'Not computed'}")

# # Check data properties
# print("\n2. DATA PROPERTIES:")
# print(f"   mario_full.df1 shape: {mario_full.df1.shape}")
# print(f"   mario_full.df2 shape: {mario_full.df2.shape}")
# print(f"   Overlap features: {len(mario_full.ovlp_features)}")

# # Check for zero-variance features
# df1_std = mario_full.df1.std()
# df2_std = mario_full.df2.std()
# print(f"\n   df1 zero-variance features: {(df1_std < 1e-10).sum()}")
# print(f"   df2 zero-variance features: {(df2_std < 1e-10).sum()}")

# # Check data scale
# print("\n3. DATA SCALE:")
# print(f"   df1 mean: {mario_full.df1.values.mean():.4f}, std: {mario_full.df1.values.std():.4f}")
# print(f"   df2 mean: {mario_full.df2.values.mean():.4f}, std: {mario_full.df2.values.std():.4f}")

# # CHECK MATCHING - this is critical
# print("\n4. MATCHING STATISTICS:")
# n_matched_ovlp = sum(1 for m in mario_full.matching['ovlp'] if len(m) > 0)
# n_matched_all = sum(1 for m in mario_full.matching['all'] if len(m) > 0)
# print(f"   Cells matched (overlap): {n_matched_ovlp} / {mario_full.n1}")
# print(f"   Cells matched (all):     {n_matched_all} / {mario_full.n1}")

# # Check the aligned data dimensions for CCA
# from mario import embed
# X_aligned = []
# Y_aligned = []
# for ii in range(mario_full.n1):
#     if len(mario_full.matching['ovlp'][ii]) > 0:
#         X_aligned.append(mario_full.df1.iloc[ii, :].values)
#         Y_aligned.append(mario_full.df2.iloc[mario_full.matching['ovlp'][ii]].mean(axis=0).values)

# X_aligned = np.array(X_aligned)
# Y_aligned = np.array(Y_aligned)
# print(f"\n5. CCA INPUT DIMENSIONS:")
# print(f"   X (RNA) aligned: {X_aligned.shape}")
# print(f"   Y (Protein) aligned: {Y_aligned.shape}")
# print(f"   Ratio features/samples (RNA): {X_aligned.shape[1]/X_aligned.shape[0]:.1f}")

# # THE PROBLEM: CCA with features >> samples gives trivial perfect correlations!
# if X_aligned.shape[1] > X_aligned.shape[0]:
#     print("\n   ⚠️  WARNING: More features than samples!")
#     print("   CCA will overfit and give meaningless correlations of 1.0")
#     print("   This is why matchability test returns p=1")

# # Test CCA with ONLY overlap features
# print("\n6. CCA WITH OVERLAP FEATURES ONLY:")
# X_ovlp = mario_full.df1[mario_full.ovlp_features].iloc[[i for i in range(mario_full.n1) if len(mario_full.matching['ovlp'][i]) > 0]].values
# Y_ovlp = np.array([mario_full.df2[mario_full.ovlp_features].iloc[mario_full.matching['ovlp'][i]].mean(axis=0).values 
#                    for i in range(mario_full.n1) if len(mario_full.matching['ovlp'][i]) > 0])
# print(f"   X_ovlp shape: {X_ovlp.shape}")
# print(f"   Y_ovlp shape: {Y_ovlp.shape}")

# try:
#     n_comp = min(5, X_ovlp.shape[1]-1, X_ovlp.shape[0]-1)
#     cancor_ovlp_only, _ = embed.get_cancor(X_ovlp, Y_ovlp, n_components=n_comp)
#     print(f"   Canonical correlations (overlap only): {cancor_ovlp_only}")
#     print(f"   Mean: {np.mean(cancor_ovlp_only):.4f}")
# except Exception as e:
#     print(f"   Error: {e}")

# print("\n" + "="*60)


In [None]:
# # Run matchability test
# print("=" * 60)
# print("MARIO MATCHABILITY TEST")
# print("=" * 60)
# print("\nRunning statistical test for dataset matchability...")
# print("(This uses random sign flips to create null distribution)")
# print()

# # CRITICAL FIX: Clean NaN/Inf values in MARIO dataframes before matchability test
# # The matchability test internally uses CCA which cannot handle NaN values
# print("Cleaning MARIO dataframes for CCA compatibility...")

# # Clean df1 (RNA) - aggressive column-by-column approach
# mario_full.df1 = mario_full.df1.copy()
# for col in mario_full.df1.columns:
#     mario_full.df1[col] = np.nan_to_num(mario_full.df1[col].values, nan=0.0, posinf=0.0, neginf=0.0)

# # Clean df2 (Protein) - aggressive column-by-column approach
# mario_full.df2 = mario_full.df2.copy()
# for col in mario_full.df2.columns:
#     mario_full.df2[col] = np.nan_to_num(mario_full.df2[col].values, nan=0.0, posinf=0.0, neginf=0.0)

# # Verify no NaN/Inf remain
# df1_clean = not (np.isnan(mario_full.df1.values).any() or np.isinf(mario_full.df1.values).any())
# df2_clean = not (np.isnan(mario_full.df2.values).any() or np.isinf(mario_full.df2.values).any())
# print(f"  df1 clean: {df1_clean} (NaN: {np.isnan(mario_full.df1.values).sum()}, Inf: {np.isinf(mario_full.df1.values).sum()})")
# print(f"  df2 clean: {df2_clean} (NaN: {np.isnan(mario_full.df2.values).sum()}, Inf: {np.isinf(mario_full.df2.values).sum()})")
# assert df1_clean and df2_clean, "Failed to clean NaN/Inf values"

# # Ensure both initial (ovlp) and refined (all) matching are complete
# print("Verifying initial and refined matching are complete...")

# # Check if matching has been done, if not redo it
# if not hasattr(mario_full, 'matching_ovlp') or mario_full.matching_ovlp is None:
#     print("  Re-running initial matching (overlap features)...")
#     mario_full.match_cells('ovlp', sparsity=None, mode='auto')

# if not hasattr(mario_full, 'matching_all') or mario_full.matching_all is None:
#     print("  Re-running refined matching (all features)...")
#     mario_full.match_cells('all', sparsity=None, mode='auto')

# print("  Both matchings confirmed. Proceeding with matchability test...\n")

# # Note: This can take a few minutes
# # Reduce n_sim if it takes too long
# pval_ovlp, pval_all = mario_full.matchable(
#     n_sim=10,           # Number of simulations (increase for more accuracy)
#     top_k=5,            # Use top-k canonical correlations
#     flip_prob=0.3,      # Probability of sign flip
#     subsample_prop=1,   # Subsample for speed
#     verbose=True
# )

# print(f"\n{'='*60}")
# print("MATCHABILITY TEST RESULTS")
# print("="*60)
# print(f"P-value (overlap features only): {pval_ovlp:.4f}")
# print(f"P-value (all features):          {pval_all:.4f}")
# print()

# if pval_ovlp < 0.05 or pval_all < 0.05:
#     print("RESULT: Datasets appear to be MATCHABLE (p < 0.05)")
#     print("  The correspondence between modalities is statistically significant.")
# else:
#     print("RESULT: Datasets may NOT be well-matched (p >= 0.05)")
#     print("  Proceed with caution - results may be unreliable.")



## Step 6: MARIO - Interpolation (Optimal Weight Search)

MARIO searches for the optimal weight between:
- Distance from **overlap features only**
- Distance from **all features** (via CCA)

The optimal weight is selected based on canonical correlations.

In [None]:
# # Run interpolation to find optimal weight
# print("Searching for optimal interpolation weight...")
# print("(Testing weights from 0 to 1)")
# print()

# best_wt, best_matching = mario_full.interpolate(
#     n_wts=10,     # Number of weights to try
#     top_k=5,      # Use top-k canonical correlations to evaluate
#     verbose=True
# )

# print(f"\nOptimal weight: {best_wt:.2f}")
# print(f"  (0 = use only overlap features, 1 = use only CCA features)")

# n_matched_best = sum(1 for m in best_matching if len(m) > 0)
# print(f"\nMatched {n_matched_best}/{len(best_matching)} RNA cells with optimal weight")

In [None]:
# # Filter bad matches using joint regularized clustering
# print("\nFiltering bad matches using joint regularized clustering...")

# n_clusters_filter = min(15, n_rna_subsample // 50)  # Aim for ~50 cells per cluster
# n_clusters_filter = max(5, n_clusters_filter)

# filtered_matching = mario_full.filter_bad_matches(
#     matching='wted',           # Use the interpolated matching
#     n_clusters=n_clusters_filter,
#     n_components=min(15, n_cca_components),
#     bad_prop=0.1,              # Remove ~10% of worst matches
#     max_iter=30,
#     verbose=True
# )

# n_matched_filtered = sum(1 for m in filtered_matching if len(m) > 0)
# print(f"\nAfter filtering: {n_matched_filtered}/{len(filtered_matching)} RNA cells matched")
# print(f"Removed {n_matched_best - n_matched_filtered} bad matches")

In [None]:
# # Optional: KNN matching for softer assignments
# knn_matching = mario_full.knn_matching(dist_mat='wted', k=5)

# print(f"KNN matching: each RNA cell matched to {5} nearest protein cells")

In [None]:
# # Compute CCA embedding for visualization
# from mario import embed

# # Align the datasets using the filtered matching
# X_aligned = []
# Y_aligned = []
# matched_rna_indices_mario = []
# matched_prot_indices_mario = []

# for i, matches in enumerate(filtered_matching):
#     if len(matches) > 0:
#         X_aligned.append(rna_df_full.iloc[i].values)
#         # Average the matched protein cells
#         Y_aligned.append(prot_df_full.iloc[matches].mean(axis=0).values)
#         matched_rna_indices_mario.append(rna_idx_subsample[i])
#         matched_prot_indices_mario.append(prot_idx_subsample[matches[0]])  # Take first match

# X_aligned = np.array(X_aligned)
# Y_aligned = np.array(Y_aligned)

# print(f"Aligned arrays: RNA {X_aligned.shape}, Protein {Y_aligned.shape}")

# # Fit CCA for embedding
# embed_dim = min(20, X_aligned.shape[1], Y_aligned.shape[1])
# cancor_embed, cca = embed.get_cancor(X_aligned, Y_aligned, n_components=embed_dim)

# # Get CCA scores
# rna_cca_mario, prot_cca_mario = cca.transform(X_aligned, Y_aligned)

# print(f"MARIO CCA embedding: {rna_cca_mario.shape}")

In [None]:
# # Visualize MARIO results
# from sklearn.manifold import TSNE

# # Combine embeddings
# combined_mario = np.vstack([rna_cca_mario, prot_cca_mario])
# labels_mario = ['RNA'] * len(rna_cca_mario) + ['Protein'] * len(prot_cca_mario)

# # Run t-SNE (faster than UMAP for small datasets)
# tsne = TSNE(n_components=2, random_state=42, perplexity=30)
# embedding_2d = tsne.fit_transform(combined_mario[:, :10])  # Use first 10 CCA components

# # Plot
# plt.figure(figsize=(10, 8))
# for label in ['RNA', 'Protein']:
#     mask = np.array(labels_mario) == label
#     plt.scatter(embedding_2d[mask, 0], embedding_2d[mask, 1], 
#                 label=label, alpha=0.5, s=10)
# plt.xlabel('t-SNE 1')
# plt.ylabel('t-SNE 2')
# plt.title('MARIO: Joint Embedding (t-SNE of CCA scores)')
# plt.legend()
# plt.show()

---
# MaxFuse Integration (Recommended for Cross-Modal Data)

MaxFuse is specifically designed for **cross-modal integration** with:
- **Weak linkage**: Few or uninformative shared features
- **Different modalities**: RNA-seq, CODEX/protein, ATAC-seq, etc.
- **Scalable architecture**: Batch processing with pivot propagation

## Key Features
1. **Graph-based smoothing**: Reduces noise before matching
2. **Iterative CCA refinement**: Improves alignment quality
3. **Pivot propagation**: Scales to large datasets

---

In [None]:
# # Load checkpoint data if MARIO section was skipped
# # This cell is idempotent - safe to run even if data already loaded

# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# from maxfuse import Fusor

# checkpoint_dir = '../results/2_integration'

# # Check if we need to load from checkpoint
# need_checkpoint = False
# try:
#     # Check if required variables exist and have data
#     _ = rna_shared.shape
#     _ = protein_shared.shape
#     _ = rna_active.shape
#     _ = protein_active.shape
#     print('Data already loaded - using existing arrays.')
# except NameError:
#     need_checkpoint = True
#     print('Data not found - loading from checkpoint...')

# if need_checkpoint:
#     # Load arrays from checkpoint
#     rna_shared = np.load(f'{checkpoint_dir}/checkpoint_rna_shared.npy')
#     protein_shared = np.load(f'{checkpoint_dir}/checkpoint_protein_shared.npy')
#     rna_active = np.load(f'{checkpoint_dir}/checkpoint_rna_active.npy')
#     protein_active = np.load(f'{checkpoint_dir}/checkpoint_protein_active.npy')
    
#     # Load correspondence
#     correspondence_df = pd.read_csv(f'{checkpoint_dir}/checkpoint_correspondence.csv')
#     rna_protein_correspondence = correspondence_df.values
    
#     print(f'Loaded from checkpoint:')
# print(f'  rna_shared: {rna_shared.shape}')
# print(f'  protein_shared: {protein_shared.shape}')
# print(f'  rna_active: {rna_active.shape}')
# print(f'  protein_active: {protein_active.shape}')


In [None]:
# # PRE-FILTER: Remove non-lymphocyte protein cells
# # Use GATED data (not z-scored) to identify expressing cells

# print("=" * 60)
# print("PRE-FILTERING: Non-lymphocyte protein cells")
# print("=" * 60)

# # Use gated data from protein_adata.X (0 = non-expressing, >0 = expressing)
# protein_gated_full = protein_adata.X.copy()
# if hasattr(protein_gated_full, 'toarray'):
#     protein_gated_full = protein_gated_full.toarray()

# marker_names = list(protein_adata.var_names)
# print(f"Total markers: {len(marker_names)}")

# # Define lymphocyte markers
# LYMPHOCYTE_MARKERS = ['CD3e', 'CD4', 'CD8', 'CD20', 'CD79a']
# available_lymph = [m for m in LYMPHOCYTE_MARKERS if m in marker_names]
# print(f"Lymphocyte markers found: {available_lymph}")

# if len(available_lymph) == 0:
#     print("WARNING: No lymphocyte markers found!")
# else:
#     # Get lymphocyte marker expression (gated - binary expressing or not)
#     lymph_idx = [marker_names.index(m) for m in available_lymph]
#     lymph_expressing = protein_gated_full[:, lymph_idx] > 0  # Boolean: expressing or not
    
#     # Count how many lymphocyte markers each cell expresses
#     n_lymph_markers = lymph_expressing.sum(axis=1)
#     any_lymph = lymph_expressing.any(axis=1)  # Expresses at least one
    
#     # Visualize
#     fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
#     # Number of lymphocyte markers expressed
#     ax = axes[0]
#     counts = np.bincount(n_lymph_markers, minlength=len(available_lymph)+1)
#     ax.bar(range(len(counts)), counts, edgecolor='black')
#     ax.set_xlabel('Number of Lymphocyte Markers Expressed')
#     ax.set_ylabel('Cell Count')
#     ax.set_title('Lymphocyte Marker Expression')
#     for i, c in enumerate(counts):
#         if c > 0:
#             ax.text(i, c + counts.max()*0.02, f'{c:,}', ha='center', fontsize=9)
    
#     # Per-marker expression counts
#     ax = axes[1]
#     expr_counts = [lymph_expressing[:, i].sum() for i in range(len(available_lymph))]
#     ax.bar(available_lymph, expr_counts, edgecolor='black')
#     ax.set_ylabel('Cells Expressing')
#     ax.set_title('Per-Marker Expression')
#     ax.tick_params(axis='x', rotation=45)
#     for i, c in enumerate(expr_counts):
#         ax.text(i, c + max(expr_counts)*0.02, f'{c:,}', ha='center', fontsize=9)
    
#     # Tissue breakdown
#     ax = axes[2]
#     tissues = protein_adata.obs['Tissue'].values
#     for tissue in np.unique(tissues):
#         t_mask = tissues == tissue
#         t_lymph = n_lymph_markers[t_mask]
#         ax.hist(t_lymph, bins=range(len(available_lymph)+2), alpha=0.6, label=f'{tissue} ({t_mask.sum():,})')
#     ax.set_xlabel('Number of Lymphocyte Markers')
#     ax.set_ylabel('Count')
#     ax.set_title('By Tissue')
#     ax.legend()
    
#     plt.tight_layout()
#     plt.show()
    
#     # Summary
#     print(f"\nCells expressing 0 lymphocyte markers: {(n_lymph_markers == 0).sum():,}")
#     print(f"Cells expressing 1+ lymphocyte markers: {any_lymph.sum():,}")
    
#     # Store for next cell
#     lymph_score = n_lymph_markers  # Use count of markers as score
#     print("\n→ Run next cell to apply filter")

In [None]:
# # APPLY LYMPHOCYTE FILTER
# # Keep cells expressing at least 1 lymphocyte marker

# MIN_LYMPH_MARKERS = 1  # Cells must express at least this many lymphocyte markers
# MAX_LYMPH_MARKERS = 2

# keep_mask = (lymph_score >= MIN_LYMPH_MARKERS) & (lymph_score <= MAX_LYMPH_MARKERS)

# n_before = len(keep_mask)
# n_after = keep_mask.sum()
# n_removed = n_before - n_after

# print("=" * 60)
# print("APPLYING LYMPHOCYTE FILTER")
# print("=" * 60)
# print(f"Threshold: express at least {MIN_LYMPH_MARKERS} and at most {MAX_LYMPH_MARKERS} lymphocyte marker(s)")
# print(f"Cells before: {n_before:,}")
# print(f"Cells after:  {n_after:,}")
# print(f"Removed:      {n_removed:,} ({n_removed/n_before*100:.1f}%)")

# # Per-tissue breakdown
# tissues = protein_adata.obs['Tissue'].values
# for t in np.unique(tissues):
#     t_mask = tissues == t
#     t_kept = keep_mask[t_mask].sum()
#     print(f"  {t}: {t_mask.sum():,} → {t_kept:,} ({t_kept/t_mask.sum()*100:.0f}% retained)")

# # Apply filter
# protein_shared = protein_shared[keep_mask]
# protein_active = protein_active[keep_mask]
# protein_adata = protein_adata[keep_mask].copy()
# protein_expressing = protein_expressing[keep_mask]
# protein_detection_mask = protein_detection_mask[keep_mask]

# print(f"\nFiltered arrays:")
# print(f"  protein_shared: {protein_shared.shape}")
# print(f"  protein_active: {protein_active.shape}")

In [None]:
# # DIAGNOSTIC: Investigate the gap between -3.0 and -2.8 in z-scored expression
# # This gap shouldn't exist for standard z-scored continuous data

# print("=" * 60)
# print("INVESTIGATING Z-SCORE GAP")
# print("=" * 60)

# # Get marker names - handle case where protein_shared may have fewer columns than protein_shared_adata
# n_markers = protein_shared.shape[1]
# if hasattr(protein_shared_adata, 'var_names') and len(protein_shared_adata.var_names) == n_markers:
#     marker_names = list(protein_shared_adata.var_names)
# else:
#     # Fallback: use generic names or try to get from protein_adata
#     print(f"WARNING: protein_shared has {n_markers} markers but protein_shared_adata has {len(protein_shared_adata.var_names)}")
#     print("Using generic marker names. Check if arrays are out of sync.")
#     marker_names = [f'Marker_{i}' for i in range(n_markers)]

# # Check distribution of protein_shared values
# fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# # 1. Overall histogram with fine bins around the gap
# ax = axes[0, 0]
# hist_vals, bin_edges, _ = ax.hist(protein_shared.flatten(), bins=200, edgecolor='none', alpha=0.7)
# ax.axvline(-3.0, color='red', linestyle='--', label='-3.0')
# ax.axvline(-2.8, color='orange', linestyle='--', label='-2.8')
# ax.set_xlabel('Z-scored Expression')
# ax.set_ylabel('Count')
# ax.set_title('Protein Shared - Full Distribution')
# ax.legend()
# ax.set_xlim(-5, 5)

# # 2. Zoom into the gap region
# ax = axes[0, 1]# APPLY LYMPHOCYTE FILTER
# # Keep cells expressing at least 1 lymphocyte marker

# MIN_LYMPH_MARKERS = 1  # Cells must express at least this many lymphocyte markers
# MAX_LYMPH_MARKERS = 2

# keep_mask = (lymph_score >= MIN_LYMPH_MARKERS) & (lymph_score <= MAX_LYMPH_MARKERS)

# n_before = len(keep_mask)
# n_after = keep_mask.sum()
# n_removed = n_before - n_after

# print("=" * 60)
# print("APPLYING LYMPHOCYTE FILTER")
# print("=" * 60)
# print(f"Threshold: express at least {MIN_LYMPH_MARKERS} and at most {MAX_LYMPH_MARKERS} lymphocyte marker(s)")
# print(f"Cells before: {n_before:,}")
# print(f"Cells after:  {n_after:,}")
# print(f"Removed:      {n_removed:,} ({n_removed/n_before*100:.1f}%)")

# # Per-tissue breakdown
# tissues = protein_adata.obs['Tissue'].values
# for t in np.unique(tissues):
#     t_mask = tissues == t
#     t_kept = keep_mask[t_mask].sum()
#     print(f"  {t}: {t_mask.sum():,} → {t_kept:,} ({t_kept/t_mask.sum()*100:.0f}% retained)")

# # Apply filter
# protein_shared = protein_shared[keep_mask]
# protein_active = protein_active[keep_mask]
# protein_adata = protein_adata[keep_mask].copy()
# protein_expressing = protein_expressing[keep_mask]
# protein_detection_mask = protein_detection_mask[keep_mask]

# print(f"\nFiltered arrays:")
# print(f"  protein_shared: {protein_shared.shape}")
# print(f"  protein_active: {protein_active.shape}")
# gap_region = protein_shared.flatten()[(protein_shared.flatten() > -3.5) & (protein_shared.flatten() < -2.5)]
# ax.hist(gap_region, bins=100, edgecolor='none', alpha=0.7)
# ax.axvline(-3.0, color='red', linestyle='--')
# ax.axvline(-2.8, color='orange', linestyle='--')
# ax.set_xlabel('Z-scored Expression')
# ax.set_ylabel('Count')
# ax.set_title('Zoomed: -3.5 to -2.5')

# # 3. Check if gap exists per marker
# ax = axes[0, 2]
# gap_counts = []
# for i in range(n_markers):
#     vals = protein_shared[:, i]
#     in_gap = ((vals > -3.0) & (vals < -2.8)).sum()
#     gap_counts.append(in_gap)
# ax.bar(range(n_markers), gap_counts)
# ax.set_xlabel('Marker Index')
# ax.set_ylabel('Values in Gap (-3.0 to -2.8)')
# ax.set_title('Gap Count per Marker')

# # 4. Check the raw protein data before z-scoring
# ax = axes[1, 0]
# # Reload raw protein shared to check
# protein_shared_raw_check = protein_shared_adata.X.copy()
# # Reverse z-score to see original distribution pattern
# ax.hist(protein_shared_raw_check.flatten(), bins=200, edgecolor='none', alpha=0.7)
# ax.set_xlabel('Current protein_shared values')
# ax.set_ylabel('Count')
# ax.set_title('Current protein_shared (should be z-scored)')

# # 5. Check for exact -3.0 values (would indicate detection-aware normalization was applied)
# ax = axes[1, 1]
# exact_neg3 = np.isclose(protein_shared, -3.0, atol=0.01).sum(axis=0)
# ax.bar(range(n_markers), exact_neg3)
# ax.set_xticks(range(n_markers))
# ax.set_xticklabels(marker_names, rotation=45, ha='right', fontsize=8)
# ax.set_ylabel('Count of values ≈ -3.0')
# ax.set_title('Exact -3.0 Values per Marker\n(indicates detection-aware norm was applied)')

# # 6. Summary statistics
# ax = axes[1, 2]
# ax.axis('off')

# n_exact_neg3 = np.isclose(protein_shared, -3.0, atol=0.01).sum()
# n_in_gap = ((protein_shared > -3.0) & (protein_shared < -2.8)).sum()
# n_below_neg3 = (protein_shared < -3.0).sum()
# n_total = protein_shared.size

# summary = f"""Gap Investigation Summary
# {'='*40}

# Total values: {n_total:,}

# Values exactly at -3.0 (±0.01): {n_exact_neg3:,} ({n_exact_neg3/n_total*100:.2f}%)
# Values in gap (-3.0 to -2.8):   {n_in_gap:,} ({n_in_gap/n_total*100:.4f}%)
# Values below -3.0:              {n_below_neg3:,} ({n_below_neg3/n_total*100:.2f}%)

# DIAGNOSIS:
# """

# if n_exact_neg3 > n_total * 0.01:
#     summary += """
# → PROBLEM: Detection-aware normalization was 
#   applied to PROTEIN data (should only be RNA)
# → The -3.0 spike is from setting undetected
#   values to a fixed value
# → FIX: Re-run normalization cell, or check if
#   old cached protein_shared was loaded
# """
# else:
#     summary += """
# → No detection-aware normalization detected
# → Gap may be natural data distribution
# → Check raw protein data for discrete values
# """

# ax.text(0.05, 0.95, summary, transform=ax.transAxes, fontsize=10,
#         verticalalignment='top', fontfamily='monospace')

# plt.tight_layout()
# plt.show()

# # Print per-marker statistics
# print("\nPer-marker statistics for values near -3.0:")
# print("-" * 60)
# print(f"{'Marker':<15} {'≈-3.0':>10} {'In Gap':>10} {'Min':>10} {'Mean':>10}")
# print("-" * 60)
# for i, marker in enumerate(marker_names):
#     vals = protein_shared[:, i]
#     n_neg3 = np.isclose(vals, -3.0, atol=0.01).sum()
#     n_gap = ((vals > -3.0) & (vals < -2.8)).sum()
#     print(f"{marker:<15} {n_neg3:>10,} {n_gap:>10,} {vals.min():>10.2f} {vals.mean():>10.3f}")

## Step 7: MaxFuse Integration

MaxFuse performs the integration in several stages:
1. Split data into batches for scalability
2. Construct k-NN graphs and cluster cells
3. Find initial pivot matches using shared features
4. Refine pivots using CCA on all features
5. Propagate matching to all cells

In [None]:
# Create Fusor - let MaxFuse cluster automatically
# IMPORTANT: Run the pre-filter cell BEFORE this cell

# Verify we're using filtered data
print("=" * 60)
print("CREATING FUSOR WITH CURRENT ARRAYS")
print("=" * 60)
print(f"  rna_shared:     {rna_shared.shape}")
print(f"  protein_shared: {protein_shared.shape}")
print(f"  rna_active:     {rna_active.shape}")
print(f"  protein_active: {protein_active.shape}")

# Sanity check - if protein has >400k cells, filter probably didn't run
# if protein_shared.shape[0] > 400000:
#     print("\n⚠️  WARNING: protein_shared has >400k cells!")
#     print("    Did you run the pre-filter cell first?")
#     print("    Re-run from the pre-filter cell to apply filtering.")

fusor = Fusor(
    shared_arr1=rna_shared,
    shared_arr2=protein_shared,
    active_arr1=rna_active,
    active_arr2=protein_active,
    labels1=None,  # Let MaxFuse cluster
    labels2=None,
    method='centroid_shrinkage'
)

print(f"\nFusor created successfully.")

In [None]:
# # DEBUG: Verify filtered data made it to Fusor
# # Run this AFTER creating the Fusor

# print("=" * 60)
# print("DATA FLOW VERIFICATION")
# print("=" * 60)

# # Check global array shapes
# print("\n1. Global array shapes:")
# print(f"   protein_shared: {protein_shared.shape}")
# print(f"   protein_active: {protein_active.shape}")

# # Check what Fusor actually received
# print("\n2. Fusor internal array shapes:")
# print(f"   fusor.shared_arr2: {fusor.shared_arr2.shape}")
# print(f"   fusor.active_arr2: {fusor.active_arr2.shape}")

# # Check if they match
# match_shared = protein_shared.shape == fusor.shared_arr2.shape
# match_active = protein_active.shape == fusor.active_arr2.shape
# print(f"\n3. Do shapes match?")
# print(f"   shared: {match_shared}")
# print(f"   active: {match_active}")

# # Check if Fusor arrays ARE the same object (not a copy)
# print(f"\n4. Are they the same object? (should be True)")
# print(f"   shared: {protein_shared is fusor.shared_arr2}")
# print(f"   active: {protein_active is fusor.active_arr2}")

# # If there's a mismatch, the filter ran AFTER Fusor creation
# if not match_shared or not match_active:
#     print("\n⚠️  MISMATCH DETECTED!")
#     print("   The filter likely ran AFTER Fusor was created.")
#     print("   Re-run cells in order: Filter → Fusor → split_into_batches → ...")
    
# # Check batch data
# if hasattr(fusor, '_batch_to_indices2') and fusor._batch_to_indices2:
#     b2 = list(fusor._batch_to_indices2.keys())[0]
#     batch_size = len(fusor._batch_to_indices2[b2])
#     print(f"\n5. Batch 0 protein indices: {batch_size:,} cells")
#     max_idx = max(fusor._batch_to_indices2[b2])
#     print(f"   Max index in batch: {max_idx}")
#     print(f"   Fusor array size: {fusor.shared_arr2.shape[0]}")
#     if max_idx >= fusor.shared_arr2.shape[0]:
#         print("   ⚠️  INDEX OUT OF BOUNDS - batches reference old array size!")

In [None]:
# process all RNA features
sc.pp.normalize_total(rna_adata)
sc.pp.log1p(rna_adata)
sc.pp.highly_variable_genes(rna_adata, n_top_genes=5000)
keep_genes = ['LAMP1', 'CD4', 'CEACAM1', 'CD38', 'PCNA', 'FOXP3', 'B3GAT1', 'MKI67', 'VSIR', 'PDCD1', 'NCAM1', 'TCF7', 'CD3E', 'ENTPD1']
for gene in keep_genes:
    if gene in rna_adata.var_names:
        rna_adata.var.loc[gene, 'highly_variable'] = True
rna_adata = rna_adata[:, rna_adata.var.highly_variable].copy()
sc.pp.scale(rna_adata)

In [None]:
sc.pp.neighbors(rna_adata, n_neighbors=15)
sc.tl.umap(rna_adata)
sc.pl.umap(rna_adata, color="CD3E")

In [None]:
sc.pp.log1p(protein_adata)
sc.pp.scale(protein_adata)

In [None]:
sc.pp.neighbors(protein_adata, n_neighbors=15)
sc.tl.umap(protein_adata, min_dist=0.05, spread=2.5)
sc.pl.umap(protein_adata, color='CD3e')

In [None]:
# make sure no feature is static
rna_active = rna_adata.X
protein_active = protein_adata.X
rna_active = rna_active[:, rna_active.std(axis=0) > 1e-5] # these are fine since already using variable features
protein_active = protein_active[:, protein_active.std(axis=0) > 1e-5] # protein are generally variable

In [None]:
# inspect shape of the four matrices
print(rna_active.shape)
print(protein_active.shape)
print(rna_shared.shape)
print(protein_shared.shape)

In [None]:
# Calculate appropriate batching parameters
n_rna = rna_active.shape[0]
n_prot = protein_active.shape[0]
ratio = n_prot / n_rna

print(f"RNA cells: {n_rna}")
print(f"Protein cells: {n_prot}")
print(f"Ratio (protein/RNA): {ratio:.1f}")

# Batching parameters
max_outward = min(8000, n_rna)

# CRITICAL FIX: For very large protein:RNA ratios, cap the matching ratio
# The algorithm propagates matches via pivot cells, so we don't need every
# protein cell in initial batches. Using sqrt scaling prevents over-sampling
# that leads to excessive filtering and cell loss.
matching_ratio = min(100, max(10, int(ratio) + 5))  # sqrt scaling for large ratios

# Metacell size: with only 19 features, metacells of size 2 provide minimal benefit
# Use size=1 (no metacells) for direct cell matching, or increase to 5+ for noise reduction
metacell_sz = 2  # No metacells - direct matching with small feature panel

print(f"\nBatching parameters:")
print(f"  max_outward_size: {max_outward}")
print(f"  matching_ratio: {matching_ratio} (sqrt-scaled, capped at 100)")
print(f"  metacell_size: {metacell_sz} (disabled for 19-marker panel)")

In [None]:
fusor.split_into_batches(
    max_outward_size=max_outward,
    matching_ratio=matching_ratio,
    metacell_size=metacell_sz,
    verbose=True
)

### Tissue-Aware Matching Priors

Since the data contains cells from two tissues (pLN and Pancreas), we configure priors to penalize cross-tissue matches. This ensures that pLN RNA cells preferentially match to pLN protein cells, and Pancreas RNA cells to Pancreas protein cells.

The weight matrix uses:
- **0.1** for same-tissue matches (favorable)
- **10.0** for cross-tissue matches (penalized 100x)

In [None]:
# DISABLED: Tissue-aware matching priors
# We want to match by CELL TYPE, not tissue
# Since all RNA is from pLN, we need pLN RNA to match Pancreas protein T cells

print("=" * 60)
print("TISSUE PRIORS: DISABLED")
print("=" * 60)
print("Matching by cell type markers instead of tissue origin.")
print("This allows pLN RNA T cells to match Pancreas protein T cells.")

# Get tissue labels for reference only
rna_tissue = rna_adata.obs['Tissue'].values
protein_tissue = protein_adata.obs['Tissue'].values

print(f"\nTissue distribution (for reference):")
print(f"  RNA - pLN: {(rna_tissue == 'pLN').sum():,}, Pancreas: {(rna_tissue == 'Pancreas').sum():,}")
print(f"  Protein - pLN: {(protein_tissue == 'pLN').sum():,}, Pancreas: {(protein_tissue == 'Pancreas').sum():,}")

# DO NOT set region priors - let matching be driven by cell type markers
# fusor.set_region_priors(...) - DISABLED

print("\nNo tissue priors set - matching will be based on shared feature similarity only.")

In [None]:
# Plot singular values to determine SVD components
# Explicitly use batch=0 to investigate singular value behavior
# fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# plt.sca(axes[0])
fusor.plot_singular_values(target='active_arr1',  n_components=None)
# axes[0].set_title('RNA Active - Singular Values (Batch 0)')

# plt.sca(axes[1])
# For protein, cap n_components at n_features - 1
# n_prot_comp = min(18, protein_active.shape[1] - 1)
fusor.plot_singular_values(target='active_arr2',  n_components=None)
# axes[1].set_title(f'Protein Active - Singular Values (Batch 0, {protein_active.shape[1]} features)')

# plt.tight_layout()
# plt.show()

In [None]:

n_prot_features = protein_active.shape[1]
n_rna_features = rna_active.shape[1]
n_shared = rna_shared.shape[1]

# print(f"Dataset dimensions:")
# print(f"  Protein features: {n_prot_features}")
# print(f"  RNA features: {n_rna_features}")
# print(f"  Shared features: {n_shared}")

# # Graph construction: SVD components must be < n_features
# # For small protein panels, use conservative values
svd_comp1_graph = min(40, n_rna_features - 1)  # RNA: plenty of features
svd_comp2_graph = min(15, n_prot_features - 1)  # Protein: cap at 15 for 19 markers

# print(f"\nGraph construction SVD components:")
# print(f"  RNA: {svd_comp1_graph}")
# print(f"  Protein: {svd_comp2_graph} (capped for {n_prot_features}-marker panel)")

In [None]:
# Construct graphs with automatic clustering
fusor.construct_graphs(
    n_neighbors1=15,
    n_neighbors2=15,
    svd_components1=svd_comp1_graph,
    svd_components2=svd_comp2_graph,
    resolution1=2.0,   # Higher resolution = more clusters = finer smoothing
    resolution2=2.0,
    resolution_tol=0.1,
    leiden_runs=1,
    verbose=True)

In [None]:
# Visualize clustering results
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Get cluster labels from first batch
labels1_b0 = fusor._labels1[0]
labels2_b0 = fusor._labels2[0]

# RNA cluster sizes
ax = axes[0]
unique, counts = np.unique(labels1_b0, return_counts=True)
ax.bar(range(len(counts)), sorted(counts, reverse=True))
ax.set_xlabel('Cluster rank')
ax.set_ylabel('Cells')
ax.set_title(f'RNA Clusters (n={len(unique)})')
ax.axhline(y=np.mean(counts), color='r', linestyle='--', label=f'Mean: {np.mean(counts):.0f}')
ax.legend()

# Protein cluster sizes
ax = axes[1]
unique, counts = np.unique(labels2_b0, return_counts=True)
ax.bar(range(len(counts)), sorted(counts, reverse=True))
ax.set_xlabel('Cluster rank')
ax.set_ylabel('Cells')
ax.set_title(f'Protein Clusters (n={len(unique)})')
ax.axhline(y=np.mean(counts), color='r', linestyle='--', label=f'Mean: {np.mean(counts):.0f}')
ax.legend()

# Summary stats
ax = axes[2]
ax.axis('off')
stats_text = f'''Graph Construction Summary
{"="*40}

RNA (Batch 0):
  Clusters: {len(np.unique(labels1_b0))}
  Cells: {len(labels1_b0)}
  Mean cluster size: {np.mean(np.bincount(labels1_b0)):.1f}

Protein (Batch 0):
  Clusters: {len(np.unique(labels2_b0))}
  Cells: {len(labels2_b0)}
  Mean cluster size: {np.mean(np.bincount(labels2_b0)):.1f}
'''
ax.text(0.1, 0.9, stats_text, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.show()

In [None]:
# fusor.plot_singular_values(target='shared_arr1',  n_components=None)

# fusor.plot_singular_values(target='shared_arr2',  n_components=None)

In [None]:
# Find initial pivots with smoothing for weak linkage
# For 19-marker panel, SVD components must be < n_shared
svd_shared1 = min(25, n_shared - 1)  # Conservative for small panel
svd_shared2 = min(20, n_shared - 1)  # Same for protein
print(f"Using {svd_shared1}/{svd_shared2} SVD components for shared features")
# print(f"  (n_shared = {n_shared})")

fusor.find_initial_pivots(
    wt1=0.3,  # Smoothing weight
    wt2=0.3,
    svd_components1=svd_shared1,
    svd_components2=svd_shared2,
    verbose=True
)

In [None]:
# Visualize initial pivot matching
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Get initial matching from first batch
init_match = fusor._init_matching[0]
init_rows, init_cols, init_scores = init_match

# Score distribution
ax = axes[0]
ax.hist(init_scores, bins=50, edgecolor='white', alpha=0.7)
ax.axvline(np.mean(init_scores), color='r', linestyle='--', 
           label=f'Mean: {np.mean(init_scores):.3f}')
ax.axvline(np.median(init_scores), color='g', linestyle='--',
           label=f'Median: {np.median(init_scores):.3f}')
ax.set_xlabel('Matching Score')
ax.set_ylabel('Count')
ax.set_title('Initial Pivot Scores')
ax.legend()

# Matches per RNA cell
ax = axes[1]
matches_per_rna = np.bincount(init_rows)
ax.hist(matches_per_rna[matches_per_rna > 0], bins=20, edgecolor='white', alpha=0.7)
ax.set_xlabel('Matches per RNA cell')
ax.set_ylabel('Count')
ax.set_title('Matching Density (RNA)')

# Summary
ax = axes[2]
ax.axis('off')
n_rna_matched = len(np.unique(init_rows))
n_prot_matched = len(np.unique(init_cols))
stats = f'''Initial Pivot Matching
{"="*40}

Total matches: {len(init_scores):,}
Unique RNA matched: {n_rna_matched:,}
Unique Protein matched: {n_prot_matched:,}

Score statistics:
  Min:    {np.min(init_scores):.4f}
  Max:    {np.max(init_scores):.4f}
  Mean:   {np.mean(init_scores):.4f}
  Median: {np.median(init_scores):.4f}
  Std:    {np.std(init_scores):.4f}
'''
ax.text(0.1, 0.9, stats, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.show()

In [None]:
# DIAGNOSTIC: Trace where bimodality emerges in the pipeline
# Test each stage: raw → smoothed → SVD → matched
# Handles metacells if enabled

from maxfuse.core import utils as mf_utils

# Get batch 0 data
b1, b2 = fusor._batch1_to_batch2[0]
rna_batch_idx = fusor._batch_to_indices1[b1]
prot_batch_idx = fusor._batch_to_indices2[b2]

# Check if metacells are used
use_metacells = fusor.metacell_size > 1

if use_metacells:
    # Get metacell centroids (this is what the pipeline actually uses)
    metacell_labels1 = fusor._metacell_labels1[b1]
    rna_batch = mf_utils.get_centroids(
        arr=rna_shared[rna_batch_idx, :],
        labels=metacell_labels1
    )
    print(f"Using metacells: {len(rna_batch_idx):,} cells → {rna_batch.shape[0]:,} metacells")
else:
    rna_batch = rna_shared[rna_batch_idx, :]

prot_batch = protein_shared[prot_batch_idx, :]

# Get cluster labels for this batch (these are for metacells if metacells enabled)
rna_labels_full = fusor._labels1[b1]
prot_labels_full = fusor._labels2[b2]

print(f"Batch sizes: RNA={rna_batch.shape[0]:,}, Protein={prot_batch.shape[0]:,}")
print(f"Shared features: {rna_batch.shape[1]}")
print(f"Cluster counts: RNA={len(np.unique(rna_labels_full))}, Protein={len(np.unique(prot_labels_full))}")

# Subsample for speed (correlation matrix is O(n*m))
np.random.seed(42)
n_sample = min(5000, rna_batch.shape[0], prot_batch.shape[0])
rna_sample_idx = np.random.choice(rna_batch.shape[0], n_sample, replace=False)
prot_sample_idx = np.random.choice(prot_batch.shape[0], n_sample, replace=False)

rna_sample = rna_batch[rna_sample_idx]
prot_sample = prot_batch[prot_sample_idx]
rna_labels_sample = rna_labels_full[rna_sample_idx]
prot_labels_sample = prot_labels_full[prot_sample_idx]

# Stage 1: RAW correlation distances (no processing)
print("\nComputing raw correlation distances...")
raw_corr_dist = mf_utils.cdist_correlation(rna_sample, prot_sample)
raw_min_dist = raw_corr_dist.min(axis=1)  # Best match for each RNA cell

# Stage 2: After centroid shrinkage
print("Applying centroid shrinkage (wt=0.3)...")
rna_smoothed = mf_utils.shrink_towards_centroids(rna_sample, rna_labels_sample, wt=0.3)
prot_smoothed = mf_utils.shrink_towards_centroids(prot_sample, prot_labels_sample, wt=0.3)

smoothed_corr_dist = mf_utils.cdist_correlation(rna_smoothed, prot_smoothed)
smoothed_min_dist = smoothed_corr_dist.min(axis=1)

# Stage 3: After SVD
print("Applying SVD denoising...")
svd_comp = min(15, rna_sample.shape[1] - 1)
rna_svd = mf_utils.svd_denoise(rna_smoothed, n_components=svd_comp)
prot_svd = mf_utils.svd_denoise(prot_smoothed, n_components=svd_comp)

svd_corr_dist = mf_utils.cdist_correlation(rna_svd, prot_svd)
svd_min_dist = svd_corr_dist.min(axis=1)

# Visualize distributions at each stage
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# Row 1: Min distance distributions (best match per RNA cell)
from scipy.stats import kurtosis
for ax, data, title in zip(axes[0], 
                            [raw_min_dist, smoothed_min_dist, svd_min_dist],
                            ['Raw', 'After Smoothing', 'After SVD']):
    ax.hist(data, bins=50, edgecolor='white', alpha=0.7)
    ax.axvline(np.median(data), color='r', linestyle='--', label=f'Median: {np.median(data):.3f}')
    ax.set_xlabel('Min Correlation Distance')
    ax.set_ylabel('Count')
    ax.set_title(f'{title}: Best Match Distance')
    ax.legend()
    
    k = kurtosis(data)
    ax.text(0.95, 0.85, f'Kurtosis: {k:.2f}', transform=ax.transAxes, ha='right')

# Row 2: Full distance matrix distributions
for ax, data, title in zip(axes[1], 
                            [raw_corr_dist.flatten(), smoothed_corr_dist.flatten(), svd_corr_dist.flatten()],
                            ['Raw', 'After Smoothing', 'After SVD']):
    # Sample to avoid memory issues
    sample = np.random.choice(data, min(100000, len(data)), replace=False)
    ax.hist(sample, bins=50, edgecolor='white', alpha=0.7)
    ax.axvline(np.median(sample), color='r', linestyle='--')
    ax.set_xlabel('Correlation Distance')
    ax.set_ylabel('Count')
    ax.set_title(f'{title}: All Pairwise Distances')

plt.suptitle('Pipeline Stage Analysis: Where Does Bimodality Emerge?', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

# Summary
print("\n" + "=" * 60)
print("BIMODALITY ANALYSIS SUMMARY")
print("=" * 60)
print(f"\nMin distance (best match) statistics:")
print(f"  Raw:      mean={raw_min_dist.mean():.3f}, std={raw_min_dist.std():.3f}")
print(f"  Smoothed: mean={smoothed_min_dist.mean():.3f}, std={smoothed_min_dist.std():.3f}")
print(f"  SVD:      mean={svd_min_dist.mean():.3f}, std={svd_min_dist.std():.3f}")

print(f"\nKurtosis (negative = bimodal):")
print(f"  Raw:      {kurtosis(raw_min_dist):.2f}")
print(f"  Smoothed: {kurtosis(smoothed_min_dist):.2f}")
print(f"  SVD:      {kurtosis(svd_min_dist):.2f}")

In [None]:
# # DIAGNOSTIC: Bidirectional supply-demand - which modality has the "orphan" cells?

# # RNA perspective: how many good protein matches exist?
# n_good_for_rna = (svd_corr_dist < threshold_good).sum(axis=1)  # per RNA cell

# # Protein perspective: how many good RNA matches exist?
# n_good_for_prot = (svd_corr_dist < threshold_good).sum(axis=0)  # per protein cell

# fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# # 1. Distribution of good match availability - RNA side
# ax = axes[0]
# ax.hist(n_good_for_rna, bins=50, edgecolor='white', alpha=0.7, color='blue')
# ax.axvline(np.median(n_good_for_rna), color='red', linestyle='--', 
#            label=f'Median: {np.median(n_good_for_rna):.0f}')
# ax.set_xlabel('# Good Protein Matches Available')
# ax.set_ylabel('# RNA Cells')
# ax.set_title('RNA Cells: Match Availability')
# ax.legend()

# # Count "orphan" RNA cells (very few good matches)
# orphan_thresh_rna = np.percentile(n_good_for_rna, 10)
# n_orphan_rna = (n_good_for_rna < orphan_thresh_rna).sum()
# ax.axvline(orphan_thresh_rna, color='orange', linestyle=':', 
#            label=f'10th pct: {orphan_thresh_rna:.0f}')

# # 2. Distribution of good match availability - Protein side  
# ax = axes[1]
# ax.hist(n_good_for_prot, bins=50, edgecolor='white', alpha=0.7, color='green')
# ax.axvline(np.median(n_good_for_prot), color='red', linestyle='--',
#            label=f'Median: {np.median(n_good_for_prot):.0f}')
# ax.set_xlabel('# Good RNA Matches Available')
# ax.set_ylabel('# Protein Cells')
# ax.set_title('Protein Cells: Match Availability')
# ax.legend()

# orphan_thresh_prot = np.percentile(n_good_for_prot, 10)
# n_orphan_prot = (n_good_for_prot < orphan_thresh_prot).sum()

# # 3. Summary comparison
# ax = axes[2]
# ax.axis('off')

# summary = f'''Supply-Demand Imbalance Analysis
# {"="*50}

# Sample sizes:
#   RNA cells (metacells): {len(n_good_for_rna):,}
#   Protein cells: {len(n_good_for_prot):,}

# Good match availability:
#   RNA → Protein:
#     Mean:   {n_good_for_rna.mean():.1f} good matches/cell
#     Median: {np.median(n_good_for_rna):.0f}
#     "Orphans" (<10th pct): {n_orphan_rna:,} ({n_orphan_rna/len(n_good_for_rna)*100:.1f}%)
    
#   Protein → RNA:
#     Mean:   {n_good_for_prot.mean():.1f} good matches/cell
#     Median: {np.median(n_good_for_prot):.0f}
#     "Orphans" (<10th pct): {n_orphan_prot:,} ({n_orphan_prot/len(n_good_for_prot)*100:.1f}%)

# INTERPRETATION:
# '''

# # Determine which side has more orphans
# if n_good_for_rna.mean() < n_good_for_prot.mean() * 0.8:
#     summary += "  → RNA has more 'orphan' cells\n"
#     summary += "  → Some RNA cell types missing in protein data"
# elif n_good_for_prot.mean() < n_good_for_rna.mean() * 0.8:
#     summary += "  → Protein has more 'orphan' cells\n"
#     summary += "  → Some protein cell types missing in RNA data"
# else:
#     summary += "  → Roughly balanced availability\n"
#     summary += "  → Bimodality from competition, not missing types"

# ax.text(0.05, 0.95, summary, transform=ax.transAxes, fontsize=10,
#         verticalalignment='top', fontfamily='monospace')

# plt.tight_layout()
# plt.show()

In [None]:
# DIAGNOSTIC: Check if specific clusters drive the bimodality
# Some clusters may have shared features that don't correlate across modalities

# Get actual initial matching results
init_match = fusor._init_matching[0]
init_rows, init_cols, init_scores = init_match
init_scores = np.array(init_scores)

# Get cluster labels for matched cells
b1, b2 = fusor._batch1_to_batch2[0]
rna_clusters = fusor._labels1[b1][init_rows]
prot_clusters = fusor._labels2[b2][init_cols]

# Analyze scores by RNA cluster
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Score distribution by RNA cluster
ax = axes[0, 0]
unique_rna_clusters = np.unique(rna_clusters)
cluster_means = []
for c in unique_rna_clusters:
    mask = rna_clusters == c
    cluster_means.append((c, init_scores[mask].mean(), mask.sum()))
cluster_means.sort(key=lambda x: x[1])  # Sort by mean score

# Show top 5 best and worst clusters
n_show = min(5, len(cluster_means))
worst_clusters = cluster_means[-n_show:]
best_clusters = cluster_means[:n_show]

for c, mean, n in best_clusters:
    mask = rna_clusters == c
    ax.hist(init_scores[mask], bins=30, alpha=0.5, label=f'RNA cluster {c} (n={n}, μ={mean:.2f})')
ax.set_xlabel('Matching Score')
ax.set_ylabel('Count')
ax.set_title(f'Best {n_show} RNA Clusters (lowest scores)')
ax.legend(fontsize=8)

ax = axes[0, 1]
for c, mean, n in worst_clusters:
    mask = rna_clusters == c
    ax.hist(init_scores[mask], bins=30, alpha=0.5, label=f'RNA cluster {c} (n={n}, μ={mean:.2f})')
ax.set_xlabel('Matching Score')
ax.set_ylabel('Count')
ax.set_title(f'Worst {n_show} RNA Clusters (highest scores)')
ax.legend(fontsize=8)

# 2. Mean score per cluster
ax = axes[1, 0]
cluster_ids = [x[0] for x in cluster_means]
means = [x[1] for x in cluster_means]
sizes = [x[2] for x in cluster_means]
colors = ['red' if m > 0.5 else 'green' for m in means]
ax.bar(range(len(cluster_ids)), means, color=colors, alpha=0.7)
ax.axhline(0.5, color='black', linestyle='--', label='Threshold 0.5')
ax.set_xlabel('RNA Cluster (sorted by score)')
ax.set_ylabel('Mean Matching Score')
ax.set_title('Mean Score by RNA Cluster')
ax.legend()

# 3. Cluster size vs score
ax = axes[1, 1]
ax.scatter(sizes, means, alpha=0.6, c=colors)
ax.set_xlabel('Cluster Size')
ax.set_ylabel('Mean Matching Score')
ax.set_title('Cluster Size vs Mean Score')
ax.axhline(0.5, color='black', linestyle='--')

# Annotate outlier clusters
for c, mean, n in cluster_means:
    if mean > 0.6 or mean < 0.2:
        ax.annotate(f'{c}', (n, mean), fontsize=8)

plt.tight_layout()
plt.show()

# Summary: which clusters are problematic?
print("\n" + "=" * 60)
print("CLUSTER ANALYSIS SUMMARY")
print("=" * 60)
bad_clusters = [(c, m, n) for c, m, n in cluster_means if m > 0.5]
good_clusters = [(c, m, n) for c, m, n in cluster_means if m < 0.3]

print(f"\nTotal RNA clusters: {len(unique_rna_clusters)}")
print(f"Bad clusters (score > 0.5): {len(bad_clusters)}")
print(f"Good clusters (score < 0.3): {len(good_clusters)}")

if bad_clusters:
    print(f"\nWorst clusters (candidates for filtering):")
    total_bad = sum(n for c, m, n in bad_clusters)
    for c, m, n in sorted(bad_clusters, key=lambda x: -x[1])[:10]:
        print(f"  Cluster {c}: mean={m:.3f}, n={n:,} ({n/len(init_scores)*100:.1f}%)")
    print(f"  Total cells in bad clusters: {total_bad:,} ({total_bad/len(init_scores)*100:.1f}%)")

In [None]:
# Check canonical correlations
cca_comp_check = min(15, n_prot_features - 1)
fusor.plot_canonical_correlations(
    svd_components1=min(30, n_rna_features - 1),
    svd_components2=None,
    cca_components=cca_comp_check
)

In [None]:
# Refine pivots using CCA
cca_components = min(25, n_prot_features - 1)

fusor.refine_pivots(
    wt1=0.3,
    wt2=0.3,
    svd_components1=min(40, n_rna_features - 1),
    svd_components2=None,  # Keep all protein features
    cca_components=cca_components,
    n_iters=1,
    filter_prop=0.0,
    verbose=True
)

In [None]:
# Compare initial vs refined matching
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Get refined matching
refined_match = fusor._refined_matching[0]
ref_rows, ref_cols, ref_scores = refined_match

# Score comparison
ax = axes[0]
ax.hist(init_scores, bins=50, alpha=0.5, label='Initial', edgecolor='white')
ax.hist(ref_scores, bins=50, alpha=0.5, label='Refined', edgecolor='white')
ax.set_xlabel('Matching Score')
ax.set_ylabel('Count')
ax.set_title('Score Distribution: Initial vs Refined')
ax.legend()

# Score improvement
ax = axes[1]
ax.boxplot([init_scores, ref_scores], labels=['Initial', 'Refined'])
ax.set_ylabel('Score')
ax.set_title('Score Comparison')

# Summary
ax = axes[2]
ax.axis('off')
stats = f'''CCA Refinement Results
{"="*40}

                Initial    Refined
Matches:     {len(init_scores):>10,}  {len(ref_scores):>10,}
Mean score:  {np.mean(init_scores):>10.4f}  {np.mean(ref_scores):>10.4f}
Median:      {np.median(init_scores):>10.4f}  {np.median(ref_scores):>10.4f}
Std:         {np.std(init_scores):>10.4f}  {np.std(ref_scores):>10.4f}

Score change: {np.mean(ref_scores) - np.mean(init_scores):+.4f}
'''
ax.text(0.1, 0.9, stats, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.show()

In [None]:
# Analyze pivot score distribution with Gaussian Mixture Model
# This helps identify bimodal distribution and set adaptive thresholds
from sklearn.mixture import GaussianMixture

# Get pivot matching scores from refined matching
pivot_match = fusor._refined_matching[0]  # First batch
ref_scores = np.array(pivot_match[2])

# Fit Gaussian Mixture to detect two modes
gmm = GaussianMixture(n_components=2, random_state=42)
gmm.fit(ref_scores.reshape(-1, 1))
means = gmm.means_.flatten()
weights = gmm.weights_
stds = np.sqrt(gmm.covariances_.flatten())

# Identify good vs bad mode (LOWER score = better match, it's 1-correlation)
good_mode_idx = np.argmin(means)  # Lower = better
bad_mode_idx = np.argmax(means)   # Higher = worse

print("GMM Analysis of Pivot Matching Scores")
print("=" * 50)
print(f"Good mode (low): mean={means[good_mode_idx]:.3f}, std={stds[good_mode_idx]:.3f}, weight={weights[good_mode_idx]:.1%}")
print(f"Bad mode (high): mean={means[bad_mode_idx]:.3f}, std={stds[bad_mode_idx]:.3f}, weight={weights[bad_mode_idx]:.1%}")

# Threshold at midpoint between modes
gmm_threshold = (means[0] + means[1]) / 2
bad_mode_fraction = weights[bad_mode_idx]

print(f"\nBad mode fraction: {bad_mode_fraction:.2%}")
print(f"Threshold between modes: {gmm_threshold:.3f}")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Histogram with GMM fit
ax = axes[0]
x = np.linspace(ref_scores.min(), ref_scores.max(), 200)
ax.hist(ref_scores, bins=50, density=True, alpha=0.7, edgecolor='white', label='Scores')

# Plot GMM components
from scipy.stats import norm
for i in range(2):
    color = 'green' if i == good_mode_idx else 'red'
    label = 'Good mode (low)' if i == good_mode_idx else 'Bad mode (high)'
    ax.plot(x, weights[i] * norm.pdf(x, means[i], stds[i]), 
            color=color, linewidth=2, label=label)

ax.axvline(gmm_threshold, color='black', linestyle='--', linewidth=2, label=f'Threshold: {gmm_threshold:.3f}')
ax.set_xlabel('Matching Score (1 - correlation, lower = better)')
ax.set_ylabel('Density')
ax.set_title('Pivot Score Distribution with GMM Fit')
ax.legend()

# Score ranges
ax = axes[1]
ax.axis('off')
summary = f'''GMM Analysis Summary
{"="*40}

Total pivots: {len(ref_scores):,}

Mode Analysis (score = 1 - correlation):
  Good mode (low scores): {weights[good_mode_idx]:.1%}
    mean={means[good_mode_idx]:.3f}, std={stds[good_mode_idx]:.3f}
  
  Bad mode (high scores): {weights[bad_mode_idx]:.1%}
    mean={means[bad_mode_idx]:.3f}, std={stds[bad_mode_idx]:.3f}

Threshold: {gmm_threshold:.3f}
Score range: [{ref_scores.min():.3f}, {ref_scores.max():.3f}]

Recommended filter_prop: {bad_mode_fraction:.2f}
(removes highest {bad_mode_fraction:.0%} of scores)
'''
ax.text(0.1, 0.9, summary, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.show()

In [None]:
# # DIAGNOSTIC: Investigate bimodal score distribution
# # Hypothesis: The two modes may correspond to tissue types (pLN vs Pancreas)

# # Get matching data
# pivot_match = fusor._refined_matching[0]
# ref_rows, ref_cols, ref_scores = pivot_match
# ref_scores = np.array(ref_scores)

# # Get batch indices - _batch1_to_batch2 is a list of (b1, b2) tuples
# b1, b2 = fusor._batch1_to_batch2[0]
# rna_indices = np.array(fusor._batch_to_indices1[b1])
# prot_indices = np.array(fusor._batch_to_indices2[b2])

# # Get tissue labels for matched cells
# rna_tissues = rna_adata.obs['Tissue'].values[rna_indices[ref_rows]]
# prot_tissues = protein_adata.obs['Tissue'].values[prot_indices[ref_cols]]

# # Categorize matches by tissue combination
# match_types = []
# for rt, pt in zip(rna_tissues, prot_tissues):
#     if rt == pt:
#         match_types.append(f'Same: {rt}')
#     else:
#         match_types.append(f'Cross: {rt}↔{pt}')
# match_types = np.array(match_types)

# # Analyze scores by match type
# fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# # 1. Score distribution by tissue match type
# ax = axes[0, 0]
# unique_types = np.unique(match_types)
# colors = plt.cm.tab10(np.linspace(0, 1, len(unique_types)))
# for i, mt in enumerate(unique_types):
#     mask = match_types == mt
#     ax.hist(ref_scores[mask], bins=30, alpha=0.6, label=f'{mt} (n={mask.sum():,})', 
#             edgecolor='white', color=colors[i])
# ax.set_xlabel('Matching Score')
# ax.set_ylabel('Count')
# ax.set_title('Score Distribution by Tissue Match Type')
# ax.legend()
# ax.axvline(0.2, color='red', linestyle='--', alpha=0.5)
# ax.axvline(0.8, color='red', linestyle='--', alpha=0.5)

# # 2. Boxplot by tissue type
# ax = axes[0, 1]
# data_for_box = [ref_scores[match_types == mt] for mt in unique_types]
# bp = ax.boxplot(data_for_box, labels=[mt.replace('↔', '\n↔') for mt in unique_types], patch_artist=True)
# for patch, color in zip(bp['boxes'], colors):
#     patch.set_facecolor(color)
#     patch.set_alpha(0.6)
# ax.set_ylabel('Matching Score')
# ax.set_title('Score by Tissue Match Type')
# ax.axhline(0.2, color='red', linestyle='--', alpha=0.5, label='Mode 1 (~0.2)')
# ax.axhline(0.8, color='red', linestyle='--', alpha=0.5, label='Mode 2 (~0.8)')

# # 3. Score vs RNA tissue
# ax = axes[1, 0]
# rna_tissue_unique = np.unique(rna_tissues)
# for i, tissue in enumerate(rna_tissue_unique):
#     mask = rna_tissues == tissue
#     ax.hist(ref_scores[mask], bins=30, alpha=0.6, label=f'RNA: {tissue} (n={mask.sum():,})', edgecolor='white')
# ax.set_xlabel('Matching Score')
# ax.set_ylabel('Count')
# ax.set_title('Score by RNA Tissue')
# ax.legend()

# # 4. Summary statistics
# ax = axes[1, 1]
# ax.axis('off')
# summary_lines = ['Bimodal Score Investigation', '=' * 50, '']

# for mt in unique_types:
#     mask = match_types == mt
#     scores_mt = ref_scores[mask]
#     # Categorize into low (<0.5) and high (>=0.5) modes
#     low_frac = (scores_mt < 0.5).mean()
#     high_frac = (scores_mt >= 0.5).mean()
#     summary_lines.append(f'{mt}:')
#     summary_lines.append(f'  n = {mask.sum():,} ({mask.mean()*100:.1f}% of all)')
#     summary_lines.append(f'  Mean score: {scores_mt.mean():.3f}')
#     summary_lines.append(f'  Low mode (<0.5): {low_frac*100:.1f}%')
#     summary_lines.append(f'  High mode (≥0.5): {high_frac*100:.1f}%')
#     summary_lines.append('')

# # Overall interpretation
# same_tissue_mask = np.array(['Same' in mt for mt in match_types])
# cross_tissue_mask = ~same_tissue_mask
# summary_lines.append('INTERPRETATION:')
# if same_tissue_mask.sum() > 0 and cross_tissue_mask.sum() > 0:
#     same_mean = ref_scores[same_tissue_mask].mean()
#     cross_mean = ref_scores[cross_tissue_mask].mean()
#     summary_lines.append(f'  Same-tissue mean: {same_mean:.3f}')
#     summary_lines.append(f'  Cross-tissue mean: {cross_mean:.3f}')
#     if abs(same_mean - cross_mean) > 0.2:
#         summary_lines.append('  → Bimodality likely due to tissue mismatch!')
#     else:
#         summary_lines.append('  → Tissue type does NOT explain bimodality')

# ax.text(0.05, 0.95, '\n'.join(summary_lines), transform=ax.transAxes, fontsize=10,
#         verticalalignment='top', fontfamily='monospace')

# plt.tight_layout()
# plt.show()

# # Print cross-tabulation
# print("\nCross-tabulation of tissue matches:")
# print(pd.crosstab(rna_tissues, prot_tissues, margins=True))

In [None]:
# # DIAGNOSTIC: Check if protein expression levels correlate with score modes
# # Hypothesis: Low-expressing cells may have worse matching scores

# # Get protein expression for matched cells
# prot_matched_indices = prot_indices[ref_cols]
# prot_expr = protein_active[prot_matched_indices, :]

# # Calculate total protein expression per cell
# prot_total_expr = np.sum(prot_expr, axis=1)

# # Categorize scores into modes
# low_mode = ref_scores < 0.5
# high_mode = ref_scores >= 0.5

# fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# # 1. Score vs total protein expression
# ax = axes[0]
# ax.scatter(prot_total_expr, ref_scores, alpha=0.1, s=1)
# ax.set_xlabel('Total Protein Expression')
# ax.set_ylabel('Matching Score')
# ax.set_title('Score vs Protein Expression')
# # Add correlation
# corr = np.corrcoef(prot_total_expr, ref_scores)[0, 1]
# ax.text(0.05, 0.95, f'Pearson r = {corr:.3f}', transform=ax.transAxes, fontsize=11)

# # 2. Protein expression distribution by score mode
# ax = axes[1]
# ax.hist(prot_total_expr[low_mode], bins=50, alpha=0.6, label=f'Low mode (<0.5, n={low_mode.sum():,})', density=True)
# ax.hist(prot_total_expr[high_mode], bins=50, alpha=0.6, label=f'High mode (≥0.5, n={high_mode.sum():,})', density=True)
# ax.set_xlabel('Total Protein Expression')
# ax.set_ylabel('Density')
# ax.set_title('Protein Expression by Score Mode')
# ax.legend()

# # 3. Check shared feature correlations
# # Get shared features for matched pairs
# rna_matched_indices = rna_indices[ref_rows]
# rna_shared_matched = rna_shared[rna_matched_indices, :]
# prot_shared_matched = protein_shared[prot_matched_indices, :]

# # Calculate per-pair correlation across shared features
# pair_correlations = []
# for i in range(len(ref_rows)):
#     r = np.corrcoef(rna_shared_matched[i, :], prot_shared_matched[i, :])[0, 1]
#     pair_correlations.append(r if not np.isnan(r) else 0)
# pair_correlations = np.array(pair_correlations)

# ax = axes[2]
# ax.scatter(pair_correlations, ref_scores, alpha=0.1, s=1)
# ax.set_xlabel('Shared Feature Correlation (RNA vs Protein)')
# ax.set_ylabel('Matching Score')
# ax.set_title('Score vs Shared Feature Agreement')
# corr2 = np.corrcoef(pair_correlations, ref_scores)[0, 1]
# ax.text(0.05, 0.95, f'Pearson r = {corr2:.3f}', transform=ax.transAxes, fontsize=11)

# plt.tight_layout()
# plt.show()

# # Summary
# print("\nSummary Statistics:")
# print(f"Low mode (<0.5):  mean protein expr = {prot_total_expr[low_mode].mean():.2f}, "
#       f"mean shared corr = {pair_correlations[low_mode].mean():.3f}")
# print(f"High mode (≥0.5): mean protein expr = {prot_total_expr[high_mode].mean():.2f}, "
#       f"mean shared corr = {pair_correlations[high_mode].mean():.3f}")

In [None]:
# # DIAGNOSTIC: Compare initial vs refined - does CCA create the bimodality?

# init_match = fusor._init_matching[0]
# init_rows, init_cols, init_scores_arr = init_match
# init_scores_arr = np.array(init_scores_arr)

# refined_match = fusor._refined_matching[0]
# refined_rows, refined_cols, ref_scores_arr = refined_match
# ref_scores_arr = np.array(ref_scores_arr)

# fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# # 1. Initial score distribution
# ax = axes[0]
# ax.hist(init_scores_arr, bins=50, edgecolor='white', alpha=0.7)
# ax.set_xlabel('Initial Matching Score')
# ax.set_ylabel('Count')
# ax.set_title(f'Initial Scores (n={len(init_scores_arr):,})')
# ax.axvline(np.median(init_scores_arr), color='red', linestyle='--', label=f'Median: {np.median(init_scores_arr):.3f}')
# ax.legend()

# # Check if initial is bimodal
# from scipy.stats import kurtosis
# init_kurtosis = kurtosis(init_scores_arr)
# ax.text(0.95, 0.95, f'Kurtosis: {init_kurtosis:.2f}\n(negative=bimodal)', 
#         transform=ax.transAxes, ha='right', va='top', fontsize=9)

# # 2. Refined score distribution  
# ax = axes[1]
# ax.hist(ref_scores_arr, bins=50, edgecolor='white', alpha=0.7, color='orange')
# ax.set_xlabel('Refined Matching Score')
# ax.set_ylabel('Count')
# ax.set_title(f'Refined Scores (n={len(ref_scores_arr):,})')
# ax.axvline(np.median(ref_scores_arr), color='red', linestyle='--', label=f'Median: {np.median(ref_scores_arr):.3f}')
# ax.legend()

# ref_kurtosis = kurtosis(ref_scores_arr)
# ax.text(0.95, 0.95, f'Kurtosis: {ref_kurtosis:.2f}\n(negative=bimodal)', 
#         transform=ax.transAxes, ha='right', va='top', fontsize=9)

# # 3. Score change scatter (for matched pairs that exist in both)
# ax = axes[2]
# # The refined matching should have same pairs as initial (just reordered by score)
# # Plot score comparison
# ax.hist2d(init_scores_arr, ref_scores_arr, bins=50, cmap='Blues')
# ax.plot([0, 1], [0, 1], 'r--', linewidth=2, label='y=x')
# ax.set_xlabel('Initial Score')
# ax.set_ylabel('Refined Score')
# ax.set_title('Initial vs Refined Scores')
# ax.legend()

# plt.tight_layout()
# plt.show()

# # Summary
# print("\nBimodality Analysis:")
# print("=" * 50)
# print(f"Initial scores:  mean={init_scores_arr.mean():.3f}, std={init_scores_arr.std():.3f}, kurtosis={init_kurtosis:.2f}")
# print(f"Refined scores:  mean={ref_scores_arr.mean():.3f}, std={ref_scores_arr.std():.3f}, kurtosis={ref_kurtosis:.2f}")
# print()
# if ref_kurtosis < init_kurtosis - 0.5:
#     print("→ CCA refinement INCREASED bimodality (kurtosis decreased)")
#     print("  This suggests CCA is separating good/bad matches more clearly")
# elif ref_kurtosis > init_kurtosis + 0.5:
#     print("→ CCA refinement REDUCED bimodality")
# else:
#     print("→ Similar bimodality before and after CCA")
#     print("  Bimodality is inherent in the data, not created by CCA")

In [None]:
# Filter bad pivots using GMM-guided threshold
# The GMM analysis (previous cell) identifies the fraction of matches in the "bad mode"
# Use that fraction (with a small margin) instead of a fixed 20%

# GMM-guided filtering: use bad_mode_fraction from previous cell
# Add small margin (0.02) to be slightly conservative, cap at 0.2 max
# pivot_filter_prop = min(0.2, bad_mode_fraction + 0.02)

# print(f"GMM-guided pivot filter:")
# print(f"  Bad mode fraction: {bad_mode_fraction:.2%}")
# print(f"  Filter proportion: {pivot_filter_prop:.2%} (bad_mode + 2% margin, max 20%)")

fusor.filter_bad_matches(
    target='pivot',
    filter_prop=0.1,
    verbose=True
)

# print(f"\nFiltered {pivot_filter_prop*100:.1f}% of lowest-scoring pivots")

In [None]:
# Visualize pivot filtering results
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Get remaining indices (ref_scores defined in GMM cell above)
remaining_idx = fusor._remaining_indices_in_refined_matching[0]
kept_scores = ref_scores[remaining_idx]
removed_scores = np.delete(ref_scores, remaining_idx)

# Score distributions: kept vs removed
ax = axes[0]
ax.hist(kept_scores, bins=30, alpha=0.7, label=f'Kept ({len(kept_scores):,})', color='green', edgecolor='white')
if len(removed_scores) > 0:
    ax.hist(removed_scores, bins=30, alpha=0.7, label=f'Removed ({len(removed_scores):,})', color='red', edgecolor='white')
ax.axvline(np.mean(kept_scores), color='darkgreen', linestyle='--', label=f'Kept mean: {np.mean(kept_scores):.3f}')
if len(removed_scores) > 0:
    ax.axvline(np.mean(removed_scores), color='darkred', linestyle='--', label=f'Removed mean: {np.mean(removed_scores):.3f}')
ax.set_xlabel('Matching Score (1 - correlation, lower = better)')
ax.set_ylabel('Count')
ax.set_title('Pivot Filtering: Kept vs Removed')
ax.legend()

# Summary pie chart
ax = axes[1]
sizes = [len(kept_scores), len(removed_scores)]
labels = [f'Kept\n{len(kept_scores):,}', f'Removed\n{len(removed_scores):,}']
colors = ['#2ecc71', '#e74c3c']
ax.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
pct_removed = 100 * len(removed_scores) / (len(kept_scores) + len(removed_scores))
ax.set_title(f'Pivot Filter Results\n({pct_removed:.0f}% removed)')

plt.tight_layout()
plt.show()

print(f"Pivot filtering: {len(kept_scores):,} kept, {len(removed_scores):,} removed")
if len(removed_scores) > 0:
    print(f"Mean score - Kept: {np.mean(kept_scores):.4f}, Removed: {np.mean(removed_scores):.4f}")
    print(f"(Removed has HIGHER score = worse matches, since score = 1 - correlation)")
else:
    print(f"Mean score - Kept: {np.mean(kept_scores):.4f}")

In [None]:
# Propagate to all cells
# Use same SVD limits as refine_pivots for consistency
fusor.propagate(
    svd_components1=min(40, n_rna_features - 1),
    svd_components2=None,
    wt1=0.7,
    wt2=0.7,
    verbose=True
)

In [None]:
pivot_matching = fusor.get_matching(target='pivot')

In [None]:
# We can inspect the first pivot pair.
[pivot_matching[0][0], pivot_matching[1][0], pivot_matching[2][0]]

In [None]:
# # DEBUG: Trace where Pancreas RNA cells are lost in the pipeline
# rna_tissue = rna_adata.obs['Tissue'].values
# protein_tissue = protein_adata.obs['Tissue'].values
# pancreas_rna_idx = set(np.where(rna_tissue == 'Pancreas')[0])
# pln_rna_idx = set(np.where(rna_tissue == 'pLN')[0])

# print(f"=== PANCREAS RNA CELL TRACING ===")
# print(f"Total Pancreas RNA cells: {len(pancreas_rna_idx)}")
# print(f"Total pLN RNA cells: {len(pln_rna_idx)}")

# # Check batching - are Pancreas cells in the batches?
# print(f"\n--- Batching ---")
# for b, indices in enumerate(fusor._batch_to_indices1):
#     batch_set = set(indices)
#     batch_pancreas = len(batch_set & pancreas_rna_idx)
#     batch_pln = len(batch_set & pln_rna_idx)
#     if batch_pancreas > 0 or batch_pln > 0:
#         print(f"  Batch {b}: {batch_pancreas} Pancreas, {batch_pln} pLN cells")

# # Build metacell -> cell mapping if using metacells
# metacell_to_cells = {}
# if fusor.metacell_size > 1 and fusor._metacell_labels1 is not None:
#     for b, mc_labels in enumerate(fusor._metacell_labels1):
#         batch_indices = fusor._batch_to_indices1[b]
#         metacell_to_cells[b] = {}
#         for cell_local_idx, mc_idx in enumerate(mc_labels):
#             if mc_idx not in metacell_to_cells[b]:
#                 metacell_to_cells[b][mc_idx] = []
#             metacell_to_cells[b][mc_idx].append(batch_indices[cell_local_idx])

# # Manually convert propagated matching indices to global indices
# print(f"\n--- Propagated Matches (before filtering) ---")
# all_rna_global = []
# all_prot_global = []
# all_scores = []

# for batch_idx, (b1, b2) in enumerate(fusor._batch1_to_batch2):
#     rows, cols, scores = fusor._propagated_matching[batch_idx]
#     batch_indices1 = fusor._batch_to_indices1[b1]
#     batch_indices2 = fusor._batch_to_indices2[b2]
    
#     for r, c, s in zip(rows, cols, scores):
#         # Protein index is straightforward
#         global_prot = batch_indices2[c] if c < len(batch_indices2) else None
#         if global_prot is None:
#             continue
            
#         # RNA index depends on whether metacells are used
#         if fusor.metacell_size > 1 and b1 in metacell_to_cells:
#             # r could be a metacell index - get all cells in that metacell
#             if r in metacell_to_cells[b1]:
#                 for global_rna in metacell_to_cells[b1][r]:
#                     all_rna_global.append(global_rna)
#                     all_prot_global.append(global_prot)
#                     all_scores.append(s)
#             elif r < len(batch_indices1):
#                 # r is a direct cell index (non-pivot cell)
#                 all_rna_global.append(batch_indices1[r])
#                 all_prot_global.append(global_prot)
#                 all_scores.append(s)
#         else:
#             if r < len(batch_indices1):
#                 all_rna_global.append(batch_indices1[r])
#                 all_prot_global.append(global_prot)
#                 all_scores.append(s)

# all_rna_global = np.array(all_rna_global)
# all_prot_global = np.array(all_prot_global)

# print(f"Total matches: {len(all_rna_global):,}")
# print(f"Unique RNA cells: {len(np.unique(all_rna_global)):,}")
# print(f"Unique Protein cells: {len(np.unique(all_prot_global)):,}")

# # Count by tissue
# pancreas_matched = sum(1 for r in all_rna_global if r in pancreas_rna_idx)
# pln_matched = sum(1 for r in all_rna_global if r in pln_rna_idx)
# print(f"\nRNA tissue breakdown:")
# print(f"  Pancreas RNA matches: {pancreas_matched:,}")
# print(f"  pLN RNA matches: {pln_matched:,}")

# # Tissue pair breakdown  
# tissue_pairs = {'Pancreas→Pancreas': 0, 'Pancreas→pLN': 0, 'pLN→Pancreas': 0, 'pLN→pLN': 0}
# for r, c in zip(all_rna_global, all_prot_global):
#     rna_t = rna_tissue[r]
#     prot_t = protein_tissue[c]
#     pair = f"{rna_t}→{prot_t}"
#     tissue_pairs[pair] = tissue_pairs.get(pair, 0) + 1

# print(f"\nTissue pair counts:")
# for pair, count in sorted(tissue_pairs.items()):
#     print(f"  {pair}: {count:,}")

# print(f"===================================")

In [None]:
# Visualize propagation results
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Get propagated matching
prop_match = fusor._propagated_matching[0]
prop_rows, prop_cols, prop_scores = prop_match

# Score distribution
ax = axes[0]
ax.hist(prop_scores, bins=50, edgecolor='white', alpha=0.7, color='purple')
ax.axvline(np.mean(prop_scores), color='r', linestyle='--',
           label=f'Mean: {np.mean(prop_scores):.3f}')
ax.set_xlabel('Matching Score')
ax.set_ylabel('Count')
ax.set_title('Propagated Matching Scores')
ax.legend()

# Coverage: pivot vs propagated
ax = axes[1]
pivot_rna = len(np.unique(ref_rows[remaining_idx]))
prop_rna = len(np.unique(prop_rows))
total_rna = rna_active.shape[0]

categories = ['Pivot\nMatches', 'Propagated\nMatches', 'Total\nRNA Cells']
values = [pivot_rna, prop_rna, total_rna]
colors = ['#3498db', '#9b59b6', '#95a5a6']
bars = ax.bar(categories, values, color=colors, edgecolor='white')
ax.set_ylabel('Count')
ax.set_title('Coverage Expansion')
for bar, val in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,
            f'{val:,}', ha='center', va='bottom', fontsize=10)

# Summary
ax = axes[2]
ax.axis('off')
prop_prot = len(np.unique(prop_cols))
total_prot = protein_active.shape[0]
stats = f'''Propagation Summary
{"="*40}

Propagated matches: {len(prop_scores):,}

RNA coverage:
  Pivot:      {pivot_rna:>8,} ({100*pivot_rna/total_rna:.1f}%)
  Propagated: {prop_rna:>8,} ({100*prop_rna/total_rna:.1f}%)
  Total:      {total_rna:>8,}

Protein coverage:
  Propagated: {prop_prot:>8,} ({100*prop_prot/total_prot:.1f}%)
  Total:      {total_prot:>8,}
'''
ax.text(0.1, 0.9, stats, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.show()

In [None]:
# Filter propagated matches
# NOTE: Using more conservative filtering (5%) to maximize coverage
# The pivot filtering already removed poor matches; be gentle here
propagate_filter_prop = 0.05  # Remove bottom 5% (was 10%)

fusor.filter_bad_matches(
    target='propagated',
    filter_prop=propagate_filter_prop,
    verbose=True
)

print(f"\nFiltered {propagate_filter_prop*100:.0f}% of lowest-scoring propagated matches")

In [None]:
# CRITICAL: Compare PIVOT vs PROPAGATED matching stages
# This cell clarifies the different statistics you'll see throughout the notebook
# 
# PIVOT STAGE: Initial matching of high-confidence "anchor" cells
# PROPAGATED STAGE: Extended matching to all cells via nearest-neighbor search
# NOTE: Scores are 1 - correlation, so LOWER = better match

print("=" * 70)
print("MATCHING STAGE COMPARISON (PIVOT vs PROPAGATED)")
print("=" * 70)

# Get pivot matching
pivot_match = fusor._refined_matching[0]
pivot_scores = np.array(pivot_match[2])
pivot_rna_idx = np.array(pivot_match[0])

# Get propagated matching (before we convert to full_data indices)
prop_match = fusor._propagated_matching[0]
prop_scores = np.array(prop_match[2])
prop_rna_idx = np.array(prop_match[0])

# Stats - score < 0.5 means correlation > 0.5, which is good
n_pivot_total = len(pivot_scores)
n_pivot_good = (pivot_scores < 0.5).sum()  # Lower score = better
n_pivot_rna = len(np.unique(pivot_rna_idx))

n_prop_total = len(prop_scores)
n_prop_good = (prop_scores < 0.5).sum()  # Lower score = better
n_prop_rna = len(np.unique(prop_rna_idx))

print(f"""
┌────────────────────────────────────────────────────────────────────┐
│                     PIVOT STAGE (Anchors)                          │
├────────────────────────────────────────────────────────────────────┤
│  Total pivot matches:     {n_pivot_total:>10,}                              │
│  Unique RNA in pivots:    {n_pivot_rna:>10,}                              │
│  Matches with score<0.5:  {n_pivot_good:>10,} ({100*n_pivot_good/n_pivot_total:>5.1f}%)               │
│  Mean score:              {pivot_scores.mean():>10.3f}                              │
└────────────────────────────────────────────────────────────────────┘

┌────────────────────────────────────────────────────────────────────┐
│                   PROPAGATED STAGE (Extended)                      │
├────────────────────────────────────────────────────────────────────┤
│  Total propagated matches:  {n_prop_total:>10,}                            │
│  Unique RNA after prop:     {n_prop_rna:>10,}                            │
│  Matches with score<0.5:    {n_prop_good:>10,} ({100*n_prop_good/n_prop_total:>5.1f}%)             │
│  Mean score:                {prop_scores.mean():>10.3f}                            │
└────────────────────────────────────────────────────────────────────┘

KEY INSIGHT:
  - Score = 1 - correlation, so LOWER score = BETTER match
  - Pivot stage finds high-quality anchor matches (smaller count, lower scores)
  - Propagation extends to more cells (larger count, higher average scores)
  - Both are valid; they measure DIFFERENT things!

⚠️  When comparing stats across cells, check which stage is being reported!
""")

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Score distributions
ax = axes[0]
ax.hist(pivot_scores, bins=30, alpha=0.6, label=f'Pivot (n={n_pivot_total:,})', color='blue', density=True)
ax.hist(prop_scores, bins=30, alpha=0.6, label=f'Propagated (n={n_prop_total:,})', color='green', density=True)
ax.axvline(0.5, color='red', linestyle='--', label='Good threshold (score<0.5)')
ax.set_xlabel('Matching Score (1 - correlation, lower = better)')
ax.set_ylabel('Density')
ax.set_title('Score Distribution: Pivot vs Propagated')
ax.legend()

# Coverage comparison
ax = axes[1]
stages = ['Pivot', 'Propagated']
rna_counts = [n_pivot_rna, n_prop_rna]
total_rna = rna_active.shape[0]
x = np.arange(len(stages))
bars = ax.bar(x, rna_counts, color=['blue', 'green'], alpha=0.7)
ax.axhline(total_rna, color='red', linestyle='--', label=f'Total RNA ({total_rna:,})')
ax.set_ylabel('Unique RNA Cells')
ax.set_title('RNA Coverage by Stage')
ax.set_xticks(x)
ax.set_xticklabels(stages)
ax.legend()
for bar, val in zip(bars, rna_counts):
    pct = 100 * val / total_rna
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 100, 
            f'{val:,}\n({pct:.1f}%)', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

In [None]:
full_matching = fusor.get_matching(order=(2, 1), target='full_data')

print(f"\nMaxFuse Full matching results (RNA-centric):")
print(f"  Total matches: {len(full_matching[0])}")
print(f"  Unique RNA cells: {len(np.unique(full_matching[0]))}")
print(f"  Unique Protein cells: {len(np.unique(full_matching[1]))}")
print(f"  Score range: [{min(full_matching[2]):.3f}, {max(full_matching[2]):.3f}]")

In [None]:
# Apply score threshold to remove negative/low-quality matches
MIN_SCORE_THRESHOLD = 0.0  # Remove anti-correlated matches

# Get original stats
n_original = len(full_matching[0])
scores = np.array(full_matching[2])

# Filter by threshold
mask = scores >= MIN_SCORE_THRESHOLD
full_matching_filtered = (
    np.array(full_matching[0])[mask],
    np.array(full_matching[1])[mask],
    scores[mask]
)

n_filtered = len(full_matching_filtered[0])
n_removed = n_original - n_filtered

print(f"Score threshold filtering (min score >= {MIN_SCORE_THRESHOLD}):")
print(f"  Original matches: {n_original:,}")
print(f"  Removed (score < {MIN_SCORE_THRESHOLD}): {n_removed:,} ({100*n_removed/n_original:.1f}%)")
print(f"  Remaining matches: {n_filtered:,}")
print(f"")
print(f"Score statistics after filtering:")
print(f"  Min:    {np.min(full_matching_filtered[2]):.4f}")
print(f"  Max:    {np.max(full_matching_filtered[2]):.4f}")
print(f"  Mean:   {np.mean(full_matching_filtered[2]):.4f}")
print(f"  Median: {np.median(full_matching_filtered[2]):.4f}")

# Update full_matching to use filtered version
full_matching = full_matching_filtered

# Visualize what was removed
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

ax = axes[0]
if n_removed > 0:
    ax.hist(scores[~mask], bins=30, alpha=0.7, color='red', label=f'Removed ({n_removed:,})', edgecolor='white')
ax.hist(scores[mask], bins=30, alpha=0.7, color='green', label=f'Kept ({n_filtered:,})', edgecolor='white')
ax.axvline(MIN_SCORE_THRESHOLD, color='black', linestyle='--', linewidth=2, label=f'Threshold ({MIN_SCORE_THRESHOLD})')
ax.set_xlabel('Matching Score')
ax.set_ylabel('Count')
ax.set_title('Score Threshold Filtering')
ax.legend()

ax = axes[1]
if n_removed > 0:
    sizes = [n_filtered, n_removed]
    colors = ['#2ecc71', '#e74c3c']
    labels = [f'Kept\n{n_filtered:,}', f'Removed\n{n_removed:,}']
    ax.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
    ax.set_title('Score Threshold Results')
else:
    ax.text(0.5, 0.5, f'All {n_filtered:,} matches\nkept (score >= {MIN_SCORE_THRESHOLD})', 
            ha='center', va='center', fontsize=12, transform=ax.transAxes,
            bbox=dict(boxstyle='round', facecolor='#2ecc71', alpha=0.3))
    ax.set_title('Score Threshold Results')
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Verify matching cardinality (RNA-centric approach)
print("=" * 60)
print("MATCHING CARDINALITY VERIFICATION (RNA-CENTRIC)")
print("=" * 60)

n_total_matches = len(full_matching[0])
n_unique_rna = len(np.unique(full_matching[0]))
n_unique_protein = len(np.unique(full_matching[1]))

print(f"Total matches: {n_total_matches:,}")
print(f"Unique RNA matched: {n_unique_rna:,} / {rna_active.shape[0]} ({100*n_unique_rna/rna_active.shape[0]:.1f}%)")
print(f"Unique protein matched: {n_unique_protein:,} / {protein_active.shape[0]} ({100*n_unique_protein/protein_active.shape[0]:.1f}%)")
print(f"")
print(f"Score range: [{full_matching[2].min():.3f}, {full_matching[2].max():.3f}]")
print(f"Mean score: {full_matching[2].mean():.3f}")
print(f"")

# With RNA-centric matching, we expect ~1 match per RNA cell
avg_matches_per_rna = n_total_matches / n_unique_rna if n_unique_rna > 0 else 0
print(f"Avg matches per RNA: {avg_matches_per_rna:.1f}")

# Check RNA coverage
rna_coverage = n_unique_rna / rna_active.shape[0] * 100
if rna_coverage >= 90:
    print(f"\n✓ Good RNA coverage ({rna_coverage:.1f}%)")
else:
    print(f"\n⚠️  Low RNA coverage ({rna_coverage:.1f}%) - some RNA cells unmatched")

In [None]:
# Investigate unmatched RNA cells
# Validate that score filtering was applied (cell 57 must run first)
assert np.min(full_matching[2]) >= MIN_SCORE_THRESHOLD, \
    f"Run the score filtering cell first! Found scores below {MIN_SCORE_THRESHOLD} in full_matching."

print("=" * 60)
print("UNMATCHED RNA CELL ANALYSIS")
print("=" * 60)

# Find matched and unmatched RNA cells
matched_rna_idx = np.unique(full_matching[0])
all_rna_idx = np.arange(rna_active.shape[0])
unmatched_rna_idx = np.setdiff1d(all_rna_idx, matched_rna_idx)

n_matched = len(matched_rna_idx)
n_unmatched = len(unmatched_rna_idx)
n_total = rna_active.shape[0]

print(f"\nRNA Cell Coverage:")
print(f"  Matched:   {n_matched:,} ({100*n_matched/n_total:.1f}%)")
print(f"  Unmatched: {n_unmatched:,} ({100*n_unmatched/n_total:.1f}%)")
print(f"  Total:     {n_total:,}")

# Save unmatched indices for further analysis
unmatched_rna_indices = unmatched_rna_idx

if n_unmatched == 0:
    print("\n*** All RNA cells matched! ***")
    print("No unmatched analysis needed.")
else:
    # Compare feature distributions: matched vs unmatched
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))

    # Shared feature means
    ax = axes[0, 0]
    matched_shared_mean = rna_shared[matched_rna_idx].mean(axis=1)
    unmatched_shared_mean = rna_shared[unmatched_rna_idx].mean(axis=1)
    ax.hist(matched_shared_mean, bins=30, alpha=0.6, label='Matched', color='green', density=True)
    ax.hist(unmatched_shared_mean, bins=30, alpha=0.6, label='Unmatched', color='red', density=True)
    ax.set_xlabel('Mean z-scored expression')
    ax.set_ylabel('Density')
    ax.set_title('Shared Features (z-scored): Matched vs Unmatched')
    ax.legend()

    # Shared feature variance
    ax = axes[0, 1]
    matched_shared_var = rna_shared[matched_rna_idx].var(axis=1)
    unmatched_shared_var = rna_shared[unmatched_rna_idx].var(axis=1)
    ax.hist(matched_shared_var, bins=30, alpha=0.6, label='Matched', color='green', density=True)
    ax.hist(unmatched_shared_var, bins=30, alpha=0.6, label='Unmatched', color='red', density=True)
    ax.set_xlabel('Variance of shared features')
    ax.set_ylabel('Density')
    ax.set_title('Feature Variance: Matched vs Unmatched')
    ax.legend()

    # Detection rate (non-zero features)
    ax = axes[0, 2]
    # Use detection mask from normalization (non-zero after log1p)
    matched_detection = rna_detection_mask[matched_rna_idx].mean(axis=1)
    unmatched_detection = rna_detection_mask[unmatched_rna_idx].mean(axis=1)
    ax.hist(matched_detection, bins=30, alpha=0.6, label='Matched', color='green', density=True)
    ax.hist(unmatched_detection, bins=30, alpha=0.6, label='Unmatched', color='red', density=True)
    ax.set_xlabel('Fraction of detected features')
    ax.set_ylabel('Density')
    ax.set_title('Detection Rate: Matched vs Unmatched')
    ax.legend()

    # Active feature mean
    ax = axes[1, 0]
    matched_active_mean = rna_active[matched_rna_idx].mean(axis=1)
    unmatched_active_mean = rna_active[unmatched_rna_idx].mean(axis=1)
    ax.hist(matched_active_mean, bins=30, alpha=0.6, label='Matched', color='green', density=True)
    ax.hist(unmatched_active_mean, bins=30, alpha=0.6, label='Unmatched', color='red', density=True)
    ax.set_xlabel('Mean active feature value')
    ax.set_ylabel('Density')
    ax.set_title('Active Features: Matched vs Unmatched')
    ax.legend()

    # Active feature variance  
    ax = axes[1, 1]
    matched_active_var = rna_active[matched_rna_idx].var(axis=1)
    unmatched_active_var = rna_active[unmatched_rna_idx].var(axis=1)
    ax.hist(matched_active_var, bins=30, alpha=0.6, label='Matched', color='green', density=True)
    ax.hist(unmatched_active_var, bins=30, alpha=0.6, label='Unmatched', color='red', density=True)
    ax.set_xlabel('Variance of active features')
    ax.set_ylabel('Density')
    ax.set_title('Active Variance: Matched vs Unmatched')
    ax.legend()

    # Summary statistics
    ax = axes[1, 2]
    ax.axis('off')

    # Statistical comparison
    from scipy import stats
    t_stat_shared, p_shared = stats.ttest_ind(matched_shared_mean, unmatched_shared_mean)
    t_stat_detect, p_detect = stats.ttest_ind(matched_detection, unmatched_detection)

    summary = f"""Unmatched Cell Characteristics
{"="*45}

Shared feature mean:
  Matched:   {np.mean(matched_shared_mean):.4f}
  Unmatched: {np.mean(unmatched_shared_mean):.4f}
  p-value:   {p_shared:.2e}

Detection rate:
  Matched:   {np.mean(matched_detection):.2%}
  Unmatched: {np.mean(unmatched_detection):.2%}
  p-value:   {p_detect:.2e}

Interpretation:
"""

    if np.mean(unmatched_detection) < np.mean(matched_detection) - 0.05:
        summary += "  Unmatched cells have LOWER detection\n"
        summary += "  (sparse profiles harder to match)"
    elif np.mean(unmatched_shared_mean) < np.mean(matched_shared_mean) - 0.1:
        summary += "  Unmatched cells have LOWER expression\n"
        summary += "  (low signal harder to match)"
    else:
        summary += "  No clear pattern - may be cell type\n"
        summary += "  specific (check clustering)"

    ax.text(0.05, 0.95, summary, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    plt.tight_layout()
    plt.show()

print(f"\nUnmatched RNA cell indices saved to 'unmatched_rna_indices'")
print(f"Use these indices to investigate in notebook 3 (visualization)")


In [None]:
# Comprehensive final matching visualization
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Score distribution
ax = axes[0, 0]
ax.hist(full_matching[2], bins=50, edgecolor='white', alpha=0.7, color='#2ecc71')
ax.axvline(np.mean(full_matching[2]), color='r', linestyle='--',
           label=f'Mean: {np.mean(full_matching[2]):.3f}')
ax.axvline(np.median(full_matching[2]), color='orange', linestyle='--',
           label=f'Median: {np.median(full_matching[2]):.3f}')
ax.set_xlabel('Matching Score')
ax.set_ylabel('Count')
ax.set_title('Final Matching Score Distribution')
ax.legend()

# Matches per RNA cell
ax = axes[0, 1]
rna_match_counts = np.bincount(full_matching[0], minlength=rna_active.shape[0])
ax.hist(rna_match_counts[rna_match_counts > 0], bins=30, edgecolor='white', alpha=0.7)
ax.set_xlabel('Protein matches per RNA cell')
ax.set_ylabel('RNA cells')
ax.set_title(f'Matching Density\n(mean: {np.mean(rna_match_counts[rna_match_counts > 0]):.1f})')

# Matches per Protein cell
ax = axes[0, 2]
prot_match_counts = np.bincount(full_matching[1], minlength=protein_active.shape[0])
ax.hist(prot_match_counts[prot_match_counts > 0], bins=30, edgecolor='white', alpha=0.7, color='orange')
ax.set_xlabel('RNA matches per Protein cell')
ax.set_ylabel('Protein cells')
ax.set_title(f'Reverse Matching Density\n(mean: {np.mean(prot_match_counts[prot_match_counts > 0]):.1f})')

# Coverage summary
ax = axes[1, 0]
n_rna_matched = len(np.unique(full_matching[0]))
n_prot_matched = len(np.unique(full_matching[1]))
categories = ['RNA', 'Protein']
matched = [n_rna_matched, n_prot_matched]
total = [rna_active.shape[0], protein_active.shape[0]]
x = np.arange(len(categories))
width = 0.35
bars1 = ax.bar(x - width/2, matched, width, label='Matched', color='#2ecc71')
bars2 = ax.bar(x + width/2, total, width, label='Total', color='#95a5a6', alpha=0.7)
ax.set_ylabel('Cells')
ax.set_title('Coverage Summary')
ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.legend()
for bar, val in zip(bars1, matched):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(),
            f'{val:,}', ha='center', va='bottom', fontsize=9)

# Score vs index (quality across matches)
ax = axes[1, 1]
sorted_scores = np.sort(full_matching[2])[::-1]
ax.plot(sorted_scores, linewidth=0.5)
ax.axhline(np.mean(full_matching[2]), color='r', linestyle='--', alpha=0.7)
ax.set_xlabel('Match rank')
ax.set_ylabel('Score')
ax.set_title('Sorted Match Scores')
ax.fill_between(range(len(sorted_scores)), sorted_scores, alpha=0.3)

# Final summary text
ax = axes[1, 2]
ax.axis('off')
coverage_rna = 100 * n_rna_matched / rna_active.shape[0]
coverage_prot = 100 * n_prot_matched / protein_active.shape[0]
summary = f'''MAXFUSE INTEGRATION SUMMARY
{"="*45}

Total matches: {len(full_matching[0]):,}

RNA cells:
  Matched: {n_rna_matched:,} / {rna_active.shape[0]:,} ({coverage_rna:.1f}%)
  Avg matches/cell: {len(full_matching[0])/n_rna_matched:.1f}

Protein cells:
  Matched: {n_prot_matched:,} / {protein_active.shape[0]:,} ({coverage_prot:.1f}%)
  Avg matches/cell: {len(full_matching[0])/n_prot_matched:.1f}

Score statistics:
  Mean:   {np.mean(full_matching[2]):.4f}
  Median: {np.median(full_matching[2]):.4f}
  Std:    {np.std(full_matching[2]):.4f}
  Min:    {np.min(full_matching[2]):.4f}
  Max:    {np.max(full_matching[2]):.4f}
'''
ax.text(0.05, 0.95, summary, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.show()

## Save Integration Results

Save integration outputs for use in subsequent visualization notebooks.

In [None]:
# Create aligned arrays and CCA embedding for visualization
# This prepares data for notebook 3 without recomputing

from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler

print("=" * 60)
print("CREATING ALIGNED ARRAYS AND CCA EMBEDDING")
print("=" * 60)

# Build aligned arrays from the filtered matching
# Group matches by RNA cell (one RNA can match multiple protein cells)
rna_to_prot_matches = {}
for rna_idx, prot_idx, score in zip(full_matching[0], full_matching[1], full_matching[2]):
    if rna_idx not in rna_to_prot_matches:
        rna_to_prot_matches[rna_idx] = []
    rna_to_prot_matches[rna_idx].append((prot_idx, score))

# Create aligned arrays using BEST MATCH only (not averaged)
# IMPORTANT: Averaging protein cells creates variance asymmetry with single RNA cells
# Using best match preserves comparable variance in both modalities
aligned_rna_shared = []
aligned_prot_shared = []
aligned_rna_active = []
aligned_prot_active = []
aligned_rna_indices = []
aligned_prot_indices = []  # Track which protein cell was used
aligned_match_scores = []

for rna_idx, matches in rna_to_prot_matches.items():
    # Use BEST match (highest score) instead of averaging all
    best_match = max(matches, key=lambda x: x[1])
    best_prot_idx, best_score = best_match
    
    # Shared features (for CCA) - both are single cells now
    aligned_rna_shared.append(rna_shared[rna_idx])
    aligned_prot_shared.append(protein_shared[best_prot_idx])
    
    # Active features (for downstream visualization)
    aligned_rna_active.append(rna_active[rna_idx])
    aligned_prot_active.append(protein_active[best_prot_idx])
    
    aligned_rna_indices.append(rna_idx)
    aligned_prot_indices.append(best_prot_idx)
    aligned_match_scores.append(best_score)

aligned_rna_shared = np.array(aligned_rna_shared)
aligned_prot_shared = np.array(aligned_prot_shared)
aligned_rna_active = np.array(aligned_rna_active)
aligned_prot_active = np.array(aligned_prot_active)
aligned_rna_indices = np.array(aligned_rna_indices)
aligned_prot_indices = np.array(aligned_prot_indices)
aligned_match_scores = np.array(aligned_match_scores)

print(f"\nAligned arrays created (best match per RNA cell):")
print(f"  Matched pairs: {len(aligned_rna_indices):,}")
print(f"  Shared features: RNA {aligned_rna_shared.shape}, Protein {aligned_prot_shared.shape}")
print(f"  Active features: RNA {aligned_rna_active.shape}, Protein {aligned_prot_active.shape}")
print(f"  Mean match score: {aligned_match_scores.mean():.3f}")
print(f"  Score range: [{aligned_match_scores.min():.3f}, {aligned_match_scores.max():.3f}]")

# Verify variance is comparable between modalities
rna_var = aligned_rna_shared.var(axis=0).mean()
prot_var = aligned_prot_shared.var(axis=0).mean()
print(f"\nVariance check (should be similar):")
print(f"  RNA mean variance:     {rna_var:.4f}")
print(f"  Protein mean variance: {prot_var:.4f}")
print(f"  Ratio: {rna_var/prot_var:.2f}")

# Fit CCA on shared features
print(f"\n" + "=" * 60)
print("FITTING CCA ON SHARED FEATURES")
print("=" * 60)

# Standardize INPUT features before CCA
scaler_rna = StandardScaler()
scaler_prot = StandardScaler()
rna_scaled = scaler_rna.fit_transform(aligned_rna_shared)
prot_scaled = scaler_prot.fit_transform(aligned_prot_shared)

# CCA components - use as many as possible for better UMAP structure
n_cca_vis = min(15, aligned_rna_shared.shape[1] - 1, aligned_prot_shared.shape[1] - 1)
print(f"Fitting CCA with {n_cca_vis} components...")

cca_vis = CCA(n_components=n_cca_vis, max_iter=1000)
cca_vis.fit(rna_scaled, prot_scaled)

# Get CCA scores - these are the canonical variates
cca_rna_scores, cca_prot_scores = cca_vis.transform(rna_scaled, prot_scaled)

# Compute canonical correlations
canonical_correlations = np.array([
    np.corrcoef(cca_rna_scores[:, i], cca_prot_scores[:, i])[0, 1]
    for i in range(n_cca_vis)
])

print(f"\nCanonical correlations: {canonical_correlations.round(3)}")
print(f"  Mean: {canonical_correlations.mean():.3f}")
print(f"  Components > 0.5: {(canonical_correlations > 0.5).sum()}/{len(canonical_correlations)}")
print(f"  Components > 0.3: {(canonical_correlations > 0.3).sum()}/{len(canonical_correlations)}")

# Standardize CCA scores SEPARATELY per modality
# This ensures both modalities have mean=0, std=1, making matched pairs close
scaler_cca_rna = StandardScaler()
scaler_cca_prot = StandardScaler()
cca_rna_for_vis = scaler_cca_rna.fit_transform(cca_rna_scores)
cca_prot_for_vis = scaler_cca_prot.fit_transform(cca_prot_scores)

# Verify alignment is preserved after separate standardization
pair_distances = np.linalg.norm(cca_rna_for_vis - cca_prot_for_vis, axis=1)
print(f"\nMatched pair distances in standardized CCA space:")
print(f"  Mean: {pair_distances.mean():.3f}")
print(f"  Median: {np.median(pair_distances):.3f}")
print(f"  (Lower = better alignment)")

# Plot canonical correlations and pair distances
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

ax = axes[0]
colors = ['darkgreen' if c > 0.7 else 'green' if c > 0.5 else 'orange' if c > 0.3 else 'steelblue' 
          for c in canonical_correlations]
ax.bar(range(len(canonical_correlations)), canonical_correlations, color=colors)
ax.axhline(y=0.5, color='r', linestyle='--', alpha=0.7, label='r=0.5')
ax.axhline(y=0.3, color='orange', linestyle='--', alpha=0.7, label='r=0.3')
ax.set_xlabel('CCA Component')
ax.set_ylabel('Canonical Correlation')
ax.set_ylim(0, 1.1)
ax.set_title(f'CCA Canonical Correlations (n={n_cca_vis})')
ax.legend()

ax = axes[1]
ax.hist(pair_distances, bins=30, edgecolor='white', alpha=0.7)
ax.axvline(pair_distances.mean(), color='red', linestyle='-', linewidth=2, label=f'Mean={pair_distances.mean():.2f}')
ax.set_xlabel('Euclidean Distance (CCA space)')
ax.set_ylabel('Count')
ax.set_title('Matched Pair Distances')
ax.legend()

plt.tight_layout()
plt.show()

print(f"\n" + "=" * 60)
print("READY TO SAVE")
print("=" * 60)

In [None]:
# Save integration results to results directory
import os
import pickle
from datetime import datetime
import json as json_module

# Create results directory
results_dir = '../results/2_integration'
os.makedirs(results_dir, exist_ok=True)

# Save matching results (filtered by score threshold)
matching_data = {
    'rna_indices': full_matching[0],
    'protein_indices': full_matching[1],
    'scores': full_matching[2]
}
with open(f'{results_dir}/maxfuse_matching.pkl', 'wb') as f:
    pickle.dump(matching_data, f)
print(f'Saved MaxFuse matching: {len(full_matching[0]):,} matches')

# Save as CSV for easy inspection
matching_df = pd.DataFrame({
    'rna_idx': full_matching[0],
    'protein_idx': full_matching[1],
    'score': full_matching[2]
})
matching_df.to_csv(f'{results_dir}/maxfuse_matching.csv', index=False)

# Save unmatched RNA indices for further analysis
np.save(f'{results_dir}/unmatched_rna_indices.npy', unmatched_rna_indices)
print(f'Saved unmatched RNA indices: {len(unmatched_rna_indices):,} cells')

# Save normalized arrays used for integration
np.save(f'{results_dir}/rna_shared.npy', rna_shared)
np.save(f'{results_dir}/rna_active.npy', rna_active)
np.save(f'{results_dir}/protein_shared.npy', protein_shared)
np.save(f'{results_dir}/protein_active.npy', protein_active)
print(f'Saved normalized arrays')

# Save aligned arrays for visualization (best match per RNA cell)
np.save(f'{results_dir}/aligned_rna_shared.npy', aligned_rna_shared)
np.save(f'{results_dir}/aligned_prot_shared.npy', aligned_prot_shared)
np.save(f'{results_dir}/aligned_rna_active.npy', aligned_rna_active)
np.save(f'{results_dir}/aligned_prot_active.npy', aligned_prot_active)
np.save(f'{results_dir}/aligned_rna_indices.npy', aligned_rna_indices)
np.save(f'{results_dir}/aligned_prot_indices.npy', aligned_prot_indices)
np.save(f'{results_dir}/aligned_match_scores.npy', aligned_match_scores)
print(f'Saved aligned arrays: {len(aligned_rna_indices):,} matched pairs (best match per RNA)')

# Save CCA embedding for visualization
np.save(f'{results_dir}/cca_rna_scores.npy', cca_rna_for_vis)
np.save(f'{results_dir}/cca_prot_scores.npy', cca_prot_for_vis)
np.save(f'{results_dir}/canonical_correlations.npy', canonical_correlations)
print(f'Saved CCA embedding: {cca_rna_for_vis.shape[1]} components')

# Save filtered AnnData objects (after pre-filtering in integration)
# These have fewer cells than preprocessing outputs due to non-immune cell removal
protein_adata.write(f'{results_dir}/protein_adata_filtered.h5ad')
rna_adata.write(f'{results_dir}/rna_adata_filtered.h5ad')
print(f'Saved filtered AnnData: protein {protein_adata.shape}, rna {rna_adata.shape}')

# Save correspondence table
correspondence_df = pd.DataFrame(
    rna_protein_correspondence, 
    columns=['rna_gene', 'protein_marker']
)
correspondence_df.to_csv(f'{results_dir}/correspondence.csv', index=False)

# Save integration parameters
n_matched_rna = len(np.unique(full_matching[0]))
n_matched_prot = len(np.unique(full_matching[1]))
integration_params = {
    'timestamp': datetime.now().isoformat(),
    'method': 'maxfuse',
    'score_threshold': float(MIN_SCORE_THRESHOLD),
    'score_threshold_reason': threshold_reason if 'threshold_reason' in dir() else 'unknown',
    'fusor_params': {
        'max_outward_size': max_outward,
        'matching_ratio': matching_ratio,
        'smoothing_method': fusor.method,
        'n_shared_features': rna_shared.shape[1],
        # 'cca_components': cca_components,
        # 'pivot_filter_prop': pivot_filter_prop,
        'propagate_filter_prop': propagate_filter_prop
    },
    'data_shapes': {
        'rna_cells': rna_active.shape[0],
        'protein_cells': protein_active.shape[0],
        'rna_active_features': rna_active.shape[1],
        'protein_active_features': protein_active.shape[1],
        'shared_features': rna_shared.shape[1]
    },
    'matching_stats': {
        'total_matches': len(full_matching[0]),
        'unique_rna_matched': n_matched_rna,
        'unique_protein_matched': n_matched_prot,
        'rna_coverage_pct': 100 * n_matched_rna / rna_active.shape[0],
        'protein_coverage_pct': 100 * n_matched_prot / protein_active.shape[0],
        'unmatched_rna_cells': len(unmatched_rna_indices),
        'mean_score': float(np.mean(full_matching[2])),
        'min_score': float(np.min(full_matching[2])),
        'max_score': float(np.max(full_matching[2]))
    },
    'visualization_embedding': {
        'n_cca_components': int(cca_rna_for_vis.shape[1]),
        'n_aligned_pairs': len(aligned_rna_indices),
        'alignment_method': 'best_match',  # Changed from 'averaged'
        'canonical_correlations': canonical_correlations.tolist()
    }
}
with open(f'{results_dir}/integration_params.json', 'w') as f:
    json_module.dump(integration_params, f, indent=2)

print(f'\nAll outputs saved to {results_dir}/')
print(f'  - maxfuse_matching.pkl, maxfuse_matching.csv')
print(f'  - unmatched_rna_indices.npy')
print(f'  - rna_shared.npy, rna_active.npy, protein_shared.npy, protein_active.npy')
print(f'  - aligned_*.npy (7 files for visualization, best match per RNA)')
print(f'  - cca_rna_scores.npy, cca_prot_scores.npy, canonical_correlations.npy')
print(f'  - protein_adata_filtered.h5ad, rna_adata_filtered.h5ad')
print(f'  - correspondence.csv, integration_params.json')
print(f'\nRun 3_visualization.ipynb next.')

In [None]:
# Find the 356 RNA cells with no shared marker expression
# Need to check BEFORE they were filtered out

# Reload to get the original data
import scanpy as sc
import numpy as np

rna_adata_orig = sc.read_h5ad('../results/1_preprocessing/rna_adata_lognorm.h5ad')
print(f"Original RNA cells: {rna_adata_orig.n_obs:,}")

# Get shared genes
correspondence = np.load('../results/2_integration/rna_protein_correspondence.npy', allow_pickle=True) if False else rna_protein_correspondence
shared_genes = [g for g, p in correspondence]
print(f"Shared genes: {shared_genes}")

# Check which cells have zero expression in ALL shared genes
rna_shared_check = rna_adata_orig[:, shared_genes].X.copy()
if hasattr(rna_shared_check, 'toarray'):
    rna_shared_check = rna_shared_check.toarray()

no_expr_mask = (rna_shared_check == 0).all(axis=1)
n_no_expr = no_expr_mask.sum()
print(f"\nRNA cells with NO shared marker expression: {n_no_expr}")

# Look at these cells
no_expr_cells = rna_adata_orig[no_expr_mask]
print(f"\nThese {n_no_expr} cells:")
print(f"  Total counts: {no_expr_cells.obs['total_counts'].median():.0f} median")
print(f"  Genes detected: {no_expr_cells.obs['n_genes_by_counts'].median():.0f} median")
print(f"  Tissue: {no_expr_cells.obs['Tissue'].value_counts().to_dict()}")

In [None]:
# What ARE these cells expressing?
# Get top expressed genes in the no-expression cells

no_expr_X = no_expr_cells.X.copy()
if hasattr(no_expr_X, 'toarray'):
    no_expr_X = no_expr_X.toarray()

# Mean expression per gene across these cells
mean_expr = no_expr_X.mean(axis=0)
gene_names = no_expr_cells.var_names.tolist()

# Top 30 most expressed genes
top_idx = np.argsort(mean_expr)[::-1][:30]
print("Top 30 expressed genes in the 356 cells with no shared marker expression:")
print(f"{'Gene':<15} {'Mean expr':>10} {'% cells':>10}")
print("-" * 37)
for idx in top_idx:
    pct_cells = (no_expr_X[:, idx] > 0).mean() * 100
    print(f"{gene_names[idx]:<15} {mean_expr[idx]:>10.2f} {pct_cells:>9.1f}%")

# Check key immune markers specifically
print("\n\nKey markers in these cells:")
key_markers = ['PTPRC', 'CD3E', 'CD3D', 'MS4A1', 'CD79A', 'CD68', 'CD14', 'NKG7', 'GZMB', 'CD4', 'CD8A']
for marker in key_markers:
    if marker in gene_names:
        idx = gene_names.index(marker)
        pct = (no_expr_X[:, idx] > 0).mean() * 100
        mean = no_expr_X[:, idx].mean()
        print(f"  {marker:<10}: {pct:5.1f}% cells, mean={mean:.2f}")
    else:
        print(f"  {marker:<10}: not in data")

In [None]:
# These are NOT immune cells - they're pancreatic exocrine contamination!
# PRSS1, PRSS2, CLPS, CELA3B = pancreatic digestive enzymes

print("DIAGNOSIS: These 356 cells are pancreatic exocrine contamination")
print("\nEvidence:")
print("  - High expression of pancreatic enzymes: PRSS1, PRSS2, CLPS, CELA3B")
print("  - Low CD45 (PTPRC): only 33% positive")  
print("  - Zero T/B/myeloid markers")
print("\nThese snuck through CD45+ FACS sorting")

# Check if these are mostly from Pancreas tissue
print(f"\nTissue breakdown:")
print(no_expr_cells.obs['Tissue'].value_counts())

# What about PTPRC expression in expressing vs non-expressing cells?
expr_cells = rna_adata_orig[~no_expr_mask]
expr_X = expr_cells.X.copy()
if hasattr(expr_X, 'toarray'):
    expr_X = expr_X.toarray()

ptprc_idx = expr_cells.var_names.tolist().index('PTPRC')
ptprc_good = (expr_X[:, ptprc_idx] > 0).mean() * 100
print(f"\nPTPRC (CD45) expression:")
print(f"  Good cells (express shared markers): {ptprc_good:.1f}%")
print(f"  Bad cells (no shared markers): 33.1%")