# Validation with Synthetic Data

Comprehensive validation of the inverse problem approach using synthetic data with known ground truth.

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import inverse_sc as isc
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams['figure.figsize'] = (10, 6)
sns.set_style('whitegrid')

## 1. Generate Data with Different Difficulty Levels

In [None]:
scenarios = ['simple', 'moderate', 'hard']
datasets = {}

for scenario in scenarios:
    adata, truth = isc.validation.generate_realistic_benchmark(scenario=scenario)
    datasets[scenario] = (adata, truth)
    print(f"{scenario}: {adata.shape}, dropout={(adata.X == 0).mean():.2%}")

## 2. Fit Models on Each Scenario

In [None]:
results_all = []

for scenario, (adata, truth) in datasets.items():
    print(f"\n=== {scenario.upper()} SCENARIO ===")
    
    # Fit inverse model
    isc.pp.fit_inverse_model(
        adata,
        n_epochs=150,
        batch_size=256,
    )
    
    # Benchmark
    results = isc.validation.benchmark_against_scanpy(adata, truth, run_scanpy=True)
    results['scenario'] = scenario
    results_all.append(results)
    
    # Calibration
    calib = isc.validation.uncertainty_calibration(adata, truth)
    print(f"Calibration score: {calib['calibration_score']:.3f}")

results_df = pd.concat(results_all, ignore_index=True)

## 3. Visualize Results

In [None]:
# Plot correlations across scenarios
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, metric in zip(axes, ['mean_cell_correlation', 'global_correlation', 'rmse']):
    pivot = results_df.pivot(index='scenario', columns='method', values=metric)
    pivot.plot(kind='bar', ax=ax)
    ax.set_title(metric)
    ax.set_xlabel('Scenario')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()

## 4. Detailed Analysis: Moderate Scenario

In [None]:
adata, truth = datasets['moderate']

# Cell-wise correlation distribution
Z_true = truth['Z_true']
Z_inferred = adata.obsm['Z_true_mean']

cell_corrs = []
for i in range(Z_true.shape[0]):
    corr = np.corrcoef(Z_true[i], Z_inferred[i])[0, 1]
    cell_corrs.append(corr)

plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.hist(cell_corrs, bins=50, edgecolor='black')
plt.xlabel('Correlation (True vs Inferred)')
plt.ylabel('Number of Cells')
plt.title('Cell-wise Recovery Quality')
plt.axvline(np.mean(cell_corrs), color='red', linestyle='--', label=f'Mean: {np.mean(cell_corrs):.3f}')
plt.legend()

# Scatter plot: true vs inferred (sample)
plt.subplot(1, 2, 2)
sample_cells = np.random.choice(Z_true.shape[0], 5, replace=False)
for cell_idx in sample_cells:
    plt.scatter(Z_true[cell_idx], Z_inferred[cell_idx], alpha=0.3, s=1)
plt.plot([0, Z_true.max()], [0, Z_true.max()], 'r--', label='Perfect recovery')
plt.xlabel('True Expression')
plt.ylabel('Inferred Expression')
plt.title('True vs Inferred (5 sample cells)')
plt.legend()

plt.tight_layout()
plt.show()

## 5. Uncertainty Analysis

In [None]:
Z_std = adata.obsm['Z_true_std']

# Plot: where uncertainty is high, is recovery worse?
avg_uncertainty = Z_std.mean(axis=1)

plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.scatter(avg_uncertainty, cell_corrs, alpha=0.5, s=10)
plt.xlabel('Average Uncertainty')
plt.ylabel('Recovery Correlation')
plt.title('Uncertainty vs Recovery Quality')

plt.subplot(1, 2, 2)
# Calibration check
errors = np.abs(Z_true - Z_inferred)
standardized_errors = errors / (Z_std + 1e-8)

plt.hist(standardized_errors.flatten(), bins=50, edgecolor='black', alpha=0.7, density=True)
x = np.linspace(0, 5, 100)
plt.plot(x, np.exp(-x**2/2) / np.sqrt(2*np.pi), 'r-', linewidth=2, label='Standard Normal')
plt.xlabel('Standardized Error')
plt.ylabel('Density')
plt.title('Uncertainty Calibration')
plt.legend()

plt.tight_layout()
plt.show()

## 6. Program Recovery

In [None]:
# Compare inferred programs to true programs
true_program_weights = truth['program_weights']
inferred_program_weights = adata.obsm['program_weights']

# Correlation matrix between true and inferred programs
n_programs = true_program_weights.shape[1]
corr_matrix = np.zeros((n_programs, n_programs))

for i in range(n_programs):
    for j in range(n_programs):
        corr_matrix[i, j] = np.corrcoef(
            true_program_weights[:, i],
            inferred_program_weights[:, j]
        )[0, 1]

plt.figure(figsize=(8, 6))
sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap='RdBu_r', center=0,
            xticklabels=[f'Inf_{i}' for i in range(n_programs)],
            yticklabels=[f'True_{i}' for i in range(n_programs)])
plt.title('Program Recovery: True vs Inferred')
plt.xlabel('Inferred Programs')
plt.ylabel('True Programs')
plt.tight_layout()
plt.show()

print("\nNote: High off-diagonal values may indicate label switching,")
print("which is expected (programs are unordered).")

## 7. Summary Statistics

In [None]:
print("=== VALIDATION SUMMARY ===")
print("\nRecovery Quality:")
print(f"  Mean cell correlation: {np.mean(cell_corrs):.3f} ± {np.std(cell_corrs):.3f}")
print(f"  Global correlation: {np.corrcoef(Z_true.flatten(), Z_inferred.flatten())[0,1]:.3f}")
print(f"  RMSE: {np.sqrt(np.mean((Z_true - Z_inferred)**2)):.3f}")

calib = isc.validation.uncertainty_calibration(adata, truth)
print("\nUncertainty Calibration:")
print(f"  Coverage at 1σ: {calib['coverage_1std']:.3f} (target: 0.68)")
print(f"  Coverage at 2σ: {calib['coverage_2std']:.3f} (target: 0.95)")
print(f"  Calibration score: {calib['calibration_score']:.3f}")

print("\n✓ Validation complete!")