In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold, StratifiedShuffleSplit, train_test_split
from sklearn.utils import compute_sample_weight
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
from sksurv.metrics import concordance_index_censored, cumulative_dynamic_auc
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# Load data
india = pd.read_csv('india_clean.csv')
jordan = pd.read_csv('jordan_clean.csv')
florida = pd.read_csv('florida_clean.csv')
california = pd.read_csv('california_clean.csv')

# Ensure LOS >= 1
for df in [india, jordan, florida, california]:
    if "los" in df.columns:
        df["los"] = df["los"].clip(lower=1)

# Registry
dataset_registry = {
    "India": india,
    "Jordan": jordan,
    "Florida": florida,
    "California": california
}


In [2]:
def to_sksurv_y(df):
    return np.array([(bool(e), t) for e, t in zip(df['event'], df['los'])],
                    dtype=[('event', bool), ('time', float)])

def to_X(df):
    return df[['age', 'sex', 'gcs']].values

def compute_dynamic_auc_v(y_train, y_test, risk, time_points):
    auc_curve = []
    for t in time_points:
        _, auc = cumulative_dynamic_auc(y_train, y_test, risk, times=t)
        auc_curve.append(auc)
    return np.array(auc_curve)

def print_summary_stats(name, cindices):
    print(f"{name:<25} → Mean: {np.mean(cindices):.3f}, SD: {np.std(cindices):.3f}")


In [3]:
def run_transfer_experiment(source_name, target_name, plot_auc=True, show_stats=True):
    print(f"\n Transfer: {source_name} → {target_name}")

    source_df = dataset_registry[source_name].copy()
    target_df = dataset_registry[target_name].copy()
    
    source_df = source_df[source_df['los'] > 0]
    target_df = target_df[target_df['los'] > 0]

    X_source, y_source = to_X(source_df), to_sksurv_y(source_df)
    X_target, y_target = to_X(target_df), to_sksurv_y(target_df)

    kf = KFold(n_splits=5, shuffle=True, random_state=42)

    # --- Direct Transfer
    c_direct = []
    for train_idx, test_idx in kf.split(X_target):
        m = GradientBoostingSurvivalAnalysis(n_estimators=150, random_state=0)
        m.fit(X_source, y_source)
        preds = m.predict(X_target[test_idx])
        c = concordance_index_censored(y_target['event'][test_idx], y_target['time'][test_idx], preds)[0]
        c_direct.append(c)

    # --- Standard Fine-Tune
    c_standard = []
    for train_idx, test_idx in kf.split(X_target):
        m = GradientBoostingSurvivalAnalysis(n_estimators=100, warm_start=True, random_state=0)
        m.fit(X_source, y_source)
        m.set_params(n_estimators=150)
        m.fit(X_target[train_idx], y_target[train_idx])
        preds = m.predict(X_target[test_idx])
        c = concordance_index_censored(y_target['event'][test_idx], y_target['time'][test_idx], preds)[0]
        c_standard.append(c)

    # --- Weighted Fine-Tune
    c_weighted = []
    for train_idx, test_idx in kf.split(X_target):
        m = GradientBoostingSurvivalAnalysis(n_estimators=100, warm_start=True, random_state=0)
        m.fit(X_source, y_source)
        m.set_params(n_estimators=150)
        weights = compute_sample_weight(class_weight={1: 5, 0: 1}, y=y_target['event'][train_idx].astype(int))
        m.fit(X_target[train_idx], y_target[train_idx], sample_weight=weights)
        preds = m.predict(X_target[test_idx])
        c = concordance_index_censored(y_target['event'][test_idx], y_target['time'][test_idx], preds)[0]
        c_weighted.append(c)

    if show_stats:
        print_summary_stats("Direct Transfer", c_direct)
        print_summary_stats("Standard Fine-Tuning", c_standard)
        print_summary_stats("Weighted Fine-Tuning", c_weighted)

    # Dynamic AUC
    model_direct = GradientBoostingSurvivalAnalysis(n_estimators=150, random_state=0)
    model_direct.fit(X_source, y_source)
    pred_direct = model_direct.predict(X_target)

    model_standard = GradientBoostingSurvivalAnalysis(n_estimators=100, warm_start=True, random_state=0)
    model_standard.fit(X_source, y_source)
    model_standard.set_params(n_estimators=150)
    model_standard.fit(X_target, y_target)
    pred_standard = model_standard.predict(X_target)

    model_weighted = GradientBoostingSurvivalAnalysis(n_estimators=100, warm_start=True, random_state=0)
    model_weighted.fit(X_source, y_source)
    model_weighted.set_params(n_estimators=150)
    sw = compute_sample_weight(class_weight={1: 5, 0: 1}, y=y_target['event'].astype(int))
    model_weighted.fit(X_target, y_target, sample_weight=sw)
    pred_weighted = model_weighted.predict(X_target)

    time_points = np.arange(1, 51)
    auc_direct = compute_dynamic_auc_v(y_target, y_target, pred_direct, time_points)
    auc_standard = compute_dynamic_auc_v(y_target, y_target, pred_standard, time_points)
    auc_weighted = compute_dynamic_auc_v(y_target, y_target, pred_weighted, time_points)

    if plot_auc:
        plt.figure(figsize=(10, 6), dpi=300)
        plt.plot(time_points, auc_direct, label='Direct Transfer', marker='o')
        plt.plot(time_points, auc_standard, label='Standard Fine-Tune', marker='s')
        plt.plot(time_points, auc_weighted, label='Weighted Fine-Tune', marker='^')
        plt.axhline(0.5, ls='--', color='gray')
        plt.title(f"AUC Curves: {source_name} → {target_name}")
        plt.xlabel("Time")
        plt.ylabel("AUC")
        plt.ylim(0.3, 1.0)
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    return {
        "source": source_name,
        "target": target_name,
        "direct_mean": np.mean(c_direct),
        "standard_mean": np.mean(c_standard),
        "weighted_mean": np.mean(c_weighted)
    }

def run_fewshot_dynamic_auc(source_df, target_df, source_name, target_name):
    print(f"\n Few-Shot Dynamic AUC: {source_name} → {target_name}")
    
    X_source = to_X(source_df)
    y_source = to_sksurv_y(source_df)
    X_target = to_X(target_df)
    y_target = to_sksurv_y(target_df)

    X_train_full, X_test, y_train_full, y_test = train_test_split(
        X_target, y_target, test_size=0.2, random_state=42
    )
    y_events = y_train_full["event"].astype(int)

    # Safe time range for AUC
    min_time = y_test["time"].min()
    max_time = y_test["time"].max()
    eval_times = np.linspace(min_time, max_time - 1e-6, num=50)

    ratios = [0.05, 0.10, 0.25, 0.50, 0.75, 0.90]
    for ratio in ratios:
        sss = StratifiedShuffleSplit(n_splits=1, train_size=ratio, random_state=42)
        train_idx, _ = next(sss.split(X_train_full, y_events))
        X_fewshot = X_train_full[train_idx]
        y_fewshot = y_train_full[train_idx]

        # Baseline
        model_baseline = GradientBoostingSurvivalAnalysis(n_estimators=100, random_state=42)
        model_baseline.fit(X_fewshot, y_fewshot)
        pred_baseline = model_baseline.predict(X_target)

        # Standard TL
        model_transfer = GradientBoostingSurvivalAnalysis(n_estimators=100, warm_start=True, random_state=42)
        model_transfer.fit(X_source, y_source)
        model_transfer.set_params(n_estimators=150)
        model_transfer.fit(X_fewshot, y_fewshot)
        pred_transfer = model_transfer.predict(X_target)

        # Weighted TL
        model_transfer_w = GradientBoostingSurvivalAnalysis(n_estimators=100, warm_start=True, random_state=42)
        model_transfer_w.fit(X_source, y_source)
        model_transfer_w.set_params(n_estimators=150)
        sw = compute_sample_weight(class_weight={1: 5, 0: 1}, y=y_fewshot["event"].astype(int))
        model_transfer_w.fit(X_fewshot, y_fewshot, sample_weight=sw)
        pred_transfer_w = model_transfer_w.predict(X_target)

        # Dynamic AUCs
        auc_base = compute_dynamic_auc_v(y_target, y_target, pred_baseline, eval_times)
        auc_trans = compute_dynamic_auc_v(y_target, y_target, pred_transfer, eval_times)
        auc_weighted = compute_dynamic_auc_v(y_target, y_target, pred_transfer_w, eval_times)

        # Plot
        plt.figure(figsize=(10, 6), dpi=300)
        plt.plot(eval_times, auc_base, label=f"{target_name}-only Baseline", marker='o')
        plt.plot(eval_times, auc_trans, label=f"Standard TL ({source_name} → {target_name})", marker='s')
        plt.plot(eval_times, auc_weighted, label=f"Weighted TL ({source_name} → {target_name})", marker='^')
        plt.axhline(0.5, ls='--', color='gray')
        plt.title(f'Dynamic AUC Comparison at Ratio {ratio:.2f}')
        plt.xlabel('Time (days)')
        plt.ylabel('AUC')
        plt.ylim(0.3, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.show()


In [None]:
results_all = []
names = list(dataset_registry.keys())

for source in names:
    for target in names:
        if source != target:
            result = run_transfer_experiment(source, target, plot_auc=True)
            run_fewshot_dynamic_auc(dataset_registry[source], dataset_registry[target], source, target)
            results_all.append(result)

# Final comparison table
all_transfer_df = pd.DataFrame(results_all)
all_transfer_df



 Transfer: India → Jordan
