In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold, StratifiedShuffleSplit, train_test_split
from sklearn.utils import compute_sample_weight
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
from sksurv.metrics import concordance_index_censored, cumulative_dynamic_auc
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# Load data
india = pd.read_csv('india_clean.csv')
jordan = pd.read_csv('jordan_clean.csv')
florida = pd.read_csv('florida_clean.csv')
california = pd.read_csv('california_clean.csv')

# Ensure LOS >= 1
for df in [india, jordan, florida, california]:
    if "los" in df.columns:
        df["los"] = df["los"].clip(lower=1)

# Registry
dataset_registry = {
    "India": india,
    "Jordan": jordan,
    "Florida": florida,
    "California": california
}


In [2]:
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit, train_test_split
from sklearn.utils import compute_sample_weight
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
from sksurv.metrics import concordance_index_censored, cumulative_dynamic_auc
from scipy.stats import wilcoxon
import matplotlib.pyplot as plt

# --- Utility functions ---
def to_sksurv_y(df):
    return np.array([(bool(e), t) for e, t in zip(df['event'], df['los'])],
                    dtype=[('event', bool), ('time', float)])

def to_X(df):
    return df[['age', 'sex', 'gcs']].values

def compute_dynamic_auc_v(y_train, y_test, risk, time_points):
    auc_curve = []
    for t in time_points:
        _, auc = cumulative_dynamic_auc(y_train, y_test, risk, times=t)
        auc_curve.append(auc)
    return np.array(auc_curve)

def print_cindex_summary(name, scores):
    print(f"{name:<25} → C-index: {np.mean(scores):.3f} ± {np.std(scores):.3f}")

def compare_models(cindex_a, cindex_b, name_a="A", name_b="B"):
    stat, p = wilcoxon(cindex_a, cindex_b)
    mean_diff = np.mean(np.array(cindex_b) - np.array(cindex_a))
    print(f"\n Wilcoxon: {name_a} vs. {name_b}")
    print(f"Mean difference: {mean_diff:.3f}")
    print(f"p-value: {p:.4f}")
    if p < 0.05:
        print("Statistically significant")
    else:
        print("Not statistically significant")

# --- Few-shot TL Evaluation ---
from sklearn.model_selection import KFold

def run_fewshot_transfer_cv(source_df, target_df, source_name, target_name):
    print(f"\n FEWSHOT TRANSFER (5-fold CV): {source_name} → {target_name}")

    X_source = to_X(source_df)
    y_source = to_sksurv_y(source_df)
    X_target = to_X(target_df)
    y_target = to_sksurv_y(target_df)

    ratios = [0.05, 0.10, 0.20]
    kf = KFold(n_splits=5, shuffle=True, random_state=42)

    for ratio in ratios:
        print(f"\n Ratio: {int(ratio*100)}%")

        c_base_all, c_trans_all, c_weighted_all = [], [], []

        for fold, (train_idx, test_idx) in enumerate(kf.split(X_target)):
            X_train, X_test = X_target[train_idx], X_target[test_idx]
            y_train, y_test = y_target[train_idx], y_target[test_idx]
            y_events = y_train["event"].astype(int)

            # Few-shot sampling
            sss = StratifiedShuffleSplit(n_splits=1, train_size=ratio, random_state=fold)
            few_idx, _ = next(sss.split(X_train, y_events))
            X_fewshot, y_fewshot = X_train[few_idx], y_train[few_idx]

            # --- Baseline ---
            model_b = GradientBoostingSurvivalAnalysis(n_estimators=100, random_state=fold)
            model_b.fit(X_fewshot, y_fewshot)
            pred_b = model_b.predict(X_test)
            c_b = concordance_index_censored(y_test["event"], y_test["time"], pred_b)[0]
            c_base_all.append(c_b)

            # --- Standard TL ---
            model_s = GradientBoostingSurvivalAnalysis(n_estimators=100, warm_start=True, random_state=fold)
            model_s.fit(X_source, y_source)
            model_s.set_params(n_estimators=150)
            model_s.fit(X_fewshot, y_fewshot)
            pred_s = model_s.predict(X_test)
            c_s = concordance_index_censored(y_test["event"], y_test["time"], pred_s)[0]
            c_trans_all.append(c_s)

            # --- Weighted TL ---
            model_w = GradientBoostingSurvivalAnalysis(n_estimators=100, warm_start=True, random_state=fold)
            model_w.fit(X_source, y_source)
            model_w.set_params(n_estimators=150)
            sample_weight = compute_sample_weight(class_weight={1: 5, 0: 1}, y=y_fewshot["event"].astype(int))
            model_w.fit(X_fewshot, y_fewshot, sample_weight=sample_weight)
            pred_w = model_w.predict(X_test)
            c_w = concordance_index_censored(y_test["event"], y_test["time"], pred_w)[0]
            c_weighted_all.append(c_w)

        # Print C-index results
        print_cindex_summary("Baseline", c_base_all)
        print_cindex_summary("Standard TL", c_trans_all)
        print_cindex_summary("Weighted TL", c_weighted_all)

        # Wilcoxon comparisons
        compare_models(c_base_all, c_trans_all, "Baseline", "Standard TL")
        compare_models(c_base_all, c_weighted_all, "Baseline", "Weighted TL")
        compare_models(c_trans_all, c_weighted_all, "Standard TL", "Weighted TL")

        # Final model (for AUC plot)
        model_b.fit(X_fewshot, y_fewshot)
        model_s.fit(X_fewshot, y_fewshot)
        model_w.fit(X_fewshot, y_fewshot, sample_weight=sample_weight)

        eval_times = np.linspace(y_target["time"].min(), y_target["time"].max() - 1e-6, 50)
        pred_auc_b = model_b.predict(X_target)
        pred_auc_s = model_s.predict(X_target)
        pred_auc_w = model_w.predict(X_target)

        auc_b = compute_dynamic_auc_v(y_target, y_target, pred_auc_b, eval_times)
        auc_s = compute_dynamic_auc_v(y_target, y_target, pred_auc_s, eval_times)
        auc_w = compute_dynamic_auc_v(y_target, y_target, pred_auc_w, eval_times)

        # Plot dynamic AUC
        plt.figure(figsize=(10, 6), dpi=300)
        plt.plot(eval_times, auc_b, label="Baseline", marker='o')
        plt.plot(eval_times, auc_s, label="Standard TL", marker='s')
        plt.plot(eval_times, auc_w, label="Weighted TL", marker='^')
        plt.axhline(0.5, ls='--', color='gray')
        plt.title(f"Few-Shot Transfer (CV) @ {int(ratio*100)}%: {source_name} → {target_name}")
        plt.xlabel("Time")
        plt.ylabel("Dynamic AUC")
        plt.ylim(0.3, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.show()


In [None]:
import itertools

# Prepare combinations
names = list(dataset_registry.keys())
all_pairs = list(itertools.permutations(names, 2))

# Run few-shot TL experiments for all pairs
for source_name, target_name in all_pairs:
    run_fewshot_transfer_cv(
        dataset_registry[source_name],
        dataset_registry[target_name],
        source_name,
        target_name
    )



 FEWSHOT TRANSFER (5-fold CV): India → Jordan

 Ratio: 5%
