# Multi-Sample Integration with MaxFuse

This notebook demonstrates how to process **multiple samples** through the MaxFuse integration pipeline.

## Use Cases
- Integrating CODEX data from multiple tissue sections
- Processing time-course or condition comparisons
- Batch processing for high-throughput studies

## Workflow Overview
1. Define sample metadata and paths
2. Create wrapper functions for preprocessing
3. Process each sample through integration
4. Combine results across samples
5. (Optional) Batch effect correction

**Prerequisites**: Ensure preprocessing.ipynb and integration.ipynb patterns are understood.

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from scipy.io import mmread
from scipy import sparse
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Import from maxfuse package
from maxfuse import Fusor, Mario
from maxfuse.core import model as mf_model
from maxfuse.mario.match import pipelined_mario

# Set random seed for reproducibility
np.random.seed(42)

# Plot settings
plt.rcParams['figure.figsize'] = [10, 8]
plt.rcParams['figure.dpi'] = 100

print(f"MaxFuse loaded")
print(f"Scanpy version: {sc.__version__}")

## Step 1: Define Sample Metadata

Create a configuration dictionary for each sample with paths to data files.
This allows batch processing with sample-specific parameters.

In [None]:
# Define sample configurations
# Modify paths and parameters for your specific samples

SAMPLES = {
    'sample_1': {
        'name': 'Spleen_Section_1',
        'codex_path': '../data/sample1_cells.tsv',
        'rna_matrix_path': '../data/sample1_raw_feature_bc_matrix/',
        'protein_gene_map': '../data/protein_gene_conversion.csv',
        # QC thresholds (adjust per sample if needed)
        'min_umi': 500,
        'max_umi': 25000,
        'min_genes': 200,
        'max_mt_pct': 25,
    },
    'sample_2': {
        'name': 'Spleen_Section_2',
        'codex_path': '../data/sample2_cells.tsv',
        'rna_matrix_path': '../data/sample2_raw_feature_bc_matrix/',
        'protein_gene_map': '../data/protein_gene_conversion.csv',
        'min_umi': 500,
        'max_umi': 25000,
        'min_genes': 200,
        'max_mt_pct': 25,
    },
    # Add more samples as needed
}

# Results directory
RESULTS_DIR = Path('../results/multi_sample')
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print(f"Configured {len(SAMPLES)} samples:")
for sample_id, config in SAMPLES.items():
    print(f"  - {sample_id}: {config['name']}")

## Step 2: Define Preprocessing Functions

Wrapper functions encapsulate the preprocessing logic from `preprocessing.ipynb` for reuse across samples.

In [None]:
def load_codex_data(tsv_path, sample_name=None):
    """
    Load CODEX data from QuPath TSV export.
    
    Parameters
    ----------
    tsv_path : str
        Path to CODEX TSV file
    sample_name : str, optional
        Name to add to obs for sample tracking
        
    Returns
    -------
    AnnData
        Protein expression with spatial coordinates
    """
    print(f"Loading CODEX data from: {tsv_path}")
    codex_df = pd.read_csv(tsv_path, sep='\t')
    
    # Extract marker columns (Cell: *: Mean)
    marker_cols = [col for col in codex_df.columns 
                   if col.startswith('Cell:') and col.endswith(': Mean')]
    
    # Parse marker names
    marker_names = []
    for col in marker_cols:
        marker = col.split(':')[1].strip().split('(')[0].strip()
        marker_names.append(marker)
    
    # Create expression matrix
    protein_matrix = codex_df[marker_cols].values
    
    # Extract spatial coordinates
    x_col = [col for col in codex_df.columns if 'Centroid X' in col][0]
    y_col = [col for col in codex_df.columns if 'Centroid Y' in col][0]
    x_coords = codex_df[x_col].values
    y_coords = codex_df[y_col].values
    
    # Create AnnData
    adata = ad.AnnData(protein_matrix.astype(np.float32))
    adata.var_names = marker_names
    adata.obs['X_centroid'] = x_coords
    adata.obs['Y_centroid'] = y_coords
    adata.obs_names = [f"cell_{i}" for i in range(adata.n_obs)]
    
    if sample_name:
        adata.obs['sample'] = sample_name
    
    print(f"  Loaded {adata.n_obs:,} cells, {adata.n_vars} markers")
    return adata

In [None]:
def load_rna_data(matrix_path, sample_name=None, 
                  min_counts=500, max_counts=25000, 
                  min_genes=200, max_mt_pct=25):
    """
    Load and filter raw 10x RNA-seq data.
    
    Parameters
    ----------
    matrix_path : str
        Path to directory containing matrix.mtx.gz, features.tsv.gz, barcodes.tsv.gz
    sample_name : str, optional
        Name to add to obs for sample tracking
    min_counts, max_counts : int
        UMI count thresholds
    min_genes : int
        Minimum genes detected per cell
    max_mt_pct : float
        Maximum mitochondrial percentage
        
    Returns
    -------
    AnnData
        Filtered RNA expression data
    """
    print(f"Loading RNA data from: {matrix_path}")
    
    # Load raw matrix
    mtx_path = Path(matrix_path)
    rna_mtx = mmread(mtx_path / 'matrix.mtx.gz')
    rna_names = pd.read_csv(mtx_path / 'features.tsv.gz', sep='\t', header=None)[1].to_numpy()
    rna_barcodes = pd.read_csv(mtx_path / 'barcodes.tsv.gz', header=None)[0].values
    
    adata = ad.AnnData(rna_mtx.T.tocsr(), dtype=np.float32)
    adata.var_names = rna_names
    adata.var_names_make_unique()
    adata.obs_names = rna_barcodes
    
    print(f"  Raw: {adata.n_obs:,} barcodes")
    
    # Calculate QC metrics
    adata.var['mt'] = adata.var_names.str.startswith('MT-')
    sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
    
    # Apply filters
    sc.pp.filter_cells(adata, min_counts=min_counts)
    sc.pp.filter_cells(adata, min_genes=min_genes)
    adata = adata[adata.obs['total_counts'] < max_counts, :].copy()
    adata = adata[adata.obs['pct_counts_mt'] < max_mt_pct, :].copy()
    sc.pp.filter_genes(adata, min_cells=3)
    
    if sample_name:
        adata.obs['sample'] = sample_name
    
    print(f"  Filtered: {adata.n_obs:,} cells, {adata.n_vars:,} genes")
    return adata

In [None]:
def build_correspondence(protein_adata, rna_adata, mapping_path):
    """
    Build protein-gene correspondence for integration.
    
    Parameters
    ----------
    protein_adata : AnnData
        Protein expression data
    rna_adata : AnnData  
        RNA expression data
    mapping_path : str
        Path to protein-gene mapping CSV
        
    Returns
    -------
    tuple
        (shared_proteins, shared_genes) - indices of matching features
    """
    mapping_df = pd.read_csv(mapping_path)
    
    shared_proteins = []
    shared_genes = []
    
    for _, row in mapping_df.iterrows():
        protein = row['protein']
        gene = row['gene']
        
        if protein in protein_adata.var_names and gene in rna_adata.var_names:
            shared_proteins.append(list(protein_adata.var_names).index(protein))
            shared_genes.append(list(rna_adata.var_names).index(gene))
    
    print(f"  Shared features: {len(shared_proteins)} protein-gene pairs")
    return shared_proteins, shared_genes

## Step 3: Define Integration Function

Wrap the MaxFuse integration pipeline for batch processing.

In [None]:
def run_maxfuse_integration(protein_adata, rna_adata, shared_protein_idx, shared_gene_idx,
                            n_batches=2, n_neighbors=15, leiden_resolution=0.8,
                            n_components=20, verbose=True):
    """
    Run MaxFuse integration pipeline.
    
    Parameters
    ----------
    protein_adata : AnnData
        Protein expression data (CODEX)
    rna_adata : AnnData
        RNA expression data
    shared_protein_idx : list
        Indices of shared protein features
    shared_gene_idx : list
        Indices of shared gene features  
    n_batches : int
        Number of batches for splitting large datasets
    n_neighbors : int
        Number of neighbors for graph construction
    leiden_resolution : float
        Resolution for Leiden clustering
    n_components : int
        Number of CCA components
    verbose : bool
        Print progress messages
        
    Returns
    -------
    dict
        Integration results including matching and embeddings
    """
    from maxfuse import utils as mf_utils
    
    # Normalize data
    if verbose:
        print("Normalizing data...")
    
    # RNA: log-normalize
    rna_norm = rna_adata.copy()
    sc.pp.normalize_total(rna_norm, target_sum=1e4)
    sc.pp.log1p(rna_norm)
    
    # Protein: asinh transform
    protein_norm = np.arcsinh(protein_adata.X / 5)
    
    # Extract shared and active features
    rna_shared = rna_norm.X[:, shared_gene_idx]
    if sparse.issparse(rna_shared):
        rna_shared = rna_shared.toarray()
    
    protein_shared = protein_norm[:, shared_protein_idx]
    
    # Active features: highly variable genes
    sc.pp.highly_variable_genes(rna_norm, n_top_genes=2000)
    hvg_idx = rna_norm.var['highly_variable'].values
    rna_active = rna_norm.X[:, hvg_idx]
    if sparse.issparse(rna_active):
        rna_active = rna_active.toarray()
    
    protein_active = protein_norm.copy()
    
    # Initialize Fusor
    if verbose:
        print("Initializing MaxFuse...")
    
    fusor = mf_model.Fusor(
        shared_arr1=rna_shared,
        shared_arr2=protein_shared,
        active_arr1=rna_active,
        active_arr2=protein_active,
    )
    
    # Split into batches
    fusor.split_into_batches(
        n_batches1=n_batches,
        n_batches2=n_batches
    )
    
    # Construct graphs
    if verbose:
        print("Constructing graphs...")
    fusor.construct_graphs(
        n_neighbors1=n_neighbors,
        n_neighbors2=n_neighbors,
        leiden_resolution1=leiden_resolution,
        leiden_resolution2=leiden_resolution
    )
    
    # Find initial pivots
    if verbose:
        print("Finding initial pivots...")
    fusor.find_initial_pivots(
        wt_on_active=0.0,
        svd_components1=30,
        svd_components2=20
    )
    
    # Refine pivots
    if verbose:
        print("Refining pivots...")
    fusor.refine_pivots(
        n_iters=3,
        cca_components=n_components
    )
    
    # Filter bad matches
    fusor.filter_bad_matches(
        target='pivot',
        filter_prop=0.2
    )
    
    # Propagate to non-pivot cells
    if verbose:
        print("Propagating matches...")
    fusor.propagate(svd_components1=30, svd_components2=20)
    
    # Get results
    matching = fusor.get_matching(target='full_data')
    embedding1, embedding2 = fusor.get_embedding(target='full_data', cca_components=n_components)
    
    if verbose:
        print(f"Integration complete: {len(matching[0]):,} matches")
    
    return {
        'matching': matching,
        'embedding_rna': embedding1,
        'embedding_protein': embedding2,
        'fusor': fusor
    }

## Step 4: Process All Samples

Loop through samples and run the complete pipeline for each.

In [None]:
# Storage for all sample results
all_results = {}
all_protein_adata = {}
all_rna_adata = {}

for sample_id, config in SAMPLES.items():
    print(f"\n{'='*60}")
    print(f"Processing: {config['name']} ({sample_id})")
    print(f"{'='*60}")
    
    # Check if files exist (skip if not)
    if not Path(config['codex_path']).exists():
        print(f"  SKIPPED: CODEX file not found at {config['codex_path']}")
        continue
    if not Path(config['rna_matrix_path']).exists():
        print(f"  SKIPPED: RNA matrix not found at {config['rna_matrix_path']}")
        continue
    
    # Load data
    print("\n[1/4] Loading data...")
    protein_adata = load_codex_data(config['codex_path'], sample_name=config['name'])
    rna_adata = load_rna_data(
        config['rna_matrix_path'], 
        sample_name=config['name'],
        min_counts=config['min_umi'],
        max_counts=config['max_umi'],
        min_genes=config['min_genes'],
        max_mt_pct=config['max_mt_pct']
    )
    
    # Build correspondence
    print("\n[2/4] Building correspondence...")
    shared_protein_idx, shared_gene_idx = build_correspondence(
        protein_adata, rna_adata, config['protein_gene_map']
    )
    
    # Run integration
    print("\n[3/4] Running MaxFuse integration...")
    results = run_maxfuse_integration(
        protein_adata, rna_adata,
        shared_protein_idx, shared_gene_idx,
        n_batches=2, n_neighbors=15
    )
    
    # Save results
    print("\n[4/4] Saving results...")
    sample_dir = RESULTS_DIR / sample_id
    sample_dir.mkdir(exist_ok=True)
    
    # Save matching
    matching = results['matching']
    matching_df = pd.DataFrame({
        'rna_idx': matching[0],
        'protein_idx': matching[1],
        'distance': matching[2]
    })
    matching_df.to_csv(sample_dir / 'matching.csv', index=False)
    
    # Save embeddings
    np.save(sample_dir / 'embedding_rna.npy', results['embedding_rna'])
    np.save(sample_dir / 'embedding_protein.npy', results['embedding_protein'])
    
    # Store in memory for combined analysis
    all_results[sample_id] = results
    all_protein_adata[sample_id] = protein_adata
    all_rna_adata[sample_id] = rna_adata
    
    print(f"  Saved to: {sample_dir}")

print(f"\n{'='*60}")
print(f"Completed processing {len(all_results)} samples")
print(f"{'='*60}")

## Step 5: Combine Results Across Samples

Merge AnnData objects and embeddings for cross-sample analysis.

In [None]:
# Combine protein data across samples
if len(all_protein_adata) > 1:
    # Concatenate AnnData objects
    protein_combined = ad.concat(
        list(all_protein_adata.values()),
        join='outer',
        label='sample',
        keys=list(all_protein_adata.keys())
    )
    
    print(f"Combined protein data: {protein_combined.shape}")
    print(f"Samples: {protein_combined.obs['sample'].value_counts().to_dict()}")
elif len(all_protein_adata) == 1:
    protein_combined = list(all_protein_adata.values())[0]
    print(f"Single sample: {protein_combined.shape}")
else:
    print("No samples processed successfully.")

In [None]:
# Combine RNA data across samples
if len(all_rna_adata) > 1:
    rna_combined = ad.concat(
        list(all_rna_adata.values()),
        join='outer',
        label='sample',
        keys=list(all_rna_adata.keys())
    )
    
    print(f"Combined RNA data: {rna_combined.shape}")
    print(f"Samples: {rna_combined.obs['sample'].value_counts().to_dict()}")
elif len(all_rna_adata) == 1:
    rna_combined = list(all_rna_adata.values())[0]
    print(f"Single sample: {rna_combined.shape}")

In [None]:
# Combine embeddings with sample offset tracking
if len(all_results) > 0:
    embeddings_list = []
    sample_labels = []
    
    for sample_id, results in all_results.items():
        emb = results['embedding_protein']
        embeddings_list.append(emb)
        sample_labels.extend([sample_id] * len(emb))
    
    combined_embedding = np.vstack(embeddings_list)
    sample_labels = np.array(sample_labels)
    
    print(f"Combined embedding shape: {combined_embedding.shape}")
    print(f"Sample distribution: {pd.Series(sample_labels).value_counts().to_dict()}")

## Step 6: Cross-Sample Visualization

Visualize combined results to assess batch effects and integration quality.

In [None]:
# Compute UMAP on combined embedding
if len(all_results) > 0 and combined_embedding.shape[0] > 100:
    from sklearn.decomposition import PCA
    import umap
    
    # Subsample if too large
    max_cells = 50000
    if combined_embedding.shape[0] > max_cells:
        idx = np.random.choice(combined_embedding.shape[0], max_cells, replace=False)
        emb_sub = combined_embedding[idx]
        labels_sub = sample_labels[idx]
    else:
        emb_sub = combined_embedding
        labels_sub = sample_labels
    
    # UMAP
    print("Computing UMAP...")
    reducer = umap.UMAP(n_neighbors=30, min_dist=0.3, random_state=42)
    umap_coords = reducer.fit_transform(emb_sub)
    
    # Plot by sample
    fig, ax = plt.subplots(figsize=(10, 8))
    
    unique_samples = np.unique(labels_sub)
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_samples)))
    
    for i, sample in enumerate(unique_samples):
        mask = labels_sub == sample
        ax.scatter(umap_coords[mask, 0], umap_coords[mask, 1], 
                   c=[colors[i]], s=1, alpha=0.5, label=sample)
    
    ax.set_xlabel('UMAP1')
    ax.set_ylabel('UMAP2')
    ax.set_title('Combined Embedding - Colored by Sample')
    ax.legend(markerscale=10)
    plt.tight_layout()
    plt.savefig(RESULTS_DIR / 'combined_umap_by_sample.png', dpi=150)
    plt.show()
    
    print("Saved: combined_umap_by_sample.png")

## Step 7: (Optional) Batch Effect Correction

If samples show strong batch effects, apply correction using Harmony or similar methods.

In [None]:
# Optional: Harmony batch correction
# Uncomment if batch effects are observed in the UMAP above

# try:
#     import harmonypy as hm
#     
#     print("Running Harmony batch correction...")
#     harmony_out = hm.run_harmony(
#         combined_embedding, 
#         pd.DataFrame({'sample': sample_labels}), 
#         'sample'
#     )
#     combined_embedding_corrected = harmony_out.Z_corr.T
#     
#     # Recompute UMAP with corrected embedding
#     print("Computing UMAP on corrected embedding...")
#     umap_corrected = reducer.fit_transform(combined_embedding_corrected)
#     
#     # Plot comparison
#     fig, axes = plt.subplots(1, 2, figsize=(16, 6))
#     
#     for ax, coords, title in zip(
#         axes, 
#         [umap_coords, umap_corrected],
#         ['Before Harmony', 'After Harmony']
#     ):
#         for i, sample in enumerate(unique_samples):
#             mask = labels_sub == sample
#             ax.scatter(coords[mask, 0], coords[mask, 1], 
#                        c=[colors[i]], s=1, alpha=0.5, label=sample)
#         ax.set_title(title)
#         ax.legend(markerscale=10)
#     
#     plt.tight_layout()
#     plt.savefig(RESULTS_DIR / 'harmony_comparison.png', dpi=150)
#     plt.show()
#     
# except ImportError:
#     print("harmonypy not installed. Install with: pip install harmonypy")

## Step 8: Summary Statistics

Generate summary table of all processed samples.

In [None]:
# Create summary table
summary_data = []

for sample_id in all_results.keys():
    protein_adata = all_protein_adata[sample_id]
    rna_adata = all_rna_adata[sample_id]
    results = all_results[sample_id]
    
    summary_data.append({
        'Sample': sample_id,
        'CODEX Cells': protein_adata.n_obs,
        'RNA Cells': rna_adata.n_obs,
        'Protein Markers': protein_adata.n_vars,
        'RNA Genes': rna_adata.n_vars,
        'Matches': len(results['matching'][0]),
        'Match Rate': f"{100*len(results['matching'][0])/min(protein_adata.n_obs, rna_adata.n_obs):.1f}%"
    })

summary_df = pd.DataFrame(summary_data)
print("\nMulti-Sample Integration Summary")
print("=" * 80)
print(summary_df.to_string(index=False))

# Save summary
summary_df.to_csv(RESULTS_DIR / 'summary.csv', index=False)
print(f"\nSaved summary to: {RESULTS_DIR / 'summary.csv'}")

## Next Steps

After multi-sample integration:
1. **Cell type annotation**: Transfer labels using the combined embedding
2. **Differential analysis**: Compare cell type proportions across samples
3. **Spatial analysis**: Use region-aware matching (see analysis.ipynb)
4. **Export for visualization**: Save combined results for external tools

See `visualization.ipynb` and `analysis.ipynb` for downstream analysis workflows.