# Experiment 2: LOCO vs CMI on OSRCT Data

This notebook compares LOCO and CMI methods for instrument ranking on real OSRCT data.

## Objectives
1. Compare EHS scores across methods
2. Evaluate rank correlation (Spearman)
3. Check top-k instrument agreement
4. Compare runtime on real data
5. Analyze downstream impact on bounds

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import time
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

import sys
sys.path.insert(0, '..')

from causal_grounding import (
    create_ci_engine,
    rank_covariates,
    rank_covariates_across_sites,
    select_best_instrument,
    create_train_target_split,
    load_rct_data,
    discretize_covariates
)

print("Imports successful!")

## 1. Load OSRCT Data

In [None]:
# Define paths to data
osrct_path = Path('../confounded_datasets/anchoring1/anchoring1_age_beta_0.3.pkl')
rct_path = Path('../ManyLabs1/ML1_data_cleaned_discretized.pkl')

print(f"OSRCT path exists: {osrct_path.exists()}")
print(f"RCT path exists: {rct_path.exists()}")

if osrct_path.exists() and rct_path.exists():
    osrct_data = pd.read_pickle(osrct_path)
    rct_data = load_rct_data('anchoring1', str(rct_path))
    
    print(f"\nOSRCT shape: {osrct_data.shape}")
    print(f"RCT shape: {rct_data.shape}")
    print(f"\nColumns: {list(osrct_data.columns)}")
else:
    print("Data files not found. Please check the paths.")
    print("You may need to run the data preprocessing scripts first.")

In [None]:
# Create train/target split
target_site = 'mturk'

if 'osrct_data' in dir() and 'rct_data' in dir():
    training_data, target_data = create_train_target_split(
        osrct_data, rct_data, target_site=target_site
    )
    
    print(f"Training sites: {list(training_data.keys())}")
    print(f"Target site: {target_site}")
    print(f"\nSamples per training site:")
    for site, df in training_data.items():
        print(f"  {site}: {len(df)} (idle: {(df['F']=='idle').sum()}, on: {(df['F']=='on').sum()})")
    print(f"\nTarget samples: {len(target_data)}")

## 2. Define Variables

In [None]:
# Define treatment, outcome, and covariates
treatment = 'iv'  # Intervention variable
outcome = 'dv'    # Dependent variable

# Covariates to test as instruments
covariates = ['age_d', 'polideo_d', 'gender']  # Discretized covariates

# Check if all covariates exist
if 'training_data' in dir():
    sample_site = list(training_data.keys())[0]
    sample_df = training_data[sample_site]
    available_cols = sample_df.columns.tolist()
    
    print(f"Treatment: {treatment} (exists: {treatment in available_cols})")
    print(f"Outcome: {outcome} (exists: {outcome in available_cols})")
    print(f"\nCovariates:")
    for cov in covariates:
        exists = cov in available_cols
        print(f"  {cov}: exists={exists}")
        if exists:
            print(f"    Values: {sample_df[cov].unique()[:5]}...")

## 3. Create CI Engines

In [None]:
# Create engines for comparison
cmi_engine = create_ci_engine('cmi', n_permutations=500, random_seed=42)
loco_engine = create_ci_engine('loco', function_class='gbm', 
                                n_estimators=100, max_depth=3, random_state=42)

print("Engines created:")
print(f"  CMI: {type(cmi_engine).__name__}")
print(f"  LOCO: {type(loco_engine).__name__}")

## 4. Run EHS Scoring with Both Methods

In [None]:
def score_site_with_engine(site_data, engine, covariates, treatment, outcome):
    """Score all covariates for a single site with timing."""
    # Filter to idle regime
    idle_data = site_data[site_data['F'] == 'idle'].copy()
    
    results = []
    start = time.time()
    
    for z_a in covariates:
        z_b = [z for z in covariates if z != z_a]
        try:
            score_result = engine.score_ehs_criteria(
                idle_data, z_a, z_b, treatment, outcome
            )
            results.append(score_result)
        except Exception as e:
            print(f"  Error scoring {z_a}: {e}")
    
    elapsed = time.time() - start
    return pd.DataFrame(results), elapsed

# Run scoring for each site
if 'training_data' in dir():
    cmi_results = {}
    loco_results = {}
    cmi_times = {}
    loco_times = {}
    
    for site, site_data in training_data.items():
        print(f"\nProcessing site: {site}")
        
        # CMI scoring
        print("  Running CMI...")
        cmi_df, cmi_time = score_site_with_engine(
            site_data, cmi_engine, covariates, treatment, outcome
        )
        cmi_results[site] = cmi_df
        cmi_times[site] = cmi_time
        print(f"    CMI time: {cmi_time:.2f}s")
        
        # LOCO scoring
        print("  Running LOCO...")
        loco_df, loco_time = score_site_with_engine(
            site_data, loco_engine, covariates, treatment, outcome
        )
        loco_results[site] = loco_df
        loco_times[site] = loco_time
        print(f"    LOCO time: {loco_time:.2f}s")
    
    print("\n=== Scoring Complete ===")

## 5. Compare Rankings

In [None]:
def compare_rankings(cmi_df, loco_df, site_name):
    """Compare rankings from CMI and LOCO."""
    # Sort by score
    cmi_ranked = cmi_df.sort_values('score', ascending=False).reset_index(drop=True)
    loco_ranked = loco_df.sort_values('score', ascending=False).reset_index(drop=True)
    
    # Create comparison
    comparison = pd.DataFrame({
        'z_a': cmi_ranked['z_a'],
        'cmi_rank': range(1, len(cmi_ranked)+1),
        'cmi_score': cmi_ranked['score'],
        'cmi_passes_ehs': cmi_ranked['passes_ehs'],
    })
    
    # Add LOCO results
    loco_rank_map = {row['z_a']: i+1 for i, row in loco_ranked.iterrows()}
    loco_score_map = {row['z_a']: row['score'] for _, row in loco_ranked.iterrows()}
    loco_ehs_map = {row['z_a']: row['passes_ehs'] for _, row in loco_ranked.iterrows()}
    
    comparison['loco_rank'] = comparison['z_a'].map(loco_rank_map)
    comparison['loco_score'] = comparison['z_a'].map(loco_score_map)
    comparison['loco_passes_ehs'] = comparison['z_a'].map(loco_ehs_map)
    
    # Compute Spearman correlation
    spearman_corr, spearman_p = stats.spearmanr(
        comparison['cmi_rank'], comparison['loco_rank']
    )
    
    return comparison, spearman_corr, spearman_p

# Compare rankings for each site
if 'cmi_results' in dir() and 'loco_results' in dir():
    comparisons = {}
    correlations = {}
    
    for site in cmi_results.keys():
        comp, rho, pval = compare_rankings(cmi_results[site], loco_results[site], site)
        comparisons[site] = comp
        correlations[site] = {'spearman_rho': rho, 'p_value': pval}
        
        print(f"\n=== Site: {site} ===")
        print(comp.to_string(index=False))
        print(f"\nSpearman correlation: rho={rho:.3f}, p={pval:.3f}")

## 6. Visualize Score Comparison

In [None]:
if 'comparisons' in dir() and len(comparisons) > 0:
    n_sites = len(comparisons)
    fig, axes = plt.subplots(1, n_sites, figsize=(5*n_sites, 5))
    if n_sites == 1:
        axes = [axes]
    
    for idx, (site, comp) in enumerate(comparisons.items()):
        ax = axes[idx]
        
        # Scatter plot of scores
        ax.scatter(comp['cmi_score'], comp['loco_score'], s=100)
        
        # Add labels
        for _, row in comp.iterrows():
            ax.annotate(row['z_a'], 
                       (row['cmi_score'], row['loco_score']),
                       textcoords="offset points",
                       xytext=(5, 5))
        
        # Add diagonal line
        min_val = min(comp['cmi_score'].min(), comp['loco_score'].min())
        max_val = max(comp['cmi_score'].max(), comp['loco_score'].max())
        ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5)
        
        ax.set_xlabel('CMI Score')
        ax.set_ylabel('LOCO Score')
        rho = correlations[site]['spearman_rho']
        ax.set_title(f'{site} (rho={rho:.2f})')
    
    plt.tight_layout()
    plt.savefig('../results/loco_vs_cmi_osrct_scores.png', dpi=150, bbox_inches='tight')
    plt.show()

## 7. Runtime Comparison

In [None]:
if 'cmi_times' in dir() and 'loco_times' in dir():
    runtime_comparison = pd.DataFrame({
        'site': list(cmi_times.keys()),
        'cmi_time': list(cmi_times.values()),
        'loco_time': list(loco_times.values())
    })
    runtime_comparison['speedup'] = runtime_comparison['cmi_time'] / runtime_comparison['loco_time']
    
    print("Runtime Comparison (seconds):")
    print(runtime_comparison.to_string(index=False))
    print(f"\nMean CMI time: {runtime_comparison['cmi_time'].mean():.2f}s")
    print(f"Mean LOCO time: {runtime_comparison['loco_time'].mean():.2f}s")
    print(f"Mean speedup (CMI/LOCO): {runtime_comparison['speedup'].mean():.2f}x")

In [None]:
# Plot runtime comparison
if 'runtime_comparison' in dir():
    fig, ax = plt.subplots(figsize=(10, 5))
    
    x = np.arange(len(runtime_comparison))
    width = 0.35
    
    ax.bar(x - width/2, runtime_comparison['cmi_time'], width, label='CMI')
    ax.bar(x + width/2, runtime_comparison['loco_time'], width, label='LOCO')
    
    ax.set_xlabel('Site')
    ax.set_ylabel('Runtime (seconds)')
    ax.set_title('Runtime Comparison: CMI vs LOCO')
    ax.set_xticks(x)
    ax.set_xticklabels(runtime_comparison['site'], rotation=45, ha='right')
    ax.legend()
    
    plt.tight_layout()
    plt.savefig('../results/loco_vs_cmi_osrct_runtime.png', dpi=150, bbox_inches='tight')
    plt.show()

## 8. Top Instrument Agreement

In [None]:
if 'comparisons' in dir():
    print("Top Instrument Selection:")
    print("=" * 50)
    
    agreement_count = 0
    total_sites = len(comparisons)
    
    for site, comp in comparisons.items():
        cmi_top = comp.loc[comp['cmi_rank'] == 1, 'z_a'].values[0]
        loco_top = comp.loc[comp['loco_rank'] == 1, 'z_a'].values[0]
        
        agrees = cmi_top == loco_top
        if agrees:
            agreement_count += 1
        
        print(f"{site}:")
        print(f"  CMI top: {cmi_top}")
        print(f"  LOCO top: {loco_top}")
        print(f"  Agreement: {'YES' if agrees else 'NO'}")
    
    print(f"\nOverall Agreement: {agreement_count}/{total_sites} sites ({100*agreement_count/total_sites:.1f}%)")

## 9. EHS Criteria Agreement

In [None]:
if 'comparisons' in dir():
    print("EHS Criteria Agreement:")
    print("=" * 50)
    
    ehs_agreement = []
    
    for site, comp in comparisons.items():
        for _, row in comp.iterrows():
            agrees = row['cmi_passes_ehs'] == row['loco_passes_ehs']
            ehs_agreement.append({
                'site': site,
                'z_a': row['z_a'],
                'cmi_passes': row['cmi_passes_ehs'],
                'loco_passes': row['loco_passes_ehs'],
                'agrees': agrees
            })
    
    ehs_df = pd.DataFrame(ehs_agreement)
    print(ehs_df.to_string(index=False))
    print(f"\nOverall EHS Agreement: {ehs_df['agrees'].mean()*100:.1f}%")

## 10. Summary and Recommendations

In [None]:
if 'correlations' in dir() and 'runtime_comparison' in dir():
    print("\n" + "="*60)
    print("SUMMARY: LOCO vs CMI on OSRCT Data")
    print("="*60)
    
    # Rank correlation
    mean_rho = np.mean([c['spearman_rho'] for c in correlations.values()])
    print(f"\n1. Rank Correlation (Spearman):")
    print(f"   Mean rho = {mean_rho:.3f}")
    if mean_rho > 0.7:
        print("   -> Strong agreement between methods")
    elif mean_rho > 0.4:
        print("   -> Moderate agreement between methods")
    else:
        print("   -> Weak agreement - methods may produce different rankings")
    
    # Runtime
    mean_speedup = runtime_comparison['speedup'].mean()
    print(f"\n2. Runtime:")
    print(f"   Mean CMI time: {runtime_comparison['cmi_time'].mean():.2f}s")
    print(f"   Mean LOCO time: {runtime_comparison['loco_time'].mean():.2f}s")
    if mean_speedup > 1:
        print(f"   -> CMI is {mean_speedup:.1f}x slower than LOCO")
    else:
        print(f"   -> LOCO is {1/mean_speedup:.1f}x slower than CMI")
    
    # EHS Agreement
    if 'ehs_df' in dir():
        ehs_agree_rate = ehs_df['agrees'].mean()
        print(f"\n3. EHS Criteria Agreement: {ehs_agree_rate*100:.1f}%")
    
    print("\n" + "="*60)
    print("RECOMMENDATIONS")
    print("="*60)
    print("\nUse CMI when:")
    print("  - Data is discrete/categorical")
    print("  - Information-theoretic interpretation needed")
    print("\nUse LOCO when:")
    print("  - Data has continuous covariates")
    print("  - High-dimensional conditioning sets")
    print("  - Predictive interpretation preferred")

In [None]:
# Save all results
if 'comparisons' in dir():
    all_comparisons = pd.concat([df.assign(site=site) for site, df in comparisons.items()])
    all_comparisons.to_csv('../results/loco_vs_cmi_osrct_rankings.csv', index=False)
    
    correlations_df = pd.DataFrame([
        {'site': site, **corr} for site, corr in correlations.items()
    ])
    correlations_df.to_csv('../results/loco_vs_cmi_osrct_correlations.csv', index=False)
    
    if 'runtime_comparison' in dir():
        runtime_comparison.to_csv('../results/loco_vs_cmi_osrct_runtime.csv', index=False)
    
    print("Results saved to results/ directory")