In [1]:
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import json
import matplotlib.pyplot as plt
from scipy import sparse
from scipy.io import mmread

# MaxFuse imports
import maxfuse as mf
from maxfuse import Fusor, core
from maxfuse.mario import pipelined_mario

import warnings
warnings.filterwarnings("ignore")

In [2]:
# Load preprocessed data from 1_preprocessing.ipynb
# Run 1_preprocessing.ipynb first to generate these files

import os

# Fix working directory if notebook started from wrong location
# The notebook expects to run from the notebooks/ directory
notebook_dir = os.path.dirname(os.path.abspath('__file__'))
if not os.path.exists('../results/1_preprocessing') and os.path.exists('/home/smith6jt/maxfuse/results/1_preprocessing'):
    os.chdir('/home/smith6jt/maxfuse/notebooks')
    print(f"Changed working directory to: {os.getcwd()}")

results_dir = '../results/1_preprocessing'

if not os.path.exists(results_dir):
    raise FileNotFoundError(
        f"Results directory '{results_dir}' not found. "
        f"Run 1_preprocessing.ipynb first to generate the input files."
    )

# Load processed AnnData objects
protein_adata = sc.read_h5ad(f'{results_dir}/protein_adata.h5ad')
rna_adata = sc.read_h5ad(f'{results_dir}/rna_adata.h5ad')
rna_adata_lognorm = sc.read_h5ad(f'{results_dir}/rna_adata_lognorm.h5ad')

print(f"Loaded from {results_dir}/")
print(f"  Protein data: {protein_adata.shape}")
print(f"  RNA data: {rna_adata.shape}")
print(f"  RNA log-normalized: {rna_adata_lognorm.shape}")

Changed working directory to: /home/smith6jt/maxfuse/notebooks
Loaded from ../results/1_preprocessing/
  Protein data: (1076404, 23)
  RNA data: (10321, 20550)
  RNA log-normalized: (10321, 20550)


## Step 3: Build Protein-Gene Correspondence

Map CODEX protein markers to their corresponding gene names in the RNA data.

In [3]:
# Load correspondence table
correspondence = pd.read_csv('../data/protein_gene_conversion.csv', encoding='utf-8-sig')
print(f"Correspondence table: {correspondence.shape[0]} entries")
correspondence.head(10)

Correspondence table: 370 entries


Unnamed: 0,Protein name,RNA name
0,CD80,CD80
1,CD86,CD86
2,CD274,CD274
3,CD273,PDCD1LG2
4,CD275,ICOSLG
5,CD275-1,ICOSLG
6,CD275-2,ICOSLG
7,CD11b,ITGAM
8,CD11b-1,ITGAM
9,CD11b-2,ITGAM


In [None]:
# Find matching features between CODEX markers and RNA genes
# Define markers to exclude (non-immune, structural, or problematic for CD45+ RNA data)
EXCLUDED_MARKERS = [
    'DAPI',           # Nuclear stain
    'ECAD', 'E-cadherin',  # Epithelial
    'IAPP', 'INS', 'GCG', 'SST',  # Pancreatic endocrine
    'Ker8-18', 'Pan-Cytokeratin', 'Keratin 5', 'EpCAM', 'TP63',  # Epithelial
    'Collagen IV', 'Vimentin', 'SMA', 'Caveolin',  # Stromal/structural
    'CD31', 'Podoplanin',  # Endothelial/lymphatic
    'Beta-actin',     # Housekeeping
    'LAG3', 'CD11b', 'CD206', 'Iba1', 'CD68', 'CD163', 'B3TUBB'         # Excluded per user
]

rna_protein_correspondence = []
unmatched_proteins = []

for marker in protein_adata.var_names:
    # Skip excluded markers
    if marker in EXCLUDED_MARKERS:
        continue
    
    # Look up in correspondence table
    matches = correspondence[correspondence['Protein name'].str.lower() == marker.lower()]
    
    if len(matches) == 0:
        # Try alternative names
        alt_names = {
            'CD3e': 'CD3',
            'FoxP3': 'FOXP3',
            'HLADR': 'HLA-DR',
            'Lyve1': 'LYVE1',
            'SMActin': 'SMA',
            'CollagenIV': 'collagen IV',
        }
        alt_marker = alt_names.get(marker, marker)
        matches = correspondence[correspondence['Protein name'].str.lower() == alt_marker.lower()]
    
    if len(matches) > 0:
        rna_names_str = matches.iloc[0]['RNA name']
        if 'Ignore' in str(rna_names_str):
            unmatched_proteins.append((marker, 'Ignored'))
            continue
        
        # Try each RNA name option_adata
        found = False
        for rna_name in str(rna_names_str).split('/'):
            if rna_name in rna_adata.var_names:
                rna_protein_correspondence.append([rna_name, marker])
                found = True
                break
        if not found:
            unmatched_proteins.append((marker, rna_names_str))
    else:
        unmatched_proteins.append((marker, 'Not in table'))

rna_protein_correspondence = np.array(rna_protein_correspondence)
print(f"Found {len(rna_protein_correspondence)} protein-gene pairs")
print(f"Excluded {len(EXCLUDED_MARKERS)} non-immune/structural markers")

if unmatched_proteins:
    print(f"\nUnmatched proteins ({len(unmatched_proteins)}):")
    for prot, reason in unmatched_proteins:
        print(f"  {prot}: {reason}")

Found 23 protein-gene pairs
Excluded 26 non-immune/structural markers


In [5]:
# Remove duplicates (same RNA mapping to multiple proteins)
# Keep first occurrence
seen_rna = set()
unique_pairs = []
for rna, prot in rna_protein_correspondence:
    if rna not in seen_rna:
        seen_rna.add(rna)
        unique_pairs.append([rna, prot])
    else:
        print(f"Removing duplicate RNA mapping: {rna} -> {prot}")

rna_protein_correspondence = np.array(unique_pairs)
print(f"\nFinal correspondence: {len(rna_protein_correspondence)} pairs")

print("\nMatched features:")
for rna, prot in rna_protein_correspondence:
    print(f"  {rna:15} <-> {prot}")


Final correspondence: 23 pairs

Matched features:
  LAMP1           <-> CD107a
  CD4             <-> CD4
  CEACAM1         <-> CD66
  CD38            <-> CD38
  CD8A            <-> CD8
  LGALS3          <-> M2Gal3
  HLA-DRA         <-> HLA-DR
  PCNA            <-> PCNA
  FOXP3           <-> FOXP3
  B3GAT1          <-> CD57
  MKI67           <-> Ki67
  MS4A1           <-> CD20
  GZMB            <-> Granzyme B
  VSIR            <-> VISTA
  PDCD1           <-> PD-1
  NCAM1           <-> CD56
  CD79A           <-> CD79a
  TCF7            <-> TCF-1
  TOX             <-> TOX
  CD274           <-> PD-L1
  CD3E            <-> CD3e
  BCL2            <-> Bcl-2
  ENTPD1          <-> CD39


## Step 4: Prepare Arrays for Integration

Extract and normalize:
- **Shared arrays**: Corresponding protein/gene features (used for initial matching)
- **Active arrays**: All features (used for refinement)

In [6]:
shared_rna_genes = rna_protein_correspondence[:, 0]  # RNA gene names
shared_protein_markers = rna_protein_correspondence[:, 1]  # Protein marker names

# Verify all features exist
missing_rna = [g for g in shared_rna_genes if g not in rna_adata.var_names]
missing_prot = [p for p in shared_protein_markers if p not in protein_adata.var_names]

if missing_rna:
    print(f"WARNING: Missing RNA genes: {missing_rna}")
if missing_prot:
    print(f"WARNING: Missing protein markers: {missing_prot}")

# Create shared feature AnnData objects
rna_shared_adata = rna_adata[:, shared_rna_genes].copy()
protein_shared_adata = protein_adata[:, shared_protein_markers].copy()

print(f"Shared feature AnnData objects created:")
print(f"  rna_shared_adata: {rna_shared_adata.shape}")
print(f"  protein_shared_adata: {protein_shared_adata.shape}")
print(f"\nRNA shared features: {list(rna_shared_adata.var_names[:5])}...")
print(f"Protein shared features: {list(protein_shared_adata.var_names[:5])}...")

Shared feature AnnData objects created:
  rna_shared_adata: (10321, 23)
  protein_shared_adata: (1076404, 23)

RNA shared features: ['LAMP1', 'CD4', 'CEACAM1', 'CD38', 'CD8A']...
Protein shared features: ['CD107a', 'CD4', 'CD66', 'CD38', 'CD8']...


In [7]:
protein_adata.layers['log'] = sc.pp.log1p(protein_adata.X, copy=True)

In [8]:
# Normalize shared features for MaxFuse
# BOTH modalities: detection-aware z-score
# - RNA: zeros = no transcript
# - Protein: zeros = below gate (from notebook 1)

from scipy import sparse, stats
from sklearn.preprocessing import StandardScaler

ZERO_VALUE = -1.0  # Value for non-expressing cells

print("=" * 70)
print("SHARED FEATURE NORMALIZATION")
print("=" * 70)

# ============================================================
# RNA SHARED
# ============================================================

rna_shared_lognorm = rna_adata_lognorm[:, shared_rna_genes].X.copy()
if sparse.issparse(rna_shared_lognorm):
    rna_shared_lognorm = rna_shared_lognorm.toarray()

rna_shared_raw = rna_shared_adata.X.copy()
if sparse.issparse(rna_shared_raw):
    rna_shared_raw = rna_shared_raw.toarray()

rna_expressing = rna_shared_lognorm > 0

# ============================================================
# PROTEIN SHARED
# ============================================================

shared_marker_names = [prot_name for rna_name, prot_name in rna_protein_correspondence]

protein_gated = protein_shared_adata.X.copy()
if sparse.issparse(protein_gated):
    protein_gated = protein_gated.toarray()

protein_expressing = protein_gated > 0

protein_log_layer = protein_adata.layers['log']
if sparse.issparse(protein_log_layer):
    protein_log_layer = protein_log_layer.toarray()

protein_log_shared = np.zeros((protein_adata.n_obs, len(shared_marker_names)))
for i, marker in enumerate(shared_marker_names):
    marker_idx = list(protein_adata.var_names).index(marker)
    protein_log_shared[:, i] = protein_log_layer[:, marker_idx]

# ============================================================
# ITERATIVE FILTERING: markers and cells until stable
# ============================================================

def filter_markers(rna_expr, prot_expr, correspondence, *arrays):
    """Remove markers with no expressing cells in either modality."""
    rna_has = rna_expr.any(axis=0)
    prot_has = prot_expr.any(axis=0)
    valid = rna_has & prot_has
    n_removed = (~valid).sum()
    if n_removed > 0:
        removed = correspondence[~valid]
        print(f"  Removing {n_removed} markers: {[p for r,p in removed]}")
    return valid, n_removed

def filter_cells(rna_expr, prot_expr):
    """Remove cells with no marker expression."""
    rna_any = rna_expr.any(axis=1)
    prot_any = prot_expr.any(axis=1)
    return rna_any, prot_any, (~rna_any).sum(), (~prot_any).sum()

print("\nFiltering markers and cells...")
iteration = 0
while True:
    iteration += 1
    changes = 0
    
    # Filter markers
    marker_valid, n_markers = filter_markers(rna_expressing, protein_expressing, rna_protein_correspondence)
    if n_markers > 0:
        changes += n_markers
        rna_shared_lognorm = rna_shared_lognorm[:, marker_valid]
        rna_shared_raw = rna_shared_raw[:, marker_valid]
        rna_expressing = rna_expressing[:, marker_valid]
        protein_gated = protein_gated[:, marker_valid]
        protein_log_shared = protein_log_shared[:, marker_valid]
        protein_expressing = protein_expressing[:, marker_valid]
        rna_protein_correspondence = rna_protein_correspondence[marker_valid]
    
    # Filter cells
    rna_keep, prot_keep, n_rna, n_prot = filter_cells(rna_expressing, protein_expressing)
    if n_rna > 0:
        print(f"  Removing {n_rna:,} RNA cells with no expression")
        changes += n_rna
        rna_shared_lognorm = rna_shared_lognorm[rna_keep]
        rna_shared_raw = rna_shared_raw[rna_keep]
        rna_expressing = rna_expressing[rna_keep]
        rna_shared_adata = rna_shared_adata[rna_keep].copy()
        rna_adata = rna_adata[rna_keep].copy()
        rna_adata_lognorm = rna_adata_lognorm[rna_keep].copy()
    
    if n_prot > 0:
        print(f"  Removing {n_prot:,} protein cells with no expression")
        changes += n_prot
        protein_gated = protein_gated[prot_keep]
        protein_log_shared = protein_log_shared[prot_keep]
        protein_expressing = protein_expressing[prot_keep]
        protein_shared_adata = protein_shared_adata[prot_keep].copy()
        protein_adata = protein_adata[prot_keep].copy()
    
    if changes == 0:
        break
    if iteration > 10:
        print("  WARNING: filtering did not converge")
        break

print(f"  Done in {iteration} iteration(s)")

# ============================================================
# DETECTION-AWARE Z-SCORE
# ============================================================

def normalize_expressing(data, expressing_mask, nonexpr_value):
    """Z-score expressing cells, set non-expressing to fixed value."""
    n_cells, n_features = data.shape
    normalized = np.full_like(data, nonexpr_value, dtype=np.float64)
    for j in range(n_features):
        expr = expressing_mask[:, j]
        if expr.sum() > 1:
            vals = data[expr, j]
            z = (vals - vals.mean()) / (vals.std() + 1e-8)
            normalized[expr, j] = np.clip(z, -3, 3)
    return normalized

rna_shared = normalize_expressing(rna_shared_lognorm, rna_expressing, ZERO_VALUE).astype(np.float32)
protein_shared = normalize_expressing(protein_log_shared, protein_expressing, ZERO_VALUE).astype(np.float32)

# Stats
rna_expr_vals = rna_shared[rna_expressing]
prot_expr_vals = protein_shared[protein_expressing]

print(f"\nRNA ({rna_shared.shape[1]} features, {rna_shared.shape[0]:,} cells):")
print(f"  Expressing: {rna_expressing.sum():,} values, mean={rna_expr_vals.mean():.3f}, std={rna_expr_vals.std():.3f}")

print(f"\nProtein ({protein_shared.shape[1]} features, {protein_shared.shape[0]:,} cells):")
print(f"  Expressing: {protein_expressing.sum():,} values, mean={prot_expr_vals.mean():.3f}, std={prot_expr_vals.std():.3f}")

# ============================================================
# REMOVE ZERO-VARIANCE FEATURES  
# ============================================================

rna_std = rna_shared.std(axis=0)
prot_std = protein_shared.std(axis=0)
valid_mask = (rna_std > 1e-6) & (prot_std > 1e-6)

if not valid_mask.all():
    n_removed = (~valid_mask).sum()
    print(f"\nRemoving {n_removed} zero-variance features")
    rna_shared = rna_shared[:, valid_mask]
    protein_shared = protein_shared[:, valid_mask]
    rna_expressing = rna_expressing[:, valid_mask]
    protein_expressing = protein_expressing[:, valid_mask]
    rna_protein_correspondence = rna_protein_correspondence[valid_mask]

rna_after_log = rna_shared_lognorm[:, valid_mask] if not valid_mask.all() else rna_shared_lognorm
protein_shared_raw = protein_log_shared[:, valid_mask] if not valid_mask.all() else protein_log_shared
rna_shared_raw = rna_shared_raw[:, valid_mask] if not valid_mask.all() else rna_shared_raw
rna_detection_mask = rna_expressing
protein_detection_mask = protein_expressing

rna_shared_adata = rna_shared_adata[:, valid_mask].copy() if not valid_mask.all() else rna_shared_adata
protein_shared_adata = protein_shared_adata[:, valid_mask].copy() if not valid_mask.all() else protein_shared_adata
rna_shared_adata.X = rna_shared
protein_shared_adata.X = protein_shared

# ============================================================
# SUMMARY
# ============================================================

n_features = rna_shared.shape[1]
print(f"\n" + "=" * 70)
print(f"FINAL: {n_features} shared features")
print(f"  RNA: {rna_shared.shape[0]:,} cells")
print(f"  Protein: {protein_shared.shape[0]:,} cells")
print("=" * 70)

SHARED FEATURE NORMALIZATION

Filtering markers and cells...
  Removing 382 RNA cells with no expression
  Done in 2 iteration(s)

RNA (23 features, 9,939 cells):
  Expressing: 34,970 values, mean=-0.004, std=0.985

Protein (23 features, 1,076,404 cells):
  Expressing: 18,238,329 values, mean=-0.006, std=0.978

FINAL: 23 shared features
  RNA: 9,939 cells
  Protein: 1,076,404 cells


In [None]:
# # Visualize normalization - focus on expressing cells
# from scipy.stats import rankdata, spearmanr
# from adjustText import adjust_text

# fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# feature_names = list(rna_protein_correspondence[:, 0])
# n_features = rna_shared.shape[1]

# # Count expressing cells per marker
# rna_expr_counts = [rna_expressing[:, i].sum() for i in range(n_features)]
# prot_expr_counts = [protein_expressing[:, i].sum() for i in range(n_features)]

# # Expression levels (mean of expressing cells, pre-normalization)
# rna_expr_means = [rna_after_log[rna_expressing[:, i], i].mean() if rna_expressing[:, i].sum() > 0 else 0 
#                   for i in range(n_features)]
# prot_expr_means = [protein_shared_raw[protein_expressing[:, i], i].mean() if protein_expressing[:, i].sum() > 0 else 0 
#                    for i in range(n_features)]

# # Row 1: Expressing cell counts and distributions
# ax = axes[0, 0]
# x_pos = np.arange(n_features)
# width = 0.35
# ax.bar(x_pos - width/2, rna_expr_counts, width, label='RNA', color='steelblue', alpha=0.8)
# ax.bar(x_pos + width/2, prot_expr_counts, width, label='Protein', color='darkorange', alpha=0.8)
# ax.set_ylabel('Expressing cells')
# ax.set_title('Cells Expressing Each Marker')
# ax.set_xticks(x_pos)
# ax.set_xticklabels([f[:5] for f in feature_names], rotation=45, ha='right', fontsize=7)
# ax.legend()
# ax.set_yscale('log')

# ax = axes[0, 1]
# rna_expr_vals = rna_shared[rna_expressing]
# prot_expr_vals = protein_shared[protein_expressing]
# bins = np.linspace(-3, 3, 50)
# ax.hist(rna_expr_vals, bins=bins, alpha=0.6, density=True, label='RNA', color='steelblue')
# ax.hist(prot_expr_vals, bins=bins, alpha=0.6, density=True, label='Protein', color='darkorange')
# ax.set_title('Z-score Distribution\n(expressing cells only)')
# ax.set_xlabel('Z-score')
# ax.legend()

# ax = axes[0, 2]
# # Rank-scale expression means
# rna_ranks = (rankdata(rna_expr_means) - 1) / (n_features - 1) if n_features > 1 else np.zeros(n_features)
# prot_ranks = (rankdata(prot_expr_means) - 1) / (n_features - 1) if n_features > 1 else np.zeros(n_features)
# ax.scatter(rna_ranks, prot_ranks, s=60, alpha=0.7, c='purple', edgecolors='black')
# texts = [ax.text(rna_ranks[i], prot_ranks[i], feature_names[i][:5], fontsize=7) for i in range(n_features)]
# ax.plot([0, 1], [0, 1], 'r--', alpha=0.5)
# spearman_corr, _ = spearmanr(rna_expr_means, prot_expr_means)
# ax.set_title(f'Expression Level Correlation\n(ρ={spearman_corr:.2f})')
# ax.set_xlabel('RNA (rank)')
# ax.set_ylabel('Protein (rank)')
# adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle='-', color='gray', alpha=0.5))

# # Row 2: Per-marker details
# ax = axes[1, 0]
# ax.bar(x_pos - width/2, rna_ranks, width, label='RNA', color='steelblue', alpha=0.8)
# ax.bar(x_pos + width/2, prot_ranks, width, label='Protein', color='darkorange', alpha=0.8)
# ax.set_ylabel('Expression level (rank)')
# ax.set_title('Per-Marker Expression')
# ax.set_xticks(x_pos)
# ax.set_xticklabels([f[:5] for f in feature_names], rotation=45, ha='right', fontsize=7)
# ax.legend()

# # Pick a well-expressed marker for example
# best_idx = np.argmax([rna_expr_counts[i] * prot_expr_counts[i] for i in range(n_features)])
# best_name = feature_names[best_idx]

# ax = axes[1, 1]
# rna_vals = rna_shared[rna_expressing[:, best_idx], best_idx]
# ax.hist(rna_vals, bins=30, alpha=0.7, color='steelblue', edgecolor='white')
# ax.set_title(f'RNA: {best_name}\n({len(rna_vals):,} expressing cells)')
# ax.set_xlabel('Z-score')

# ax = axes[1, 2]
# prot_vals = protein_shared[protein_expressing[:, best_idx], best_idx]
# ax.hist(prot_vals, bins=30, alpha=0.7, color='darkorange', edgecolor='white')
# ax.set_title(f'Protein: {best_name}\n({len(prot_vals):,} expressing cells)')
# ax.set_xlabel('Z-score')

# plt.tight_layout()
# plt.show()

# print(f"Expression correlation (Spearman): ρ = {spearman_corr:.3f}")

In [None]:
# # Shared feature summary
# print("=" * 60)
# print("SHARED FEATURE SUMMARY")
# print("=" * 60)

# feature_stats = []
# for i, (rna_gene, prot_marker) in enumerate(rna_protein_correspondence):
#     rna_n = rna_detection_mask[:, i].sum()
#     prot_n = protein_detection_mask[:, i].sum()
#     feature_stats.append({
#         'Marker': prot_marker,
#         'Gene': rna_gene,
#         'RNA_expressing': rna_n,
#         'Protein_expressing': prot_n,
#     })

# stats_df = pd.DataFrame(feature_stats)
# stats_df = stats_df.sort_values('RNA_expressing', ascending=False)

# print(f"\n{'Marker':<12} {'Gene':<12} {'RNA cells':>12} {'Protein cells':>14}")
# print("-" * 52)
# for _, row in stats_df.iterrows():
#     print(f"{row['Marker'][:12]:<12} {row['Gene'][:12]:<12} {row['RNA_expressing']:>12,} {row['Protein_expressing']:>14,}")

# print(f"\n{'='*60}")
# print(f"Total: {len(rna_protein_correspondence)} shared features")
# print(f"RNA cells: {rna_shared.shape[0]:,}")
# print(f"Protein cells: {protein_shared.shape[0]:,}")

In [None]:
# # Protein active features - detection-aware z-score
# # Gated mask defines expressing, z-score LOG values

# protein_markers_active = [m for m in protein_adata.var_names if m not in EXCLUDED_MARKERS]

# # Gated data for expressing mask
# protein_adata_active = protein_adata[:, protein_markers_active].copy()
# protein_gated_active = protein_adata_active.X.copy()
# if sparse.issparse(protein_gated_active):
#     protein_gated_active = protein_gated_active.toarray()

# protein_active_expressing = protein_gated_active > 0

# # Log layer for z-scoring
# protein_log_full = protein_adata.layers['log']
# if sparse.issparse(protein_log_full):
#     protein_log_full = protein_log_full.toarray()

# active_marker_indices = [list(protein_adata.var_names).index(m) for m in protein_markers_active]
# protein_log_active = protein_log_full[:, active_marker_indices]

# # Detection-aware z-score
# protein_active_normalized = normalize_expressing(protein_log_active, protein_active_expressing, ZERO_VALUE)
# protein_active_normalized = protein_active_normalized.astype(np.float32)
# protein_adata_active.X = protein_active_normalized

# prot_active_expr = protein_active_normalized[protein_active_expressing]
# print(f"Protein active: {protein_adata_active.shape}")
# print(f"  {len(protein_adata.var_names)} markers → {len(protein_markers_active)} (excluded {len(EXCLUDED_MARKERS)} non-immune)")
# print(f"  Expressing: mean={prot_active_expr.mean():.3f}, std={prot_active_expr.std():.3f}")

In [None]:
# Make sure no column is static
mask = (
    (rna_shared_adata.X.toarray().std(axis=0) > 0.01)
    & (protein_shared_adata.X.std(axis=0) > 0.01)
)
rna_shared = rna_shared_adata[:, mask].copy()
protein_shared = protein_shared_adata[:, mask].copy()
print([rna_shared.shape,protein_shared.shape])

In [None]:
# rna_active = rna_active[:, rna_var > 0.3]
# protein_active = protein_active[:, protein_var > 0.2]

# print(f"\nActive arrays:")
# print(f"  rna_active: {rna_active.shape}")
# print(f"  protein_active: {protein_active.shape}")

In [None]:
# process all RNA features
sc.pp.normalize_total(rna_adata)
sc.pp.log1p(rna_adata)
sc.pp.highly_variable_genes(rna_adata, n_top_genes=5000)
keep_genes = ['LAMP1', 'CD4', 'CEACAM1', 'CD38', 'PCNA', 'FOXP3', 'B3GAT1', 'MKI67', 'VSIR', 'PDCD1', 'NCAM1', 'TCF7', 'CD3E', 'ENTPD1']
for gene in keep_genes:
    if gene in rna_adata.var_names:
        rna_adata.var.loc[gene, 'highly_variable'] = True
rna_adata = rna_adata[:, rna_adata.var.highly_variable].copy()
sc.pp.scale(rna_adata)

In [None]:
sc.pp.neighbors(rna_adata, n_neighbors=15)
sc.tl.umap(rna_adata)
sc.pl.umap(rna_adata, color="CD3E")

In [None]:
protein_adata_sample = sc.pp.sample(protein_adata, n=5000, copy=True)
sc.pp.neighbors(protein_adata_sample, n_neighbors=15)
sc.tl.umap(protein_adata_sample)
sc.pl.umap(protein_adata_sample, color='CD3e')

In [None]:
# sc.pp.normalize_total(rna_shared_adata)
# sc.pp.log1p(rna_shared_adata)
# sc.pp.scale(rna_shared_adata, zero_center=False)
sc.pp.neighbors(rna_shared, n_neighbors=15)
sc.tl.umap(rna_shared)
sc.pl.umap(rna_shared, color='CD3E')

In [None]:
protein_shared_sample = sc.pp.sample(protein_shared, n=5000, copy=True)
# sc.pp.normalize_total(protein_shared_sample)
# sc.pp.log1p(protein_shared_sample)
# sc.pp.scale(protein_shared_sample, zero_center=False)
sc.pp.neighbors(protein_shared_sample, n_neighbors=15)
sc.tl.umap(protein_shared_sample)
sc.pl.umap(protein_shared_sample, color='CD3e')

In [None]:
# make sure no feature is static
rna_active = rna_adata.X
protein_active = protein_adata.X
rna_active = rna_active[:, rna_active.std(axis=0) > 1e-5] # these are fine since already using variable features
protein_active = protein_active[:, protein_active.std(axis=0) > 1e-5] # protein are generally variable

In [None]:
# inspect shape of the four matrices
print(rna_active.shape)
print(protein_active.shape)
print(rna_shared.shape)
print(protein_shared.shape)

In [None]:
# # CRITICAL VALIDATION: Check array dimensions match
# print("=" * 50)
# print("DIMENSION VALIDATION")
# print("=" * 50)
# print(f"RNA shared cells:     {rna_shared.shape[0]}")
# print(f"RNA active cells:     {rna_active.shape[0]}")
# print(f"Protein shared cells: {protein_shared.shape[0]}")
# print(f"Protein active cells: {protein_active.shape[0]}")
# print()

# assert rna_shared.shape[0] == rna_active.shape[0], \
#     f"RNA mismatch: shared={rna_shared.shape[0]}, active={rna_active.shape[0]}"
# assert protein_shared.shape[0] == protein_active.shape[0], \
#     f"Protein mismatch: shared={protein_shared.shape[0]}, active={protein_active.shape[0]}"
# assert rna_shared.shape[1] == protein_shared.shape[1], \
#     f"Shared feature mismatch: RNA={rna_shared.shape[1]}, Protein={protein_shared.shape[1]}"

# print("All dimensions validated!")
# print(f"\nIntegrating {rna_active.shape[0]} RNA cells with {protein_active.shape[0]} protein cells")
# print(f"Using {rna_shared.shape[1]} shared features for initialization")

In [None]:
# # Save checkpoint for cross-modal integration
# # This allows skipping the MARIO section and going directly to MaxFuse

# checkpoint_dir = '../results/2_integration'
# os.makedirs(checkpoint_dir, exist_ok=True)

# # Save arrays needed for MaxFuse
# np.save(f'{checkpoint_dir}/checkpoint_rna_shared.npy', rna_shared)
# np.save(f'{checkpoint_dir}/checkpoint_protein_shared.npy', protein_shared)
# np.save(f'{checkpoint_dir}/checkpoint_rna_active.npy', rna_active)
# np.save(f'{checkpoint_dir}/checkpoint_protein_active.npy', protein_active)

# # Save correspondence
# correspondence_checkpoint = pd.DataFrame(
#     rna_protein_correspondence, 
#     columns=['rna_gene', 'protein_marker']
# )
# correspondence_checkpoint.to_csv(f'{checkpoint_dir}/checkpoint_correspondence.csv', index=False)

# print(f'Checkpoint saved to {checkpoint_dir}/')
# print(f'  - checkpoint_rna_shared.npy: {rna_shared.shape}')
# print(f'  - checkpoint_protein_shared.npy: {protein_shared.shape}')
# print(f'  - checkpoint_rna_active.npy: {rna_active.shape}')
# print(f'  - checkpoint_protein_active.npy: {protein_active.shape}')
# print(f'  - checkpoint_correspondence.csv')
# print(f'\nYou can now skip to MaxFuse Integration (Step 7) if desired.')

In [None]:
# # PRE-FILTER: Remove non-lymphocyte protein cells
# # Use GATED data (not z-scored) to identify expressing cells

# print("=" * 60)
# print("PRE-FILTERING: Non-lymphocyte protein cells")
# print("=" * 60)

# # Use gated data from protein_adata.X (0 = non-expressing, >0 = expressing)
# protein_gated_full = protein_adata.X.copy()
# if hasattr(protein_gated_full, 'toarray'):
#     protein_gated_full = protein_gated_full.toarray()

# marker_names = list(protein_adata.var_names)
# print(f"Total markers: {len(marker_names)}")

# # Define lymphocyte markers
# LYMPHOCYTE_MARKERS = ['CD3e', 'CD4', 'CD8', 'CD20', 'CD79a']
# available_lymph = [m for m in LYMPHOCYTE_MARKERS if m in marker_names]
# print(f"Lymphocyte markers found: {available_lymph}")

# if len(available_lymph) == 0:
#     print("WARNING: No lymphocyte markers found!")
# else:
#     # Get lymphocyte marker expression (gated - binary expressing or not)
#     lymph_idx = [marker_names.index(m) for m in available_lymph]
#     lymph_expressing = protein_gated_full[:, lymph_idx] > 0  # Boolean: expressing or not
    
#     # Count how many lymphocyte markers each cell expresses
#     n_lymph_markers = lymph_expressing.sum(axis=1)
#     any_lymph = lymph_expressing.any(axis=1)  # Expresses at least one
    
#     # Visualize
#     fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
#     # Number of lymphocyte markers expressed
#     ax = axes[0]
#     counts = np.bincount(n_lymph_markers, minlength=len(available_lymph)+1)
#     ax.bar(range(len(counts)), counts, edgecolor='black')
#     ax.set_xlabel('Number of Lymphocyte Markers Expressed')
#     ax.set_ylabel('Cell Count')
#     ax.set_title('Lymphocyte Marker Expression')
#     for i, c in enumerate(counts):
#         if c > 0:
#             ax.text(i, c + counts.max()*0.02, f'{c:,}', ha='center', fontsize=9)
    
#     # Per-marker expression counts
#     ax = axes[1]
#     expr_counts = [lymph_expressing[:, i].sum() for i in range(len(available_lymph))]
#     ax.bar(available_lymph, expr_counts, edgecolor='black')
#     ax.set_ylabel('Cells Expressing')
#     ax.set_title('Per-Marker Expression')
#     ax.tick_params(axis='x', rotation=45)
#     for i, c in enumerate(expr_counts):
#         ax.text(i, c + max(expr_counts)*0.02, f'{c:,}', ha='center', fontsize=9)
    
#     # Tissue breakdown
#     ax = axes[2]
#     tissues = protein_adata.obs['Tissue'].values
#     for tissue in np.unique(tissues):
#         t_mask = tissues == tissue
#         t_lymph = n_lymph_markers[t_mask]
#         ax.hist(t_lymph, bins=range(len(available_lymph)+2), alpha=0.6, label=f'{tissue} ({t_mask.sum():,})')
#     ax.set_xlabel('Number of Lymphocyte Markers')
#     ax.set_ylabel('Count')
#     ax.set_title('By Tissue')
#     ax.legend()
    
#     plt.tight_layout()
#     plt.show()
    
#     # Summary
#     print(f"\nCells expressing 0 lymphocyte markers: {(n_lymph_markers == 0).sum():,}")
#     print(f"Cells expressing 1+ lymphocyte markers: {any_lymph.sum():,}")
    
#     # Store for next cell
#     lymph_score = n_lymph_markers  # Use count of markers as score
#     print("\n→ Run next cell to apply filter")

In [None]:
# # APPLY LYMPHOCYTE FILTER
# # Keep cells expressing at least 1 lymphocyte marker

# MIN_LYMPH_MARKERS = 1  # Cells must express at least this many lymphocyte markers

# keep_mask = lymph_score >= MIN_LYMPH_MARKERS

# n_before = len(keep_mask)
# n_after = keep_mask.sum()
# n_removed = n_before - n_after

# print("=" * 60)
# print("APPLYING LYMPHOCYTE FILTER")
# print("=" * 60)
# print(f"Threshold: express at least {MIN_LYMPH_MARKERS} lymphocyte marker(s)")
# print(f"Cells before: {n_before:,}")
# print(f"Cells after:  {n_after:,}")
# print(f"Removed:      {n_removed:,} ({n_removed/n_before*100:.1f}%)")

# # Per-tissue breakdown
# tissues = protein_adata.obs['Tissue'].values
# for t in np.unique(tissues):
#     t_mask = tissues == t
#     t_kept = keep_mask[t_mask].sum()
#     print(f"  {t}: {t_mask.sum():,} → {t_kept:,} ({t_kept/t_mask.sum()*100:.0f}% retained)")

# # Apply filter
# protein_shared = protein_shared[keep_mask]
# protein_active = protein_active[keep_mask]
# protein_adata = protein_adata[keep_mask].copy()
# # protein_expressing = protein_expressing[keep_mask]
# # protein_detection_mask = protein_detection_mask[keep_mask]

# print(f"\nFiltered arrays:")
# print(f"  protein_shared: {protein_shared.shape}")
# print(f"  protein_active: {protein_active.shape}")

In [None]:
# Create Fusor - let MaxFuse cluster automatically
# IMPORTANT: Run the pre-filter cell BEFORE this cell

# Verify we're using filtered data
print("=" * 60)
print("CREATING FUSOR WITH CURRENT ARRAYS")
print("=" * 60)
print(f"  rna_shared:     {rna_shared.shape}")
print(f"  protein_shared: {protein_shared.shape}")
print(f"  rna_active:     {rna_active.shape}")
print(f"  protein_active: {protein_active.shape}")

# Sanity check - if protein has >400k cells, filter probably didn't run
# if protein_shared.shape[0] > 400000:
#     print("\n⚠️  WARNING: protein_shared has >400k cells!")
#     print("    Did you run the pre-filter cell first?")
#     print("    Re-run from the pre-filter cell to apply filtering.")

fusor = core.model.Fusor(
    shared_arr1=rna_shared,
    shared_arr2=protein_shared,
    active_arr1=rna_active,
    active_arr2=protein_active,
    labels1=None,  # Let MaxFuse cluster
    labels2=None
)

print(f"\nFusor created successfully.")

In [None]:
# Calculate appropriate batching parameters
n_rna = rna_active.shape[0]
n_prot = protein_active.shape[0]
ratio = n_prot / n_rna

print(f"RNA cells: {n_rna}")
print(f"Protein cells: {n_prot}")
print(f"Ratio (protein/RNA): {ratio:.1f}")

# Batching parameters
max_outward = min(8000, n_rna)

matching_ratio = int(ratio)  # sqrt scaling for large 
metacell_sz = 2  # No metacells - direct matching with small feature panel

print(f"\nBatching parameters:")
print(f"  max_outward_size: {max_outward}")
print(f"  matching_ratio: {matching_ratio} ")
print(f"  metacell_size: {metacell_sz}")

In [None]:
fusor.split_into_batches(
    max_outward_size=2000,
    matching_ratio=5,
    metacell_size=metacell_sz,
    verbose=True
)

In [None]:
fusor.plot_singular_values(target='active_arr1',  n_components=None)
fusor.plot_singular_values(target='active_arr2',  n_components=None)

In [None]:
n_prot_features = protein_active.shape[1]
n_rna_features = rna_active.shape[1]
n_shared = rna_shared.shape[1]

print(f"Dataset dimensions:")
print(f"  Protein features: {n_prot_features}")
print(f"  RNA features: {n_rna_features}")
print(f"  Shared features: {n_shared}")

# # Graph construction: SVD components must be < n_features
# # For small protein panels, use conservative values
svd_comp1_graph = min(40, n_rna_features - 1)  # RNA: plenty of features
svd_comp2_graph = min(20, n_prot_features - 1)  

print(f"\nGraph construction SVD components:")
print(f"  RNA: {svd_comp1_graph}")
print(f"  Protein: {svd_comp2_graph} (capped for {n_prot_features}-marker panel)")

In [None]:
# Construct graphs with automatic clustering
fusor.construct_graphs(
    n_neighbors1=15,
    n_neighbors2=15,
    svd_components1=svd_comp1_graph,
    svd_components2=svd_comp2_graph,
    resolution1=2.0,  
    resolution2=2.0,
    resolution_tol=0.1,
    leiden_runs=1,
    verbose=True)

In [None]:
# Visualize clustering results
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Get cluster labels from first batch
labels1_b0 = fusor._labels1[0]
labels2_b0 = fusor._labels2[0]

# RNA cluster sizes
ax = axes[0]
unique, counts = np.unique(labels1_b0, return_counts=True)
ax.bar(range(len(counts)), sorted(counts, reverse=True))
ax.set_xlabel('Cluster rank')
ax.set_ylabel('Cells')
ax.set_title(f'RNA Clusters (n={len(unique)})')
ax.axhline(y=np.mean(counts), color='r', linestyle='--', label=f'Mean: {np.mean(counts):.0f}')
ax.legend()

# Protein cluster sizes
ax = axes[1]
unique, counts = np.unique(labels2_b0, return_counts=True)
ax.bar(range(len(counts)), sorted(counts, reverse=True))
ax.set_xlabel('Cluster rank')
ax.set_ylabel('Cells')
ax.set_title(f'Protein Clusters (n={len(unique)})')
ax.axhline(y=np.mean(counts), color='r', linestyle='--', label=f'Mean: {np.mean(counts):.0f}')
ax.legend()

# Summary stats
ax = axes[2]
ax.axis('off')
stats_text = f'''Graph Construction Summary
{"="*40}

RNA (Batch 0):
  Clusters: {len(np.unique(labels1_b0))}
  Cells: {len(labels1_b0)}
  Mean cluster size: {np.mean(np.bincount(labels1_b0)):.1f}

Protein (Batch 0):
  Clusters: {len(np.unique(labels2_b0))}
  Cells: {len(labels2_b0)}
  Mean cluster size: {np.mean(np.bincount(labels2_b0)):.1f}
'''
ax.text(0.1, 0.9, stats_text, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.show()

In [None]:
svd_shared1 = min(25, n_shared - 1)  
svd_shared2 = min(20, n_shared - 1) 
print(f"Using {svd_shared1}/{svd_shared2} SVD components for shared features")

fusor.find_initial_pivots(
    wt1=0.3,  # Smoothing weight
    wt2=0.3,
    svd_components1=svd_shared1,
    svd_components2=svd_shared2,
    verbose=True
)

In [None]:
# Visualize initial pivot matching
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Get initial matching from first batch
init_match = fusor._init_matching[0]
init_rows, init_cols, init_scores = init_match

# Score distribution
ax = axes[0]
ax.hist(init_scores, bins=50, edgecolor='white', alpha=0.7)
ax.axvline(np.mean(init_scores), color='r', linestyle='--', 
           label=f'Mean: {np.mean(init_scores):.3f}')
ax.axvline(np.median(init_scores), color='g', linestyle='--',
           label=f'Median: {np.median(init_scores):.3f}')
ax.set_xlabel('Matching Score')
ax.set_ylabel('Count')
ax.set_title('Initial Pivot Scores')
ax.legend()

# Matches per RNA cell
ax = axes[1]
matches_per_rna = np.bincount(init_rows)
ax.hist(matches_per_rna[matches_per_rna > 0], bins=20, edgecolor='white', alpha=0.7)
ax.set_xlabel('Matches per RNA cell')
ax.set_ylabel('Count')
ax.set_title('Matching Density (RNA)')

# Summary
ax = axes[2]
ax.axis('off')
n_rna_matched = len(np.unique(init_rows))
n_prot_matched = len(np.unique(init_cols))
stats = f'''Initial Pivot Matching
{"="*40}

Total matches: {len(init_scores):,}
Unique RNA matched: {n_rna_matched:,}
Unique Protein matched: {n_prot_matched:,}

Score statistics:
  Min:    {np.min(init_scores):.4f}
  Max:    {np.max(init_scores):.4f}
  Mean:   {np.mean(init_scores):.4f}
  Median: {np.median(init_scores):.4f}
  Std:    {np.std(init_scores):.4f}
'''
ax.text(0.1, 0.9, stats, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.show()

In [None]:
# DIAGNOSTIC: Check if specific clusters drive the bimodality
# Some clusters may have shared features that don't correlate across modalities

# Get actual initial matching results
init_match = fusor._init_matching[0]
init_rows, init_cols, init_scores = init_match
init_scores = np.array(init_scores)

# Get cluster labels for matched cells
b1, b2 = fusor._batch1_to_batch2[0]
rna_clusters = fusor._labels1[b1][init_rows]
prot_clusters = fusor._labels2[b2][init_cols]

# Analyze scores by RNA cluster
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Score distribution by RNA cluster
ax = axes[0, 0]
unique_rna_clusters = np.unique(rna_clusters)
cluster_means = []
for c in unique_rna_clusters:
    mask = rna_clusters == c
    cluster_means.append((c, init_scores[mask].mean(), mask.sum()))
cluster_means.sort(key=lambda x: x[1])  # Sort by mean score

# Show top 5 best and worst clusters
n_show = min(5, len(cluster_means))
worst_clusters = cluster_means[-n_show:]
best_clusters = cluster_means[:n_show]

for c, mean, n in best_clusters:
    mask = rna_clusters == c
    ax.hist(init_scores[mask], bins=30, alpha=0.5, label=f'RNA cluster {c} (n={n}, μ={mean:.2f})')
ax.set_xlabel('Matching Score')
ax.set_ylabel('Count')
ax.set_title(f'Best {n_show} RNA Clusters (lowest scores)')
ax.legend(fontsize=8)

ax = axes[0, 1]
for c, mean, n in worst_clusters:
    mask = rna_clusters == c
    ax.hist(init_scores[mask], bins=30, alpha=0.5, label=f'RNA cluster {c} (n={n}, μ={mean:.2f})')
ax.set_xlabel('Matching Score')
ax.set_ylabel('Count')
ax.set_title(f'Worst {n_show} RNA Clusters (highest scores)')
ax.legend(fontsize=8)

# 2. Mean score per cluster
ax = axes[1, 0]
cluster_ids = [x[0] for x in cluster_means]
means = [x[1] for x in cluster_means]
sizes = [x[2] for x in cluster_means]
colors = ['red' if m > 0.5 else 'green' for m in means]
ax.bar(range(len(cluster_ids)), means, color=colors, alpha=0.7)
ax.axhline(0.5, color='black', linestyle='--', label='Threshold 0.5')
ax.set_xlabel('RNA Cluster (sorted by score)')
ax.set_ylabel('Mean Matching Score')
ax.set_title('Mean Score by RNA Cluster')
ax.legend()

# 3. Cluster size vs score
ax = axes[1, 1]
ax.scatter(sizes, means, alpha=0.6, c=colors)
ax.set_xlabel('Cluster Size')
ax.set_ylabel('Mean Matching Score')
ax.set_title('Cluster Size vs Mean Score')
ax.axhline(0.5, color='black', linestyle='--')

# Annotate outlier clusters
for c, mean, n in cluster_means:
    if mean > 0.6 or mean < 0.2:
        ax.annotate(f'{c}', (n, mean), fontsize=8)

plt.tight_layout()
plt.show()

# Summary: which clusters are problematic?
print("\n" + "=" * 60)
print("CLUSTER ANALYSIS SUMMARY")
print("=" * 60)
bad_clusters = [(c, m, n) for c, m, n in cluster_means if m > 0.5]
good_clusters = [(c, m, n) for c, m, n in cluster_means if m < 0.3]

print(f"\nTotal RNA clusters: {len(unique_rna_clusters)}")
print(f"Bad clusters (score > 0.5): {len(bad_clusters)}")
print(f"Good clusters (score < 0.3): {len(good_clusters)}")

if bad_clusters:
    print(f"\nWorst clusters (candidates for filtering):")
    total_bad = sum(n for c, m, n in bad_clusters)
    for c, m, n in sorted(bad_clusters, key=lambda x: -x[1])[:10]:
        print(f"  Cluster {c}: mean={m:.3f}, n={n:,} ({n/len(init_scores)*100:.1f}%)")
    print(f"  Total cells in bad clusters: {total_bad:,} ({total_bad/len(init_scores)*100:.1f}%)")

In [None]:
# Check canonical correlations
cca_comp_check = min(15, n_prot_features - 1)
fusor.plot_canonical_correlations(
    svd_components1=min(30, n_rna_features - 1),
    svd_components2=None,
    cca_components=cca_comp_check
)

In [None]:
# Refine pivots using CCA
cca_components = min(25, n_prot_features - 1)

fusor.refine_pivots(
    wt1=0.3,
    wt2=0.3,
    svd_components1=min(40, n_rna_features - 1),
    svd_components2=None,  
    cca_components=cca_components,
    n_iters=1,
    filter_prop=0.0,
    verbose=True
)

In [None]:
# Compare initial vs refined matching
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Get refined matching
refined_match = fusor._refined_matching[0]
ref_rows, ref_cols, ref_scores = refined_match

# Score comparison
ax = axes[0]
ax.hist(init_scores, bins=50, alpha=0.5, label='Initial', edgecolor='white')
ax.hist(ref_scores, bins=50, alpha=0.5, label='Refined', edgecolor='white')
ax.set_xlabel('Matching Score')
ax.set_ylabel('Count')
ax.set_title('Score Distribution: Initial vs Refined')
ax.legend()

# Score improvement
ax = axes[1]
ax.boxplot([init_scores, ref_scores], labels=['Initial', 'Refined'])
ax.set_ylabel('Score')
ax.set_title('Score Comparison')

# Summary
ax = axes[2]
ax.axis('off')
stats = f'''CCA Refinement Results
{"="*40}

                Initial    Refined
Matches:     {len(init_scores):>10,}  {len(ref_scores):>10,}
Mean score:  {np.mean(init_scores):>10.4f}  {np.mean(ref_scores):>10.4f}
Median:      {np.median(init_scores):>10.4f}  {np.median(ref_scores):>10.4f}
Std:         {np.std(init_scores):>10.4f}  {np.std(ref_scores):>10.4f}

Score change: {np.mean(ref_scores) - np.mean(init_scores):+.4f}
'''
ax.text(0.1, 0.9, stats, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.show()

In [None]:
fusor.filter_bad_matches(
    target='pivot',
    filter_prop=0.5,
    verbose=True
)

In [None]:
pivot_matching = fusor.get_matching(order=(2,1), target='pivot')

In [None]:
# We can inspect the first pivot pair.
[pivot_matching[0][0], pivot_matching[1][0], pivot_matching[2][0]]

In [None]:
fusor.propagate(
    svd_components1=min(40, n_rna_features - 1),
    svd_components2=None,
    wt1=0.7,
    wt2=0.7,
    verbose=True
)

In [None]:
propagate_filter_prop = 0.3

fusor.filter_bad_matches(
    target='propagated',
    filter_prop=propagate_filter_prop,
    verbose=True
)

In [None]:
full_matching = fusor.get_matching(order=(1,2), target='full_data')

In [None]:
pd.DataFrame(list(zip(full_matching[0], full_matching[1], full_matching[2])),
             columns = ['mod1_indx', 'mod2_indx', 'score'])
# columns: cell idx in mod1, cell idx in mod2, and matching scores

In [None]:
rna_cca, protein_cca_sub = fusor.get_embedding(
    active_arr1=fusor.active_arr1,
    active_arr2=fusor.active_arr2[full_matching[1],:]
)

In [None]:
rna_cca.shape

In [None]:
protein_cca_sub.shape

In [None]:
np.random.seed(42)
subs = 5085
randix = np.random.choice(rna_cca.shape[0], subs, replace = False)

dim_use = 6 # dimensions of the CCA embedding to be used for UMAP etc

cca_adata = ad.AnnData(
    np.concatenate((rna_cca[randix,:dim_use], protein_cca_sub[:,:dim_use]), axis=0),
    dtype=np.float32
)
cca_adata.obs['data_type'] = ['rna'] * subs + ['protein'] * protein_cca_sub.shape[0]

In [None]:
sc.pp.neighbors(cca_adata, n_neighbors=15)
sc.tl.umap(cca_adata, min_dist=0.4, spread=1.0)
sc.pl.umap(cca_adata, color='data_type', size=20)

In [None]:
rna_to_prot_matches = {}
for rna_idx, prot_idx, score in zip(full_matching[0], full_matching[1], full_matching[2]):
    if rna_idx not in rna_to_prot_matches:
        rna_to_prot_matches[rna_idx] = []
    rna_to_prot_matches[rna_idx].append((prot_idx, score))
