In [None]:
import numpy as np
import pandas as pd
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import os
from lifelines import CoxPHFitter
from lifelines.utils import concordance_index


def compute_baseline_survival(durations, events, risk_preds, time_point):
    sorted_indices = np.argsort(durations)
    durations = durations[sorted_indices]
    events = events[sorted_indices]
    risk_preds = risk_preds[sorted_indices]

    H0 = 0
    H0_t = []
    unique_times = np.unique(durations)

    for t in unique_times:
        at_risk = durations >= t
        risk_sum = np.sum(np.exp(risk_preds[at_risk]))
        event_at_t = np.sum(events[durations == t])
        H0 += event_at_t / risk_sum
        H0_t.append(H0)

    S0_t = np.exp(-np.array(H0_t))
    baseline_survival = pd.DataFrame({'time': unique_times, 'S0': S0_t})

    S0_at_time_point = baseline_survival.loc[baseline_survival['time'] <= time_point, 'S0'].iloc[-1]
    return baseline_survival, S0_at_time_point

data_num = 148
fold_num = 4
three_year_days = 1.5 * 365
threshold = 0.2
seed_list = range(83, 84)
level = 1

count_005 = 0
count_010 = 0

# Create a PdfPages object outside of the loop to save all figures into one PDF
KM_path = f"data/KM_curve/data_num{data_num}/level{level}/fold_num_{fold_num}"
pdf_name = f"time:{three_year_days} days, p threshold:{threshold}.pdf"

pdf_path = os.path.join(KM_path, pdf_name)

if not os.path.exists(KM_path):
    os.makedirs(KM_path)

test_concat_path = f"data/test_results/level{level}/slide_num{data_num}/fold_num_{fold_num}/test_concat"
if not os.path.exists(test_concat_path):
    os.makedirs(test_concat_path)

c_index_list = []
with PdfPages(pdf_path) as pdf:
    for seed in seed_list:  # loop for each seed (50 experiments)
        combined_results = pd.DataFrame()
        all_slide_re = pd.DataFrame()

        # Loop through each fold for the current seed
        for fold in range(1, fold_num + 1):
            test_csv_path = f"data/test_results/level{level}/slide_num{data_num}/fold_num_{fold_num}/test_fold/test_results_seed{seed}_fold_{fold}.csv"
            test_re = pd.read_csv(test_csv_path)
            
            risks = test_re['risk']
            days = test_re['survival_time']
            events = test_re['event'].astype(int)
            patient_id = test_re['patient_id']
            
            train_csv_path = f"data/train_results/level{level}/slide_num{data_num}/fold_num_{fold_num}/train_results_seed{seed}_fold_{fold}.csv"
            train_re = pd.read_csv(train_csv_path)
            train_durations = train_re['durations'].values
            train_events = train_re['events'].values
            train_risk_preds = train_re['risk_preds'].values
            
            # cph = CoxPHFitter()
            # cph.fit(train_re, duration_col='durations', event_col='events', show_progress=True)
            # baseline_survival = cph.baseline_survival_.reset_index()
            # S0_3_year = baseline_survival.loc[baseline_survival['index'] <= three_year_days, 'baseline survival'].iloc[-1]
            

            _, S0_3_year = compute_baseline_survival(
                train_durations, train_events, train_risk_preds, three_year_days
            )

            # Calculate death probabilities
            survival_probs = S0_3_year ** np.exp(risks)
            death_probs = 1 - survival_probs
            
            result = pd.DataFrame({
                'patient_id': patient_id,
                'days': days,
                'events': events,
                'death_probs': death_probs,
                'risk': risks
            })
            combined_results = pd.concat([combined_results, result])
            result_slide = pd.DataFrame({
                'slide_id': test_re['slide_id'],
                'days': days,
                'events': events,
                'death_probs': death_probs,
                'risk': risks
            })
            all_slide_re = pd.concat([all_slide_re, result_slide])

        combined_results = combined_results.groupby('patient_id').agg({
            'death_probs': 'max', 
            'days': 'first', 
            'events': 'first',
            'risk': 'max'
        }).reset_index()

        # Classify high and low risk based on the threshold (for the whole seed, across folds)
        high_risk = combined_results['death_probs'] > threshold
        low_risk = combined_results['death_probs'] <= threshold

        high_count = np.sum(high_risk)
        low_count = np.sum(low_risk)

        # Kaplan-Meier fitting
        kmf_high = KaplanMeierFitter()
        kmf_low = KaplanMeierFitter()

        kmf_high.fit(combined_results['days'][high_risk], event_observed=combined_results['events'][high_risk], label="High Risk")
        kmf_low.fit(combined_results['days'][low_risk], event_observed=combined_results['events'][low_risk], label="Low Risk")
        
        # Compute C-index
        c_index = concordance_index(combined_results['days'], -combined_results['death_probs'], combined_results['events'])  # Compute concordance index
        print(f"Seed {seed}, data_num {data_num}, fold_num{fold_num}, C-index: {c_index:.4f}")
        c_index_list.append(c_index)

        # Perform Log-Rank test
        results = logrank_test(
            combined_results['days'][high_risk], combined_results['days'][low_risk],
            event_observed_A=combined_results['events'][high_risk],
            event_observed_B=combined_results['events'][low_risk]
        )

        p_value = results.p_value
        if p_value < 0.05:
            count_005 += 1
        if p_value < 0.1:
            count_010 += 1

        # Plot the Kaplan-Meier curve for the current seed
        fig, ax = plt.subplots(figsize=(8, 6))
        kmf_high.plot(ax=ax)
        kmf_low.plot(ax=ax)

        ax.text(0.6, 0.2, f"Log-Rank p-value: {p_value:.4e}\nHigh Risk: {high_count}\nLow Risk: {low_count}",
                transform=ax.transAxes, fontsize=12)
        ax.set_title(f"Kaplan-Meier Curves (seed{seed})")
        ax.set_ylim(0.2, 1.01)
        ax.set_xlabel("Days")
        ax.set_ylabel("Survival Probability")

        plt.tight_layout()
        pdf.savefig(fig)  # Save current figure to the PDF
        plt.close(fig)  # Close the figure to release memory
        combined_results.to_csv(f"data/test_results/level{level}/slide_num{data_num}/fold_num_{fold_num}/test_concat/concat_seed{seed}.csv", encoding="utf-8-sig")
        all_slide_re.sort_values(by='slide_id').to_csv(f"data/test_results/level{level}/slide_num{data_num}/fold_num_{fold_num}/test_concat/slideconcat_seed{seed}.csv", encoding="utf-8-sig")

# pd.DataFrame(c_index_list).to_csv("GCN_max.csv", encoding="utf-8-sig")

# Calculate the total test count
test_count = len(seed_list)  # 50 seeds
print(f"p-value < 0.05: {count_005} / {test_count}")
print(f"p-value < 0.10: {count_010} / {test_count}")


Seed 83, data_num 148, fold_num4, C-index: 0.9054
p-value < 0.05: 1 / 1
p-value < 0.10: 1 / 1
