# TTE-V2 Python Code + Clustering

In [None]:
import os
import pandas as pd
import statsmodels.formula.api as smf
import numpy as np
from sklearn.cluster import KMeans
from scipy.stats import multivariate_normal
from patsy import dmatrix
import matplotlib.pyplot as plt

# Step 1: Setup directories
trial_pp_dir = os.path.join(os.getcwd(), "trial_pp")
trial_itt_dir = os.path.join(os.getcwd(), "trial_itt")
os.makedirs(trial_pp_dir, exist_ok=True)
os.makedirs(trial_itt_dir, exist_ok=True)

trial_pp = {"estimand": "PP", "dir": trial_pp_dir}
trial_itt = {"estimand": "ITT", "dir": trial_itt_dir}

# Step 2: Data Preparation
data_censored = pd.read_csv("data_censored.csv")  # Replace with your data file
columns = {
    "id": "id",
    "period": "period",
    "treatment": "treatment",
    "outcome": "outcome",
    "eligible": "eligible"
}
trial_pp["data"] = data_censored.copy()
trial_pp["columns"] = columns
trial_itt["data"] = data_censored.copy()
trial_itt["columns"] = columns

# Step 3: Clustering Integration
def perform_clustering(data, baseline_vars, n_clusters=3):
    """
    Perform K-means clustering on baseline covariates and merge cluster labels into the dataset.
    
    Args:
        data (pd.DataFrame): Input dataset
        baseline_vars (list): List of baseline variables for clustering
        n_clusters (int): Number of clusters (default=3)
    
    Returns:
        pd.DataFrame: Dataset with cluster labels
        KMeans: Fitted K-means model
    """
    baseline_data = data[data["period"] == 0][["id"] + baseline_vars]
    kmeans = KMeans(n_clusters=n_clusters, random_state=1234)
    baseline_data["cluster"] = kmeans.fit_predict(baseline_data[baseline_vars])
    data = data.merge(baseline_data[["id", "cluster"]], on="id", how="left")
    return data, kmeans

# Define baseline variables for clustering (adjust based on your dataset)
baseline_vars = ["age", "x1", "x2", "x3", "x4"]
trial_itt["data"], kmeans_model = perform_clustering(trial_itt["data"], baseline_vars)

# Step 4: Specify Outcome Model with Clustering
def set_outcome_model(trial, adjustment_terms=None):
    """
    Define the outcome model formula including cluster interaction.
    
    Args:
        trial (dict): Trial dictionary containing data and columns
        adjustment_terms (list): Additional terms for adjustment (optional)
    """
    outcome = trial["columns"]["outcome"]
    formula = f"{outcome} ~ assigned_treatment * cluster + followup_time"
    if adjustment_terms:
        formula += " + " + " + ".join(adjustment_terms)
    trial["outcome_model_formula"] = formula

set_outcome_model(trial_itt, adjustment_terms=["x2"])

# Step 5: Fit Marginal Structural Model (MSM)
def fit_msm(trial):
    """
    Fit the MSM using the outcome model formula.
    
    Args:
        trial (dict): Trial dictionary with data and formula
    
    Returns:
        sm.LogitResults: Fitted model
    """
    data = trial["data"]  # Assuming data is ready; add weight steps if needed
    formula = trial["outcome_model_formula"]
    model = smf.logit(formula, data=data).fit(disp=0)
    trial["outcome_model"] = {"fitted": model, "formula": formula}
    return model

msm_itt = fit_msm(trial_itt)

# Step 6: Prediction with Clustering
def predict_with_clustering(trial, predict_times=None, conf_int=True, samples=100):
    """
    Predict survival differences by cluster with optional confidence intervals.
    
    Args:
        trial (dict): Trial dictionary with fitted model
        predict_times (list): Time points for prediction (default: 0-10)
        conf_int (bool): Whether to compute confidence intervals
        samples (int): Number of bootstrap samples for CI
    
    Returns:
        dict: Prediction results including survival differences by cluster
    """
    if predict_times is None:
        predict_times = list(range(11))
    model = trial["outcome_model"]["fitted"]
    formula = trial["outcome_model"]["formula"].split("~")[1].strip()
    newdata = trial["data"][trial["data"]["period"] == 1].copy()

    predictor_terms = [var.strip() for var in formula.split(" + ")]
    base_vars = [var for var in predictor_terms if not var.startswith("I(") and var not in ["assigned_treatment", "cluster"]]
    base_vars = list(set(base_vars + ["id", "cluster"]))
    newdata = newdata[base_vars].copy()

    n_baseline = len(newdata)
    newdata = pd.concat([newdata] * len(predict_times), ignore_index=True)
    newdata["followup_time"] = np.tile(predict_times, n_baseline)
    newdata["trial_period"] = 1

    coefs = model.params.values
    cov_matrix = model.cov_params()
    coefs_samples = multivariate_normal.rvs(mean=coefs, cov=cov_matrix, size=samples) if conf_int else [coefs]

    pred_results = {}
    for treatment in [0, 1]:
        for cluster in newdata["cluster"].unique():
            temp_data = newdata[newdata["cluster"] == cluster].copy()
            temp_data["assigned_treatment"] = treatment
            exog = dmatrix(formula, temp_data, return_type="dataframe")
            preds = []
            for coef_sample in coefs_samples:
                linear_pred = np.dot(exog, coef_sample)
                probs = 1 / (1 + np.exp(-linear_pred))
                probs_matrix = probs.reshape(-1, len(predict_times), order='F')
                survival = np.cumprod(1 - probs_matrix, axis=1)
                preds.append(survival.mean(axis=0))
            pred_results[f'treatment_{treatment}_cluster_{cluster}'] = np.array(preds).T

    diff_results = {}
    for cluster in newdata["cluster"].unique():
        diff_preds = pred_results[f'treatment_1_cluster_{cluster}'] - pred_results[f'treatment_0_cluster_{cluster}']
        diff_results[f'cluster_{cluster}'] = diff_preds

    results = {
        'difference': {
            'followup_time': predict_times,
            'clusters': {cluster: {'survival_diff': diff_results[cluster][:, 0]} for cluster in diff_results}
        }
    }
    if conf_int:
        for cluster in diff_results:
            results['difference']['clusters'][cluster]['2.5%'] = np.percentile(diff_results[cluster], 2.5, axis=1)
            results['difference']['clusters'][cluster]['97.5%'] = np.percentile(diff_results[cluster], 97.5, axis=1)

    return results

# Step 7: Plotting Cluster-Specific Survival Differences
def plot_cluster_survival_differences(preds):
    """
    Plot survival differences for each cluster.
    
    Args:
        preds (dict): Prediction results from predict_with_clustering function
    """
    clusters = preds['difference']['clusters']
    n_clusters = len(clusters)
    fig, axes = plt.subplots(1, n_clusters, figsize=(5 * n_clusters, 5), sharey=True)
    axes = [axes] if n_clusters == 1 else axes
    for i, (cluster, data) in enumerate(clusters.items()):
        ax = axes[i]
        ax.plot(preds['difference']['followup_time'], data['survival_diff'], 'k-', label="Survival difference")
        ax.plot(preds['difference']['followup_time'], data['2.5%'], 'r--', label="95% CI")
        ax.plot(preds['difference']['followup_time'], data['97.5%'], 'r--')
        ax.set_title(f"Cluster {cluster}")
        ax.set_xlabel("Follow up")
        ax.set_ylabel("Survival difference")
        ax.grid(True)
        ax.legend()
    plt.tight_layout()
    plt.show()

# Execute cluster-based prediction and plotting
predict_times = list(range(11))
preds_cluster = predict_with_clustering(trial_itt, predict_times=predict_times)
plot_cluster_survival_differences(preds_cluster)

# Step 9: Inference (New Step Added by Your Friend)
def predict_overall(trial, newdata=None, predict_times=None, type='survival', conf_int=True, samples=100):
    """
    Predict overall survival probabilities and differences with optional confidence intervals.
    
    Args:
        trial (dict): Trial dictionary with fitted model
        newdata (pd.DataFrame, optional): Data for prediction; defaults to trial data
        predict_times (list): Time points for prediction (default: 0-10)
        type (str): 'survival' or 'cum_inc' (default: 'survival')
        conf_int (bool): Whether to compute confidence intervals
        samples (int): Number of bootstrap samples for CI
    
    Returns:
        dict: Prediction results including survival probabilities and differences
    """
    if predict_times is None:
        predict_times = list(range(11))

    if "outcome_model" not in trial or "fitted" not in trial["outcome_model"]:
        raise KeyError("Fitted model not found in trial['outcome_model']['fitted']. Run fit_msm first.")
    model = trial["outcome_model"]["fitted"]
    formula = trial["outcome_model"]["formula"].split("~")[1].strip()

    if newdata is None:
        newdata = trial["data"][trial["data"]["period"] == 1].copy()
    else:
        newdata = newdata[newdata["period"] == 1].copy()

    predictor_terms = [var.strip() for var in formula.split(" + ")]
    base_vars = [var for var in predictor_terms if not var.startswith("I(") and var != "assigned_treatment"]
    base_vars = list(set(base_vars + ["id"]))
    newdata = newdata[base_vars].copy()

    n_baseline = len(newdata)
    newdata = pd.concat([newdata] * len(predict_times), ignore_index=True)
    newdata["followup_time"] = np.tile(predict_times, n_baseline)
    newdata["trial_period"] = 1

    coefs = model.params.values
    cov_matrix = model.cov_params()
    coefs_samples = multivariate_normal.rvs(mean=coefs, cov=cov_matrix, size=samples) if conf_int else [coefs]

    pred_results = {}
    for treatment in [0, 1]:
        temp_data = newdata.copy()
        temp_data["assigned_treatment"] = treatment
        exog = dmatrix(formula, temp_data, return_type="dataframe")
        
        preds = []
        for coef_sample in coefs_samples:
            linear_pred = np.dot(exog, coef_sample)
            probs = 1 / (1 + np.exp(-linear_pred))
            probs_matrix = probs.reshape(n_baseline, len(predict_times), order='F')
            
            if type == 'survival':
                survival = np.cumprod(1 - probs_matrix, axis=1)
                result = survival.mean(axis=0)
            else:  # cum_inc
                result = 1 - np.cumprod(1 - probs_matrix, axis=1).mean(axis=0)
            preds.append(result)
        
        pred_results[f'assigned_treatment_{treatment}'] = np.array(preds).T

    diff_preds = pred_results['assigned_treatment_1'] - pred_results['assigned_treatment_0']

    results = {
        'assigned_treatment_0': {
            'followup_time': predict_times,
            'survival': pred_results['assigned_treatment_0'][:, 0]
        },
        'assigned_treatment_1': {
            'followup_time': predict_times,
            'survival': pred_results['assigned_treatment_1'][:, 0]
        },
        'difference': {
            'followup_time': predict_times,
            'survival_diff': diff_preds[:, 0]
        }
    }
    if conf_int:
        results['difference']['2.5%'] = np.percentile(diff_preds, 2.5, axis=1)
        results['difference']['97.5%'] = np.percentile(diff_preds, 97.5, axis=1)

    return results

def plot_survival_difference(preds):
    """
    Plot overall survival difference with confidence intervals in R style.
    
    Args:
        preds (dict): Prediction results from predict_overall function
    """
    plt.figure(figsize=(8, 6))
    plt.plot(preds['difference']['followup_time'], 
             preds['difference']['survival_diff'], 
             'k-', 
             label="Survival difference")
    plt.plot(preds['difference']['followup_time'], 
             preds['difference']['2.5%'], 
             'r--', 
             label="95% CI")
    plt.plot(preds['difference']['followup_time'], 
             preds['difference']['97.5%'], 
             'r--')
    plt.xlabel("Follow up")
    plt.ylabel("Survival difference")
    plt.grid(True)
    plt.legend()
    plt.show()

# Execute overall prediction and plotting (new step)
preds_overall = predict_overall(trial_itt, predict_times=predict_times)
plot_survival_difference(preds_overall)