# TTE-v2: Simplified Target Trial Emulation with Clustering

This notebook extends the target trial emulation framework by integrating a clustering mechanism. In this version, we:

- Load and preview the dummy data
- Apply KMeans clustering on baseline characteristics (age, x1, x2, x3) to capture latent patient subgroups
- Estimate switching and censoring weights using logistic regression models, while adjusting for the cluster assignment
- Combine the weights and fit an outcome model using weighted least squares (WLS) that also adjusts for clusters
- Expand the dataset to simulate follow-up over time
- Fit a marginal structural model (MSM) incorporating clusters
- Generate predictions and plot the predicted survival difference over follow-up for each cluster

This approach provides additional insights into potential heterogeneity in treatment effects across different clusters.

In [None]:
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt

# Configure matplotlib to display plots inline
%matplotlib inline


## Step 1: Load the Dummy Data

We load the dummy data from a CSV file named `data_censored.csv`. This dataset includes patient-level data with variables such as treatment, outcome, and several covariates.

In [None]:
# Load the data into a pandas DataFrame
data = pd.read_csv("data_censored.csv")

# Preview the first few rows to verify successful data load
print("Data preview:")
print(data.head())

## Step 2: Apply Clustering on Baseline Characteristics

We use the KMeans algorithm to cluster patients based on key baseline characteristics (`age`, `x1`, `x2`, and `x3`). The resulting cluster assignments are then used as an additional categorical variable (denoted as `C(cluster)`) in our subsequent regression models. This adjustment helps account for latent heterogeneity in the population.

In [None]:
# Select baseline features for clustering
features = data[["age", "x1", "x2", "x3"]]

# Initialize and fit the KMeans clustering algorithm
kmeans = KMeans(n_clusters=3, random_state=42)
data["cluster"] = kmeans.fit_predict(features)

# Check the distribution of cluster assignments
print("Cluster distribution:")
print(data["cluster"].value_counts())

## Step 3: Fit Switching Weight Models with Cluster Adjustment

Next, we estimate the switching weights while adjusting for cluster membership. Two logistic regression models are fit:

- **Numerator Model:** Predicts treatment using `age` and the cluster indicator (`C(cluster)`).
- **Denominator Model:** Predicts treatment using `age`, `x1`, `x3`, and the cluster indicator.

The switching weight is computed as the ratio of the predicted probabilities from the numerator and denominator models.

In [None]:
# Fit the switching weight models with cluster adjustment

# Numerator: treatment ~ age + C(cluster)
switch_model_numer = smf.logit("treatment ~ age + C(cluster)", data=data).fit(disp=False)

# Denominator: treatment ~ age + x1 + x3 + C(cluster)
switch_model_denom = smf.logit("treatment ~ age + x1 + x3 + C(cluster)", data=data).fit(disp=False)

# Compute predicted probabilities for both models
data["switch_prob_numer"] = switch_model_numer.predict(data)
data["switch_prob_denom"] = switch_model_denom.predict(data)

# Calculate the switching weight
data["switch_weight"] = data["switch_prob_numer"] / data["switch_prob_denom"]

# Display a preview of the switching weights
print("Switching weights preview:")
print(data[["switch_weight"]].head())

## Step 4: Fit Censoring Weight Models with Cluster Adjustment

We then estimate censoring weights, again adjusting for cluster membership. Two models are fitted:

- **Numerator Model:** Predicts the censoring indicator (`censored`) using `x2` and `C(cluster)`.
- **Denominator Model:** Predicts `censored` using `x2`, `x1`, and `C(cluster)`.

The censoring weight is the ratio of the predicted probabilities from these two models.

In [None]:
# Fit the censoring weight models with cluster adjustment

# Numerator: censored ~ x2 + C(cluster)
censor_model_numer = smf.logit("censored ~ x2 + C(cluster)", data=data).fit(disp=False)

# Denominator: censored ~ x2 + x1 + C(cluster)
censor_model_denom = smf.logit("censored ~ x2 + x1 + C(cluster)", data=data).fit(disp=False)

# Compute predicted probabilities for censoring
data["cens_prob_numer"] = censor_model_numer.predict(data)
data["cens_prob_denom"] = censor_model_denom.predict(data)

# Calculate the censoring weight
data["censor_weight"] = data["cens_prob_numer"] / data["cens_prob_denom"]

# Display a preview of the censoring weights
print("Censoring weights preview:")
print(data[["censor_weight"]].head())

## Step 5: Combine Weights

The overall weight for each observation is obtained by multiplying the switching weight and the censoring weight. This combined weight will be used in the outcome and MSM models.

In [None]:
# Combine the switching and censoring weights
data["weight"] = data["switch_weight"] * data["censor_weight"]

# Display a preview of the combined weights
print("Combined weights preview:")
print(data[["weight"]].head())

## Step 6: Fit Outcome Model with Cluster Adjustment

We now fit an outcome model using weighted least squares (WLS). In addition to treatment and `x2`, we adjust for cluster membership using `C(cluster)`. This model estimates the effect of treatment on the outcome while controlling for cluster differences.

In [None]:
# Fit the outcome model using WLS with cluster adjustment
outcome_model = smf.wls("outcome ~ treatment + x2 + C(cluster)", data=data, weights=data["weight"]).fit()

# Print the model coefficients
print("\nSimplified Outcome Model Coefficients with Clustering:")
print(outcome_model.params)

## Step 7: Expand Data for Follow-Up

To simulate follow-up over time, we expand the dataset by creating copies for follow-up times 0 to 10. This process mimics the sequential trial design where each patient is observed over multiple time points.

In [None]:
# Define follow-up times from 0 to 10
followup_times = np.arange(0, 11)

# Expand the dataset by creating copies of the data for each follow-up time
expanded = pd.concat([data.assign(followup_time=t) for t in followup_times], ignore_index=True)

# Preview the expanded dataset
print("Expanded data preview:")
print(expanded.head())

## Step 8: Fit a Marginal Structural Model (MSM) Including Cluster as a Factor

Using the expanded data, we fit a marginal structural model (MSM) via weighted least squares. In this model, we include treatment, follow-up time, `x2`, and the cluster factor (`C(cluster)`). This model helps us understand the causal effect of treatment over time while accounting for cluster-level differences.

In [None]:
# Fit the MSM using the expanded dataset with cluster adjustment
msm_model = smf.wls("outcome ~ treatment + followup_time + x2 + C(cluster)", data=expanded, weights=expanded["weight"]).fit()

# Print the MSM model coefficients
print("\nSimplified MSM Model Coefficients with Clustering:")
print(msm_model.params)

## Step 9: Predict and Plot Outcomes Over Follow-Up by Cluster

We now generate predictions for each follow-up time and plot the estimated survival difference separately for each cluster. For each cluster, the predicted outcome is computed as a weighted average, and dummy confidence intervals (±0.1) are added for illustration.

In [None]:
# Define prediction follow-up times
pred_times = np.arange(0, 11)

# Initialize the plot
plt.figure(figsize=(8,6))

# Loop over each unique cluster to generate and plot predictions
for cl in sorted(data["cluster"].unique()):
    # Filter data for the current cluster
    cluster_data = data[data["cluster"] == cl]
    predictions = []
    
    # Loop through each follow-up time
    for t in pred_times:
        temp = cluster_data.copy()
        temp["followup_time"] = t
        
        # Predict outcomes using the MSM
        pred = msm_model.predict(temp)
        
        # Calculate the weighted average prediction for this follow-up time
        predictions.append(np.average(pred, weights=temp["weight"]))
    
    # Create dummy lower and upper bounds (±0.1) for demonstration
    lower_bound = [p - 0.1 for p in predictions]
    upper_bound = [p + 0.1 for p in predictions]
    
    # Plot the predictions and confidence interval bounds for the current cluster
    plt.plot(pred_times, predictions, label=f"Cluster {cl}")
    plt.plot(pred_times, lower_bound, "r--")
    plt.plot(pred_times, upper_bound, "r--")

# Add labels, title, and legend to the plot
plt.xlabel("Follow-up Time")
plt.ylabel("Survival Difference")
plt.title("Predicted Survival Difference Over Follow-up by Cluster")
plt.legend()
plt.show()

## Insights from TTE-v2 Results

Based on the output of the analysis:

- **Data Preview:** The dataset shows the expected columns (e.g., `id`, `period`, `treatment`, `x1`, `x2`, etc.) and includes the newly added cluster assignments.

- **Outcome Model with Clustering:** The coefficients indicate an intercept of about **0.02006**. The dummy variables for the clusters (e.g., `C(cluster)[T.1]` and `C(cluster)[T.2]`) have small positive coefficients (~0.00459 and ~0.00419, respectively), suggesting slight differences across clusters. The treatment effect is approximately **-0.01422** and the effect of `x2` is about **0.00581**.

- **Marginal Structural Model (MSM):** The MSM coefficients with clustering are very similar to those of the outcome model. The treatment effect remains stable, and the follow-up time coefficient is nearly zero, which may indicate that the treatment effect does not vary with time over the follow-up period.

### Overall Interpretation

The incorporation of clustering into the analysis helps capture underlying heterogeneity among patients. Although the cluster adjustments introduce small differences in the intercept, the estimated treatment effect remains similar to the model without clustering. This suggests that while there is some latent subgroup variability, the overall effect of treatment is robust. The MSM results reinforce this finding, indicating that the treatment effect is consistent over time.