# Lecture 12: Capstone - Advanced Causal Methods

[!["Open In Colab"](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/<ORG>/<REPO>/blob/main/lectures/L12_Capstone/L12_Capstone_student.ipynb)

## Learning Objectives
1. Explain the problem of **time-varying confounding**.
2. Understand the concept of **Marginal Structural Models (MSMs)**.
3. Perform **target trial emulation** to define clear causal protocols.
4. Implement **discrete-time survival analysis** using pooled logistic regression.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.formula.api as smf
from phs564_ci.datasets import load_data
from phs564_ci.estimators.msm import fit_msm_pooled_logistic

# Load longitudinal survival data
df = load_data("l12_survival_data.csv")
df.head()

--- 
### 1. Discrete-Time Survival Analysis
We transform the data into a **person-period** format (one row per person-time) to model the hazard.

In [None]:
# The dataset is already in person-period format.
# Let's fit a pooled logistic model for the hazard.
hazard_model = smf.logit("event ~ time + A + L", data=df).fit()
print(hazard_model.summary().tables[1])

--- 
### üñºÔ∏è Figure Generation: Survival Curves (Slide 23)
We transform the predicted hazards into survival curves.

In [None]:
def calculate_survival(data, treatment_val):
    temp = data.copy()
    temp['A'] = treatment_val
    hazards = hazard_model.predict(temp).values.reshape(-1, 10) # 10 time points
    avg_hazards = hazards.mean(axis=0)
    survival = np.cumprod(1 - avg_hazards)
    return np.concatenate([[1], survival])

surv_a1 = calculate_survival(df[df['time']==0], 1)
surv_a0 = calculate_survival(df[df['time']==0], 0)

plt.figure(figsize=(10, 6))
plt.step(range(11), surv_a1, label='Always Treated (A=1)', where='post')
plt.step(range(11), surv_a0, label='Never Treated (A=0)', where='post')
plt.ylim(0, 1.05)
plt.title("Predicted Survival Curves (Pooled Logistic Model)")
plt.xlabel("Time")
plt.ylabel("Survival Probability")
plt.legend()
plt.savefig("figures/L12/survival_curves.png")
plt.show()

--- 
## üõë Activity 1: Protocol v2 Workshop (Slide 17)

Re-evaluate your research protocol using the **Target Trial** framework:
1. **Eligibility:** Are criteria defined strictly at baseline?
2. **Time Zero:** Does follow-up start exactly when treatment status is determined?
3. **Immortal Time:** Is there any way a subject could be "guaranteed" to survive a certain period based on your treatment definition?

--- 
### 2. Marginal Structural Model (MSM)
If treatment `A` and confounder `L` change over time, we use MSMs with stabilized weights.

In [None]:
# Using our helper to fit a weighted pooled logistic model (MSM)
# Note: In this simple demo, we use baseline weights for simplicity.
msm_res = fit_msm_pooled_logistic(df, 'event', 'A', ['L'], 'time')
print("MSM Results (Hazard Ratio for Treatment A):")
print(msm_res.summary().tables[1])

### 3. Summary
- Target trial emulation is the best way to avoid design-stage biases like immortal time.
- Discrete-time survival analysis using pooled logistic regression is a standard epidemiologic tool.
- MSMs combine survival models with time-varying IPTW to handle complex confounding feedback loops.