# Notebook 2: Cross-Validation Strategies

This notebook covers:
1. Evaluation metric: Weighted ipcw C-index
2. Per-fold vs Global OOF evaluation
3. Stratified K-Fold CV

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sksurv.metrics import concordance_index_ipcw
from sksurv.util import Surv

# Paths
TRAIN_PATH = '/your_path/SurvivalPrediction/data'

## 1. Evaluation Metric

Evaluation metric is a weighted combination of ipcw C-indices; see readme file for details.

In [None]:
def define_risk_groups(X):
    """Define risk groups based on clinical risk factors."""
    risk_factors = pd.DataFrame(index=X.index)
    risk_factors['high_blast'] = (X['BM_BLAST'] > 10).astype(int)
    risk_factors['has_TP53'] = (X['has_TP53'] > 0).astype(int)
    risk_factors['low_hb'] = (X['HB'] < 10).astype(int)
    risk_factors['low_plt'] = (X['PLT'] < 50).astype(int)
    risk_factors['high_cyto'] = (X['cyto_risk_score'] >= 3).astype(int)
    
    n_risk_factors = risk_factors.sum(axis=1)
    
    return {
        'test_like': n_risk_factors >= 1,
        'high_risk': n_risk_factors >= 2,
    }

def weighted_cindex(risk, y_surv, risk_groups, tau=7.0):
    """
    Compute weighted IPCW C-index across risk groups.
    
    Uses sksurv's concordance_index_ipcw directly.
    
    Weights: 30% overall + 40% test-like + 30% high-risk
    
    Args:
        risk: Risk scores (higher = worse prognosis)
        y_surv: Structured survival array from Surv.from_arrays()
        risk_groups: Dict with 'test_like' and 'high_risk' boolean arrays
        tau: Truncation time
    
    Returns:
        Dict with C-index metrics for each group and weighted score
    """
    # Overall C-index
    c_overall = concordance_index_ipcw(y_surv, y_surv, risk, tau=tau)[0]
    
    # Test-like subgroup
    mask_test = risk_groups['test_like'].values
    y_surv_test = Surv.from_arrays(event=y_surv['event'][mask_test], time=y_surv['time'][mask_test])
    c_test = concordance_index_ipcw(y_surv, y_surv_test, risk[mask_test], tau=tau)[0]
    
    # High-risk subgroup
    mask_high = risk_groups['high_risk'].values
    y_surv_high = Surv.from_arrays(event=y_surv['event'][mask_high], time=y_surv['time'][mask_high])
    c_high = concordance_index_ipcw(y_surv, y_surv_high, risk[mask_high], tau=tau)[0]
    
    weighted = 0.3 * c_overall + 0.4 * c_test + 0.3 * c_high
    
    return {
        'overall': c_overall,
        'test_like': c_test,
        'high_risk': c_high,
        'weighted': weighted
    }

print("Risk groups and weighted C-index functions defined.")

Risk groups and weighted C-index functions defined.


## 2. Per-Fold vs Global OOF Evaluation

### Method 1: Per-Fold Averaging (Less Accurate)
```
For each fold k:
    1. Train on train_k
    2. Predict on val_k (~624 samples)
    3. Z-score normalize WITHIN fold
    4. Compute C-index on fold
Final score = Average of 5 fold C-indices
```

### Method 2: Global OOF (Recommended)
```
For each fold k:
    1. Train on train_k
    2. Store predictions at val_k indices → OOF array

After all folds:
    3. Z-score normalize ALL 3120 OOF predictions
    4. Compute single C-index on all samples
```

**Why Global OOF is better:**
- More statistical power (3120² pairs vs 5 × 624² pairs)
- Avoids "averaging non-linear metrics" problem

In [14]:
# Load data for demonstration
X_train = pd.read_csv(f'{TRAIN_PATH}/X_train_128features_clean_fixed_scaled.csv')
X_train_unscaled = pd.read_csv(f'{TRAIN_PATH}/X_train_128features_clean_fixed.csv')
target = pd.read_csv(f'{TRAIN_PATH}/target_train_clean_aligned.csv')

y_time = target['OS_YEARS'].values
y_event = target['OS_STATUS'].values
n_samples = len(X_train)

# Create structured survival array for IPCW C-index
y_surv = Surv.from_arrays(event=y_event.astype(bool), time=y_time)

# Risk groups
risk_groups = define_risk_groups(X_train_unscaled)

print(f"Samples: {n_samples}")
print(f"Events: {y_event.sum()} ({y_event.mean()*100:.1f}%)")
print(f"Survival array created: {type(y_surv)}")

Samples: 3120
Events: 1600.0 (51.3%)
Survival array created: <class 'numpy.ndarray'>


## 3. Stratified K-Fold CV

We stratify by event status AND TP53 mutation status to ensure:
- Balanced event rates in each fold
- Balanced high-risk patients (TP53 is a major risk factor)

In [15]:
# Create stratification variable
has_tp53 = (X_train_unscaled['has_TP53'] > 0).astype(int).values
strat_var = pd.Series([f"{int(e)}_{int(t)}" for e, t in zip(y_event, has_tp53)])

print("Stratification groups:")
print(strat_var.value_counts())

Stratification groups:
0_0    1430
1_0    1329
1_1     271
0_1      90
Name: count, dtype: int64


In [16]:
def global_oof_cv_template(model_fn, X, y_time, y_event, strat_var, risk_groups, n_splits=5, seed=42):
    """
    Template for Global OOF Cross-Validation.
    
    Uses sksurv's concordance_index_ipcw for evaluation.
    
    Args:
        model_fn: Function(X_train, y_time_train, y_event_train, X_val) -> predictions
        X: Feature matrix
        y_time: Survival times
        y_event: Event indicators
        strat_var: Stratification variable
        risk_groups: Dict with 'test_like' and 'high_risk' boolean arrays
        n_splits: Number of CV folds
        seed: Random seed
    
    Returns:
        Dict with C-index metrics
    """
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    oof_preds = np.zeros(len(X))
    
    X_arr = X.values if hasattr(X, 'values') else X
    
    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_arr, strat_var)):
        X_tr, X_val = X_arr[train_idx], X_arr[val_idx]
        y_time_tr = y_time[train_idx]
        y_event_tr = y_event[train_idx]
        
        # Train model and predict
        preds = model_fn(X_tr, y_time_tr, y_event_tr, X_val)
        oof_preds[val_idx] = preds
        
        print(f"  Fold {fold_idx+1}: {len(train_idx)} train, {len(val_idx)} val")
    
    # Global Z-score normalization
    oof_normalized = (oof_preds - oof_preds.mean()) / (oof_preds.std() + 1e-8)
    
    # Create structured survival array for all samples
    y_surv_all = Surv.from_arrays(event=y_event.astype(bool), time=y_time)
    
    # Compute metrics on ALL samples using IPCW C-index
    return weighted_cindex(oof_normalized, y_surv_all, risk_groups)

print("Global OOF CV template defined.")

Global OOF CV template defined.
