In [1]:
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output
from lifelines import KaplanMeierFitter
import matplotlib.pyplot as plt
from lifelines.statistics import logrank_test

# Load your data
data = pd.read_csv('ALL_MICE.csv')  # Adjust the path to your CSV file
data.dropna(subset=['Lifespan (days)'], inplace=True)

data = data[~data['Censored'].isin([1, 2])]
data.dropna(subset=['Lifespan (days)'], inplace=True)


# Function to update code options based on selected batch and gender
def update_code_options(batch_dropdown, gender_dropdown, code_dropdown):
    batch = batch_dropdown.value
    gender = gender_dropdown.value
    if gender == 'All' or batch == 'All':
        code_options = ['All']
    else:
        code_options = ['All'] + list(data[(data['Batch'] == batch) & (data['Gender'] == gender)]['Code'].unique())
    code_dropdown.options = code_options
    code_dropdown.value = 'All'

# Function to create a condition set
def create_condition_set(description_suffix=''):
    batch_dropdown = widgets.Dropdown(options=['All'] + list(data['Batch'].unique()), description=f'Batch{description_suffix}:', value='All')
    gender_dropdown = widgets.Dropdown(options=['All'], description=f'Gender{description_suffix}:', value='All')
    code_dropdown = widgets.Dropdown(options=['All'], description=f'Code{description_suffix}:', value='All')
    
    def on_batch_change(change):
        batch = change['new']
        gender_options = ['All'] + list(data[data['Batch'] == batch]['Gender'].unique()) if batch != 'All' else ['All']
        gender_dropdown.options = gender_options
        gender_dropdown.value = 'All'
        update_code_options(batch_dropdown, gender_dropdown, code_dropdown)
    
    def on_gender_change(change):
        update_code_options(batch_dropdown, gender_dropdown, code_dropdown)
    
    batch_dropdown.observe(on_batch_change, names='value')
    gender_dropdown.observe(on_gender_change, names='value')
    
    return batch_dropdown, gender_dropdown, code_dropdown

def calculate_auc(kmf):
    # Access the survival function data
    survival_function = kmf.survival_function_
    times = survival_function.index
    probabilities = survival_function["KM_estimate"].values
    
    # Calculate the area under the curve as the sum of rectangles
    auc = sum((times[i+1] - times[i]) * probabilities[i] for i in range(len(times)-1))
    
    return auc

# Plot_survival_curve function to calculate and display AUC, Delta AUC, and p-value
def plot_survival_curve(b):
    with plot_output:
        clear_output(wait=True)
        plt.figure(figsize=(10, 6))

        datasets = []  # To store filtered data for statistical comparison
        auc_values = []  # To store AUC values for comparison
        labels = []  # To store labels for comparison
        
        # Loop through each condition set
        for idx, condition_set in enumerate(condition_sets, start=1):
            batch, gender, code = (dropdown.value for dropdown in condition_set)
            batch = batch if batch != 'All' else None
            gender = gender if gender != 'All' else None
            code = code if code != 'All' else None
            
            # Filter data based on selections
            conditions = {}
            if batch:
                conditions['Batch'] = batch
            if gender:
                conditions['Gender'] = gender
            if code:
                conditions['Code'] = code
            
            filtered_data = data
            for condition_key, condition_value in conditions.items():
                filtered_data = filtered_data[filtered_data[condition_key] == condition_value]
            
            if not filtered_data.empty:
                datasets.append(filtered_data)  # Add filtered data for comparison
                # Fit the model and plot
                kmf = KaplanMeierFitter()
                kmf.fit(durations=filtered_data['Lifespan (days)'], event_observed=filtered_data['Dead'])
                auc = calculate_auc(kmf)
                auc_values.append(auc)  # Store AUC value for this condition
                label = f'Cond {idx}: {batch if batch else "All"} - {gender if gender else "All"} - {code if code else "All"} (AUC: {auc:.4f})'
                labels.append(label)
                kmf.plot_survival_function(ci_show=False, label=label)
        
        # Calculate and display log-rank test if there are exactly two datasets for comparison
        p_value = None
        if len(datasets) == 2:
            results = logrank_test(datasets[0]['Lifespan (days)'], datasets[1]['Lifespan (days)'],
                                   event_observed_A=datasets[0]['Dead'], event_observed_B=datasets[1]['Dead'])
            p_value = results.p_value
        
        # Display AUCs and calculate Delta AUC if two AUCs are present
        delta_auc = None
        if len(auc_values) == 2:
            delta_auc = abs(auc_values[1] - auc_values[0]) / auc_values[0] * 100  # Delta AUC in percentage
        
        # Title and legend adjustments
        title = 'Survival Curve Comparison'
        if p_value is not None:
            title += f' - p-value: {p_value:.4f}'
        if delta_auc is not None:
            title += f' - Delta AUC: {delta_auc:.2f}%'
        plt.title(title)
        
        plt.xlabel('Lifespan (days)')
        plt.ylabel('Survival Probability')
        plt.ylim(0, 1)
        plt.xlim(0, 1200)
        plt.legend()
        plt.show()

        # Optionally print the AUC values and Delta AUC directly if needed
        for label, auc in zip(labels, auc_values):
            print(f"{label}")
        if delta_auc is not None:
            print(f"Delta AUC (percentage): {delta_auc:.2f}%")

# Create an output widget for the plot
plot_output = widgets.Output()

# Create condition sets
condition_set_1 = create_condition_set(' 1')
condition_set_2 = create_condition_set(' 2')
condition_sets = [condition_set_1, condition_set_2]

# Create a button to plot the survival curve
plot_button = widgets.Button(description='Plot Survival Curve')
plot_button.on_click(plot_survival_curve)

# Display the widgets
for condition_set in condition_sets:
    display(*condition_set)
display(plot_button, plot_output)


Dropdown(description='Batch 1:', options=('All', 'E_2022_42_CEB', 'E_2022_45_CEB', 'E_2021_27_CEB', 'E_2023_15…

Dropdown(description='Gender 1:', options=('All',), value='All')

Dropdown(description='Code 1:', options=('All',), value='All')

Dropdown(description='Batch 2:', options=('All', 'E_2022_42_CEB', 'E_2022_45_CEB', 'E_2021_27_CEB', 'E_2023_15…

Dropdown(description='Gender 2:', options=('All',), value='All')

Dropdown(description='Code 2:', options=('All',), value='All')

Button(description='Plot Survival Curve', style=ButtonStyle())

Output()