In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.metrics import (classification_report, confusion_matrix, 
                            roc_auc_score, roc_curve, precision_recall_curve,
                            f1_score, accuracy_score)
import statsmodels
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import json

In [None]:
# =============================================================================
# LOAD ALL SUBJECTS' DATA
# =============================================================================
# Update these paths to your data directories
preprocessed_dir = Path('path/to/preprocessed/files')
raw_dir = Path('path/to/raw/files')

# Get all preprocessed JSON files
preprocessed_files = list(preprocessed_dir.glob('preprocessing_*.json'))
print(f"Found {len(preprocessed_files)} preprocessed files")

# =============================================================================
# CHOOSE BASELINE METHOD
# =============================================================================
baseline_method = 't3_stable_pre_decision'  # or 't0_initial_fixation', 't1_post_stabilization', 't2_early_post_stimulus'

# =============================================================================
# EXTRACT FEATURES FOR ALL SUBJECTS
# =============================================================================
all_physiology_features = []
all_behavior_features = []
all_outcomes = []
all_subject_ids = []
all_trial_ids = []

for preprocessed_file in preprocessed_files:
    
    # Load preprocessed data
    with open(preprocessed_file, 'r') as f:
        preprocessed = json.load(f)
    
    subject_id = preprocessed['subject_id']
    print(f"\nProcessing subject: {subject_id}")
    
    # Find corresponding raw data file
    # Adjust this pattern to match your raw file naming convention
    raw_file = raw_dir / f"{subject_id}.json"  # or however your raw files are named
    
    if not raw_file.exists():
        print(f"  Warning: Raw file not found for {subject_id}, skipping...")
        continue
    
    with open(raw_file, 'r') as f:
        raw_data = json.load(f)
    
    subject_trial_count = 0
    
    for trial_id, trial_data in preprocessed['trials'].items():
        
        # Skip if preprocessing failed
        if trial_data['preprocessing_status'] != 'success':
            continue
        
        # Get the baseline-corrected pupil data for chosen method
        method_data = trial_data['baseline_methods'][baseline_method]
        
        if method_data['status'] != 'success':
            continue
        
        # Get raw trial data for behavioral features
        trial_idx = int(trial_id.split('_')[1])  # Extract trial number from 'trial_0', 'trial_1', etc.
        raw_trial = raw_data['trials'][trial_idx]
        
        # Skip trials that weren't submitted
        if not raw_trial['submitted']:
            continue
        
        # =============================================================================
        # PHYSIOLOGY FEATURES (from preprocessed pupil data)
        # =============================================================================
        pupil_avg = np.array(method_data['pupil_avg_baselined'])
        pupil_L = np.array(method_data['pupil_L_baselined'])
        pupil_R = np.array(method_data['pupil_R_baselined'])
        time = np.array(method_data['time_aligned'])
        
        # Remove NaN values
        valid_mask = ~np.isnan(pupil_avg)
        pupil_avg_clean = pupil_avg[valid_mask]
        time_clean = time[valid_mask]
        
        if len(pupil_avg_clean) < 10:
            continue
        
        # Calculate pupil dilation velocity (rate of change)
        pupil_velocity = np.diff(pupil_avg_clean)
        
        # Extract physiology features
        physiology_features = {
            # Basic statistics
            'pupil_mean': np.mean(pupil_avg_clean),
            'pupil_std': np.std(pupil_avg_clean),
            'pupil_min': np.min(pupil_avg_clean),
            'pupil_max': np.max(pupil_avg_clean),
            'pupil_range': np.max(pupil_avg_clean) - np.min(pupil_avg_clean),
            
            # Peak dilation metrics
            'pupil_peak_dilation': np.max(pupil_avg_clean),
            'time_to_peak': time_clean[np.argmax(pupil_avg_clean)],
            
            # Temporal dynamics
            'pupil_slope': np.polyfit(time_clean, pupil_avg_clean, 1)[0] if len(time_clean) > 1 else 0,
            'pupil_initial': np.mean(pupil_avg_clean[:10]),  # First 10 samples
            'pupil_final': np.mean(pupil_avg_clean[-10:]),   # Last 10 samples
            'pupil_change': np.mean(pupil_avg_clean[-10:]) - np.mean(pupil_avg_clean[:10]),
            
            # Variability metrics
            'pupil_cv': np.std(pupil_avg_clean) / np.abs(np.mean(pupil_avg_clean)) if np.mean(pupil_avg_clean) != 0 else 0,
            'pupil_velocity_mean': np.mean(np.abs(pupil_velocity)),
            'pupil_velocity_max': np.max(np.abs(pupil_velocity)),
            
            # Baseline values
            'baseline_L': method_data['baseline_L'],
            'baseline_R': method_data['baseline_R'],
            
            # Eye asymmetry
            'eye_asymmetry': np.mean(np.abs(pupil_L[valid_mask] - pupil_R[valid_mask])),
        }
        
        # =============================================================================
        # BEHAVIOR FEATURES (from raw trial data)
        # =============================================================================
        gamble_params = raw_trial['gamble parameters']
        lct = raw_trial['lct']
        
        # Extract timing information
        show_screen_time = None
        submit_time = None
        click_time = None
        
        for event in lct:
            if 'show screen' in event['event']:
                show_screen_time = event['time']
            elif 'gamble clicked' in event['event']:
                click_time = event['time']
            elif 'submit' in event['event']:
                submit_time = event['time']
        
        if show_screen_time is None or submit_time is None:
            continue
        
        # Calculate timing metrics (convert to seconds)
        reaction_time = (click_time - show_screen_time) / 1000 if click_time else np.nan
        decision_time = (submit_time - show_screen_time) / 1000
        
        # Calculate expected values
        invest_ev = (gamble_params['invest amount 1'] * gamble_params['invest probability 1'] + 
                     gamble_params['invest amount 2'] * gamble_params['invest probability 2'])
        keep_ev = gamble_params['keep amount']
        ev_difference = invest_ev - keep_ev
        
        # Choice information
        num_choice_switches = len(raw_trial['choices'])
        final_choice = raw_trial['choices'][-1] if len(raw_trial['choices']) > 0 else None
        chose_invest = 1 if final_choice == 'INVEST' else 0
        
        behavior_features = {
            # Timing
            'reaction_time': reaction_time if not np.isnan(reaction_time) else decision_time,
            'decision_time': decision_time,
                        
            # Gamble parameters
            'keep_amount': gamble_params['keep amount'],
            'invest_ev': invest_ev,
            'keep_ev': keep_ev,
            'ev_difference': ev_difference,
            'ambiguity': gamble_params['ambiguity'],
            'condition_social': 1 if gamble_params['condition'] == 'social' else 0,
            
            # Risk metrics
            'invest_variance': ((gamble_params['invest amount 1'] - invest_ev)**2 * gamble_params['invest probability 1'] +
                               (gamble_params['invest amount 2'] - invest_ev)**2 * gamble_params['invest probability 2']),
        }
        
        # =============================================================================
        # OUTCOME VARIABLE
        # =============================================================================
        outcome = chose_invest  # Predicting INVEST (1) vs KEEP (0)
        
        # Store everything
        all_physiology_features.append(physiology_features)
        all_behavior_features.append(behavior_features)
        all_outcomes.append(outcome)
        all_subject_ids.append(subject_id)
        all_trial_ids.append(f"{subject_id}_{trial_id}")
        
        subject_trial_count += 1
    
    print(f"  Extracted {subject_trial_count} valid trials")

# =============================================================================
# CREATE DATAFRAMES
# =============================================================================
physiology_df = pd.DataFrame(all_physiology_features)
behavior_df = pd.DataFrame(all_behavior_features)
outcomes = np.array(all_outcomes)
subjects = np.array(all_subject_ids)

print(f"\n{'='*80}")
print(f"DATA SUMMARY (ALL SUBJECTS)")
print(f"{'='*80}")
print(f"Total subjects: {len(np.unique(subjects))}")
print(f"Total valid trials: {len(outcomes)}")
print(f"Physiology features: {physiology_df.shape[1]}")
print(f"Behavior features: {behavior_df.shape[1]}")
print(f"\nOutcome distribution:")
print(f"  KEEP (0): {np.sum(outcomes == 0)} ({np.mean(outcomes == 0):.1%})")
print(f"  INVEST (1): {np.sum(outcomes == 1)} ({np.mean(outcomes == 1):.1%})")

# Per-subject breakdown
print(f"\nPer-subject trial counts:")
subject_counts = pd.DataFrame({'subject': subjects, 'outcome': outcomes})
print(subject_counts.groupby('subject')['outcome'].agg(['count', 'mean']))

# Check for missing values
print(f"\nMissing values in physiology features:")
missing_physio = physiology_df.isnull().sum()
if missing_physio.sum() > 0:
    print(missing_physio[missing_physio > 0])
else:
    print("None!")

print(f"\nMissing values in behavior features:")
missing_behavior = behavior_df.isnull().sum()
if missing_behavior.sum() > 0:
    print(missing_behavior[missing_behavior > 0])
else:
    print("None!")

# =============================================================================
# CHECK MULTICOLLINEARITY IN PHYSIOLOGY FEATURES
# =============================================================================
print(f"\n{'='*80}")
print(f"MULTICOLLINEARITY ANALYSIS - PHYSIOLOGY FEATURES")
print(f"{'='*80}")

# Calculate correlation matrix
physiology_corr = physiology_df.corr()

# Find highly correlated pairs (>0.9)
high_corr_pairs = []
for i in range(len(physiology_corr.columns)):
    for j in range(i+1, len(physiology_corr.columns)):
        if abs(physiology_corr.iloc[i, j]) > 0.9:
            high_corr_pairs.append({
                'feature_1': physiology_corr.columns[i],
                'feature_2': physiology_corr.columns[j],
                'correlation': physiology_corr.iloc[i, j]
            })

print(f"\n⚠️  Highly correlated feature pairs (|r| > 0.9):")
for pair in high_corr_pairs:
    print(f"  {pair['feature_1']} <-> {pair['feature_2']}: {pair['correlation']:.3f}")

# Visualize correlation matrix
fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(physiology_corr, annot=True, fmt='.2f', cmap='coolwarm', center=0,
            square=True, linewidths=0.5, cbar_kws={"shrink": 0.8})
ax.set_title('Physiology Features Correlation Matrix', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Calculate VIF (Variance Inflation Factor) for multicollinearity
from statsmodels.stats.outliers_influence import variance_inflation_factor

X_physio_for_vif = physiology_df.fillna(physiology_df.mean())  # VIF needs no NaN
vif_data = pd.DataFrame()
vif_data["Feature"] = X_physio_for_vif.columns
vif_data["VIF"] = [variance_inflation_factor(X_physio_for_vif.values, i) 
                   for i in range(len(X_physio_for_vif.columns))]
vif_data = vif_data.sort_values('VIF', ascending=False)

print(f"\nVariance Inflation Factor (VIF):")
print(vif_data)
print(f"\nNote: VIF > 10 indicates high multicollinearity")
print(f"      VIF > 5 indicates moderate multicollinearity")

# =============================================================================
# TASK 2: PHYSIOLOGY → OUTCOME (Direct)
# =============================================================================
print(f"\n{'='*80}")
print(f"TASK 2: PHYSIOLOGY → OUTCOME (Direct)")
print(f"{'='*80}")

# Prepare data
X_physio = physiology_df.values
y = outcomes

# Check class balance
print(f"\nClass balance check:")
print(f"  Minimum class size: {min(np.bincount(y))}")
print(f"  Class ratio: {np.mean(y):.3f}")

# Train-test split (stratified to maintain class balance)
X_train_physio, X_test_physio, y_train, y_test = train_test_split(
    X_physio, y, test_size=0.3, random_state=42, stratify=y
)

print(f"\nTraining samples: {len(y_train)} (KEEP: {np.sum(y_train==0)}, INVEST: {np.sum(y_train==1)})")
print(f"Test samples: {len(y_test)} (KEEP: {np.sum(y_test==0)}, INVEST: {np.sum(y_test==1)})")

# Train Random Forest classifier
rf_physio = RandomForestClassifier(
    n_estimators=100, 
    random_state=42, 
    max_depth=5,
    min_samples_split=10,
    min_samples_leaf=5,
    class_weight='balanced'  # Handle class imbalance
)
rf_physio.fit(X_train_physio, y_train)

# Predictions
y_pred_physio = rf_physio.predict(X_test_physio)
y_prob_physio = rf_physio.predict_proba(X_test_physio)[:, 1]

# Evaluation
accuracy_physio = accuracy_score(y_test, y_pred_physio)
print(f"\nAccuracy: {accuracy_physio:.3f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred_physio, target_names=['KEEP', 'INVEST']))

# Cross-validation
cv_scores = cross_val_score(rf_physio, X_physio, y, cv=5, scoring='accuracy')
print(f"\n5-Fold Cross-Validation Accuracy: {cv_scores.mean():.3f} (+/- {cv_scores.std():.3f})")

# Feature importance
feature_importance_physio = pd.DataFrame({
    'feature': physiology_df.columns,
    'importance': rf_physio.feature_importances_
}).sort_values('importance', ascending=False)

print("\nTop 10 Most Important Physiology Features:")
print(feature_importance_physio.head(10))

# Plot feature importance
fig, ax = plt.subplots(figsize=(10, 8))
top_n = min(15, len(feature_importance_physio))
ax.barh(range(top_n), feature_importance_physio['importance'][:top_n], color='steelblue', edgecolor='black')
ax.set_yticks(range(top_n))
ax.set_yticklabels(feature_importance_physio['feature'][:top_n])
ax.set_xlabel('Importance', fontsize=12, fontweight='bold')
ax.set_title(f'Top {top_n} Physiology Features for Outcome Prediction\n(All Subjects Combined)', 
             fontsize=14, fontweight='bold')
ax.invert_yaxis()
plt.tight_layout()
plt.show()

# Confusion matrix
cm_physio = confusion_matrix(y_test, y_pred_physio)
fig, ax = plt.subplots(figsize=(8, 7))
sns.heatmap(cm_physio, annot=True, fmt='d', cmap='Blues', ax=ax, cbar_kws={'label': 'Count'},
            xticklabels=['KEEP', 'INVEST'],
            yticklabels=['KEEP', 'INVEST'], 
            annot_kws={'size': 14, 'weight': 'bold'})
ax.set_title(f'Physiology → Outcome\nAccuracy: {accuracy_physio:.3f} | CV: {cv_scores.mean():.3f}', 
             fontsize=14, fontweight='bold')
ax.set_ylabel('True Label', fontsize=12, fontweight='bold')
ax.set_xlabel('Predicted Label', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

# =============================================================================
# TASK 3: BEHAVIOR → OUTCOME
# =============================================================================
print(f"\n{'='*80}")
print(f"TASK 3: BEHAVIOR → OUTCOME")
print(f"{'='*80}")

# Prepare data
X_behavior = behavior_df.values
y = outcomes

# Train-test split (same random state for fair comparison)
X_train_behavior, X_test_behavior, y_train, y_test = train_test_split(
    X_behavior, y, test_size=0.3, random_state=42, stratify=y
)

# Train Random Forest classifier
rf_behavior = RandomForestClassifier(
    n_estimators=100, 
    random_state=42, 
    max_depth=5,
    min_samples_split=10,
    min_samples_leaf=5,
    class_weight='balanced'
)
rf_behavior.fit(X_train_behavior, y_train)

# Predictions
y_pred_behavior = rf_behavior.predict(X_test_behavior)
y_prob_behavior = rf_behavior.predict_proba(X_test_behavior)[:, 1]

# Evaluation
accuracy_behavior = accuracy_score(y_test, y_pred_behavior)
print(f"\nAccuracy: {accuracy_behavior:.3f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred_behavior, target_names=['KEEP', 'INVEST']))

# Cross-validation
cv_scores_behavior = cross_val_score(rf_behavior, X_behavior, y, cv=5, scoring='accuracy')
print(f"\n5-Fold Cross-Validation Accuracy: {cv_scores_behavior.mean():.3f} (+/- {cv_scores_behavior.std():.3f})")

# Feature importance
feature_importance_behavior = pd.DataFrame({
    'feature': behavior_df.columns,
    'importance': rf_behavior.feature_importances_
}).sort_values('importance', ascending=False)

print("\nBehavior Feature Importance:")
print(feature_importance_behavior)

# Plot feature importance
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(range(len(feature_importance_behavior)), feature_importance_behavior['importance'], 
        color='seagreen', edgecolor='black')
ax.set_yticks(range(len(feature_importance_behavior)))
ax.set_yticklabels(feature_importance_behavior['feature'])
ax.set_xlabel('Importance', fontsize=12, fontweight='bold')
ax.set_title('Behavior Features for Outcome Prediction\n(All Subjects Combined)', 
             fontsize=14, fontweight='bold')
ax.invert_yaxis()
plt.tight_layout()
plt.show()

# Confusion matrix
cm_behavior = confusion_matrix(y_test, y_pred_behavior)
fig, ax = plt.subplots(figsize=(8, 7))
sns.heatmap(cm_behavior, annot=True, fmt='d', cmap='Greens', ax=ax, cbar_kws={'label': 'Count'},
            xticklabels=['KEEP', 'INVEST'],
            yticklabels=['KEEP', 'INVEST'],
            annot_kws={'size': 14, 'weight': 'bold'})
ax.set_title(f'Behavior → Outcome\nAccuracy: {accuracy_behavior:.3f} | CV: {cv_scores_behavior.mean():.3f}', 
             fontsize=14, fontweight='bold')
ax.set_ylabel('True Label', fontsize=12, fontweight='bold')
ax.set_xlabel('Predicted Label', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()