# Neural CATE Estimators: Siamese T-Learner vs JEPA-Style Learning

This notebook benchmarks modern neural architectures for treatment effect estimation:

1. **Siamese T-Learner**: Shared encoder with separate outcome heads
2. **JEPA-Style Causal Learner**: Joint embedding predictive architecture with causal regularization
3. **Baseline methods**: Standard T-Learner and S-Learner

We evaluate on synthetic gene expression data with known ground truth treatment effects.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import torch

import sys
sys.path.append('../../src')

from causalbiolab.estimation.cate import SLearner, TLearner
from causalbiolab.estimation.neural_cate import (
    SiameseTLearner,
    JEPACausalLearner,
    benchmark_neural_cate,
)

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# Plotting setup
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# Check for GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

---

## Part 1: Generate Synthetic Gene Expression Data

We simulate a realistic gene expression dataset with:
- **20,000 genes** (high-dimensional)
- **Heterogeneous treatment effects** (τ varies by cell state)
- **Confounding** (cell cycle affects both treatment and outcome)
- **Known ground truth** for evaluation

In [None]:
def generate_gene_expression_data(
    n_samples: int = 2000,
    n_genes: int = 1000,
    n_informative: int = 50,
    effect_heterogeneity: float = 0.5,
    noise_level: float = 0.1,
    random_state: int = 42,
):
    """
    Generate synthetic gene expression data with treatment effects.
    
    Data generating process:
    1. Sample gene expression from log-normal (realistic for RNA-seq)
    2. Define cell states based on first few PCs
    3. Treatment assignment depends on cell state (confounding)
    4. Treatment effect varies by cell state (heterogeneity)
    
    Returns:
        X: Gene expression (n_samples, n_genes)
        T: Treatment (n_samples,)
        Y: Outcome (n_samples,)
        tau_true: True CATE (n_samples,)
    """
    rng = np.random.RandomState(random_state)
    
    # 1. Generate gene expression (log-normal)
    # Most genes have low expression, few have high
    X = rng.lognormal(mean=0, sigma=1.5, size=(n_samples, n_genes))
    
    # Add structure: some genes are correlated (gene modules)
    n_modules = 10
    genes_per_module = n_genes // n_modules
    for i in range(n_modules):
        start = i * genes_per_module
        end = start + genes_per_module
        module_factor = rng.randn(n_samples, 1)
        X[:, start:end] += module_factor * 0.5
    
    # 2. Define cell states (first 3 PCs as proxy)
    from sklearn.decomposition import PCA
    pca = PCA(n_components=3)
    cell_states = pca.fit_transform(X)
    
    # 3. Treatment assignment (confounded by cell state)
    # Cells in certain states more likely to receive treatment
    propensity_logit = 0.5 * cell_states[:, 0] - 0.3 * cell_states[:, 1]
    propensity = 1 / (1 + np.exp(-propensity_logit))
    T = rng.binomial(1, propensity)
    
    # 4. Outcome model with heterogeneous treatment effects
    # Base outcome depends on informative genes
    beta = rng.randn(n_informative)
    Y_base = X[:, :n_informative] @ beta
    
    # Treatment effect varies by cell state
    # tau(x) = base_effect + heterogeneity * f(cell_state)
    base_effect = 2.0
    tau_true = base_effect + effect_heterogeneity * (
        cell_states[:, 0] ** 2 + cell_states[:, 1]
    )
    
    # Observed outcome
    Y = Y_base + T * tau_true + rng.normal(0, noise_level, n_samples)
    
    # Standardize
    X = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)
    Y = (Y - Y.mean()) / Y.std()
    tau_true = (tau_true - tau_true.mean()) / tau_true.std()
    
    return X, T, Y, tau_true, cell_states

In [None]:
# Generate data
print("Generating synthetic gene expression data...")
X, T, Y, tau_true, cell_states = generate_gene_expression_data(
    n_samples=2000,
    n_genes=1000,  # Use 1000 for faster training, can increase to 20000
    n_informative=50,
    effect_heterogeneity=0.5,
    noise_level=0.1,
)

print(f"Data shape: X={X.shape}, T={T.shape}, Y={Y.shape}")
print(f"Treatment balance: {T.mean():.2%} treated")
print(f"True CATE range: [{tau_true.min():.2f}, {tau_true.max():.2f}]")

In [None]:
# Visualize data
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Cell states colored by treatment
axes[0].scatter(cell_states[T==0, 0], cell_states[T==0, 1], 
                alpha=0.5, label='Control', s=20)
axes[0].scatter(cell_states[T==1, 0], cell_states[T==1, 1], 
                alpha=0.5, label='Treated', s=20)
axes[0].set_xlabel('PC1')
axes[0].set_ylabel('PC2')
axes[0].set_title('Cell States (Confounding)')
axes[0].legend()

# True CATE distribution
axes[1].hist(tau_true, bins=30, alpha=0.7, edgecolor='black')
axes[1].axvline(tau_true.mean(), color='r', linestyle='--', 
                label=f'Mean = {tau_true.mean():.2f}')
axes[1].set_xlabel('True CATE')
axes[1].set_ylabel('Frequency')
axes[1].set_title('True Treatment Effect Distribution')
axes[1].legend()

# CATE vs cell state
scatter = axes[2].scatter(cell_states[:, 0], cell_states[:, 1], 
                          c=tau_true, cmap='RdYlBu_r', s=20, alpha=0.6)
axes[2].set_xlabel('PC1')
axes[2].set_ylabel('PC2')
axes[2].set_title('True CATE by Cell State')
plt.colorbar(scatter, ax=axes[2], label='True CATE')

plt.tight_layout()
plt.show()

---

## Part 2: Train-Test Split

In [None]:
# Split data
X_train, X_test, T_train, T_test, Y_train, Y_test, tau_train, tau_test = train_test_split(
    X, T, Y, tau_true, test_size=0.3, random_state=42
)

print(f"Train: {len(X_train)} samples")
print(f"Test: {len(X_test)} samples")

---

## Part 3: Benchmark Methods

We compare:
1. **S-Learner** (baseline): Single model
2. **T-Learner** (baseline): Two independent models
3. **Siamese T-Learner** (ours): Shared encoder + separate heads
4. **JEPA-Causal** (ours): Joint embedding with causal regularization

In [None]:
# Initialize models
input_dim = X_train.shape[1]

models = {
    'S-Learner': SLearner(),
    'T-Learner': TLearner(),
    'Siamese-T': SiameseTLearner(
        input_dim=input_dim,
        hidden_dims=[256, 128, 64],
        learning_rate=1e-3,
        batch_size=128,
        n_epochs=100,
        device=device,
        verbose=True,
    ),
    'JEPA-Causal': JEPACausalLearner(
        input_dim=input_dim,
        context_dim=128,
        treatment_dim=32,
        target_dim=64,
        learning_rate=1e-3,
        batch_size=128,
        n_epochs=100,
        lambda_inv=0.1,
        device=device,
        verbose=True,
    ),
}

In [None]:
# Train all models and collect results
results = {}

for name, model in models.items():
    print(f"\n{'='*60}")
    print(f"Training {name}")
    print('='*60)
    
    # Fit
    model.fit(X_train, T_train, Y_train)
    
    # Predict on test set
    tau_pred = model.predict(X_test)
    
    # Evaluate
    rmse = np.sqrt(mean_squared_error(tau_test, tau_pred))
    r2 = r2_score(tau_test, tau_pred)
    mae = np.mean(np.abs(tau_test - tau_pred))
    
    # Store results
    results[name] = {
        'tau_pred': tau_pred,
        'rmse': rmse,
        'r2': r2,
        'mae': mae,
    }
    
    print(f"\nTest Set Performance:")
    print(f"  RMSE: {rmse:.4f}")
    print(f"  R²: {r2:.4f}")
    print(f"  MAE: {mae:.4f}")

---

## Part 4: Results Comparison

In [None]:
# Create comparison table
comparison_df = pd.DataFrame({
    'Method': list(results.keys()),
    'RMSE': [r['rmse'] for r in results.values()],
    'R²': [r['r2'] for r in results.values()],
    'MAE': [r['mae'] for r in results.values()],
})

# Sort by RMSE
comparison_df = comparison_df.sort_values('RMSE')

print("\n" + "="*70)
print("CATE Estimation Benchmark Results")
print("="*70)
print(comparison_df.to_string(index=False))
print("="*70)

In [None]:
# Visualize results
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Predicted vs True CATE
for name, result in results.items():
    axes[0, 0].scatter(tau_test, result['tau_pred'], alpha=0.5, label=name, s=20)

axes[0, 0].plot([tau_test.min(), tau_test.max()], 
                [tau_test.min(), tau_test.max()], 
                'k--', label='Perfect prediction')
axes[0, 0].set_xlabel('True CATE')
axes[0, 0].set_ylabel('Predicted CATE')
axes[0, 0].set_title('Predicted vs True CATE')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Error distribution
for name, result in results.items():
    errors = result['tau_pred'] - tau_test
    axes[0, 1].hist(errors, bins=30, alpha=0.5, label=name, edgecolor='black')

axes[0, 1].axvline(0, color='k', linestyle='--')
axes[0, 1].set_xlabel('Prediction Error')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Error Distribution')
axes[0, 1].legend()

# 3. Performance metrics
metrics = ['RMSE', 'R²', 'MAE']
x = np.arange(len(results))
width = 0.25

for i, metric in enumerate(metrics):
    values = [results[name][metric.lower().replace('²', '2')] for name in results.keys()]
    axes[1, 0].bar(x + i * width, values, width, label=metric)

axes[1, 0].set_xlabel('Method')
axes[1, 0].set_ylabel('Score')
axes[1, 0].set_title('Performance Metrics Comparison')
axes[1, 0].set_xticks(x + width)
axes[1, 0].set_xticklabels(results.keys(), rotation=45, ha='right')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3, axis='y')

# 4. Absolute error by true CATE
for name, result in results.items():
    abs_errors = np.abs(result['tau_pred'] - tau_test)
    axes[1, 1].scatter(tau_test, abs_errors, alpha=0.5, label=name, s=20)

axes[1, 1].set_xlabel('True CATE')
axes[1, 1].set_ylabel('Absolute Error')
axes[1, 1].set_title('Error vs True CATE')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## Part 5: Representation Analysis

Visualize learned representations from Siamese T-Learner and JEPA-Causal.

In [None]:
# Get representations
siamese_model = models['Siamese-T']
jepa_model = models['JEPA-Causal']

z_siamese = siamese_model.get_representations(X_test)
z_jepa = jepa_model.get_causal_representations(X_test)

print(f"Siamese representations: {z_siamese.shape}")
print(f"JEPA representations: {z_jepa.shape}")

In [None]:
# Visualize representations with UMAP
from sklearn.manifold import TSNE

# Reduce to 2D for visualization
tsne = TSNE(n_components=2, random_state=42)
z_siamese_2d = tsne.fit_transform(z_siamese)

tsne = TSNE(n_components=2, random_state=42)
z_jepa_2d = tsne.fit_transform(z_jepa)

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Siamese representations colored by treatment
axes[0, 0].scatter(z_siamese_2d[T_test==0, 0], z_siamese_2d[T_test==0, 1],
                   alpha=0.5, label='Control', s=20)
axes[0, 0].scatter(z_siamese_2d[T_test==1, 0], z_siamese_2d[T_test==1, 1],
                   alpha=0.5, label='Treated', s=20)
axes[0, 0].set_title('Siamese T-Learner Representations (by Treatment)')
axes[0, 0].legend()

# Siamese representations colored by true CATE
scatter = axes[0, 1].scatter(z_siamese_2d[:, 0], z_siamese_2d[:, 1],
                             c=tau_test, cmap='RdYlBu_r', s=20, alpha=0.6)
axes[0, 1].set_title('Siamese T-Learner Representations (by True CATE)')
plt.colorbar(scatter, ax=axes[0, 1], label='True CATE')

# JEPA representations colored by treatment
axes[1, 0].scatter(z_jepa_2d[T_test==0, 0], z_jepa_2d[T_test==0, 1],
                   alpha=0.5, label='Control', s=20)
axes[1, 0].scatter(z_jepa_2d[T_test==1, 0], z_jepa_2d[T_test==1, 1],
                   alpha=0.5, label='Treated', s=20)
axes[1, 0].set_title('JEPA-Causal Representations (by Treatment)')
axes[1, 0].legend()

# JEPA representations colored by true CATE
scatter = axes[1, 1].scatter(z_jepa_2d[:, 0], z_jepa_2d[:, 1],
                             c=tau_test, cmap='RdYlBu_r', s=20, alpha=0.6)
axes[1, 1].set_title('JEPA-Causal Representations (by True CATE)')
plt.colorbar(scatter, ax=axes[1, 1], label='True CATE')

plt.tight_layout()
plt.show()

---

## Summary and Key Findings

### Expected Results:

1. **Siamese T-Learner** should outperform standard T-Learner when:
   - Data is high-dimensional (many genes)
   - Treatment groups are imbalanced
   - Sample size is limited

2. **JEPA-Causal** should learn representations that:
   - Are invariant to treatment assignment (control/treated overlap)
   - Capture effect modifiers (CATE structure visible)
   - Provide better generalization

3. **Baseline methods** (S/T-Learner) provide strong performance but:
   - May struggle with high dimensions
   - Don't leverage representation learning
   - Less interpretable latent space

### Next Steps:

1. **Real data validation**: Test on Perturb-seq or drug response data
2. **Architecture search**: Optimize encoder depth, width, regularization
3. **Contrastive learning**: Add contrastive loss for better representations
4. **Uncertainty quantification**: Add Bayesian layers or ensembles