In [125]:
import os
import pickle
import pandas as pd
import numpy as np
from datetime import datetime
from dateutil.relativedelta import relativedelta
import matplotlib.pyplot as plt

In [2]:
LOCAL_FILE = './BLCS_Data_Updated_February2023.xlsx'

In [3]:
BLCS = pd.read_excel(LOCAL_FILE, engine='openpyxl')
extrcdate = pd.to_datetime('2023-02-28')

In [4]:
def days_between(a, b):
    return (b - a).days

In [5]:
date_cols = [c for c in BLCS.columns if 'date' in c.lower()]
for col in date_cols:
    BLCS[col] = pd.to_datetime(BLCS[col], errors='coerce')

In [6]:
BLCS = BLCS.dropna(subset=['dxdate'])
BLCS = BLCS[BLCS['dxdate'].apply(lambda d: days_between(d, extrcdate) > 180)]

In [7]:
exclude = {'0', 'SCLC-ES', 'SCLC-LS'}
BLCS = BLCS[~BLCS['cstage'].isin(exclude)]

In [8]:
def compute_pdate_pevent(df):
    pdate = pd.Series(pd.NaT, index=df.index)
    pevent = pd.Series(np.nan, index=df.index)
    for i, row in df.iterrows():
        if pd.notna(row.get('progressiondate')):
            pdate.at[i] = row['progressiondate']; pevent.at[i] = row.get('progression', np.nan)
        elif pd.notna(row.get('relapsedate')):
            pdate.at[i] = row['relapsedate']; pevent.at[i] = row.get('relapse', np.nan)
        elif pd.notna(row.get('progdate')) and row['progdate'] > row['dxdate']:
            pdate.at[i] = row['progdate']; pevent.at[i] = row.get('prog', np.nan)
    # Recode pevent: 0 or 2 -> 0; 1 -> 1
    pevent = pevent.replace({2:0}).fillna(np.nan)
    pevent = pevent.where(pevent==1, 0)
    return pdate, pevent

In [9]:
BLCS['pdate'], BLCS['pevent'] = compute_pdate_pevent(BLCS)

In [10]:
def compute_censor(df):
    censordate = pd.Series(index=df.index, dtype='datetime64[ns]')
    censorcheck = pd.Series(index=df.index, dtype='object')
    for i, row in df.iterrows():
        alive = row.get('alivedate')
        death = row.get('deathdate')
        pdate = row.get('pdate')
        if pd.notna(alive) and alive <= row['dxdate'] and pd.notna(alive) and not (pd.isna(pdate) and pd.isna(death)):
            censordate.at[i] = extrcdate; censorcheck.at[i] = 'Corrected Negative Survival'
        elif pd.notna(alive) and pd.notna(death) and alive <= death:
            censordate.at[i] = extrcdate; censorcheck.at[i] = 'Corrected Censoring with Death'
        else:
            censordate.at[i] = min([d for d in [alive, extrcdate, death] if pd.notna(d)])
            censorcheck.at[i] = 'Not Corrected'
    return censordate, censorcheck

In [11]:
BLCS['censordate'], BLCS['censorcheck'] = compute_censor(BLCS)

In [12]:
sec_per_year = 365.25 * 24 * 3600
BLCS['Ti1'] = (BLCS['pdate'] - BLCS['dxdate']).dt.total_seconds() / sec_per_year
BLCS['Ti2'] = (BLCS['deathdate'] - BLCS['dxdate']).dt.total_seconds() / sec_per_year
BLCS['Ci'] = (BLCS['censordate'] - BLCS['dxdate']).dt.total_seconds() / sec_per_year
BLCS['ptime'] = np.fmin(BLCS['Ti1'], BLCS['Ci'])
BLCS['dtime'] = np.fmin(BLCS['Ti2'], BLCS['Ci'])
BLCS['pevent'] = np.where(BLCS['Ti1'] <= BLCS['Ci'], 1, 0)
BLCS['devent'] = np.where(BLCS['Ti2'] <= BLCS['Ci'], 1, 0)

In [13]:
df = BLCS
BLCS['Yi1'] = df[['Ti1','Ti2','Ci']].min(axis=1)
BLCS['Yi2'] = df[['Ti2','Ci']].min(axis=1)
BLCS['YiS'] = df[['Ti1','Ti2','Ci']].min(axis=1)
BLCS['Di1'] = np.where(df['Ti1'] <= BLCS['Yi2'], 1, 0)
BLCS['Di2'] = np.where(df['Ti2'] <= df['Ci'], 1, 0)
BLCS['DiS'] = np.where(df[['Ti2','Ti2']].min(axis=1) <= df['Ci'], 1, 0)

In [None]:
# Predictor Recoding
BLCS['sex'] = BLCS['sex'].map({1:'Male',2:'Female'}).astype('category')
# Race
race_map = {1:'White/Caucasian',2:'Native American/Alaska Native',3:'Asian',
            4:'Black/African American',5:'Native Hawaiian/Pacific Islander',
            6:'Multiracial',7:'Other'}
BLCS['race'] = BLCS['race'].map(race_map).astype('category')
BLCS['ethnic'] = BLCS['ethnic'].map({0:'Non-Hispanic',1:'Hispanic'}).astype('category')
# Education
edu_map = {
    1:'Some grade school',2:'Some high school',3:'High school graduate',
    4:"Vocational/tech school after high school",5:"Some college or associate's degree",
    6:'College graduate',7:'Graduate or professional school',8:'Other'
}
BLCS['education'] = BLCS['education'].map(edu_map).astype('category')
BLCS['bmi'] = BLCS['wtkg'] / (BLCS['htm']**2)
BLCS['smk'] = BLCS['smk'].map({1:'Never smoker',2:'Former smoker',3:'Current smoker',4:'Smoker, status unknown'}).astype('category')
# Treatment
BLCS['trt'] = np.select(
    [BLCS['surg']==1, BLCS['chemo']==1, BLCS['radio']==1, BLCS['othertrt']==1],
    ['Surgery','Chemotherapy','Radiation','Other'], default=np.nan
)
# Comorbidities
BLCS['copd'] = BLCS['copd'].map({0:'No',1:'Yes'}).astype('category')
BLCS['asthma'] = BLCS['asthma'].map({0:'No',1:'Yes'}).astype('category')
BLCS['egfr'] = BLCS['egfr'].map({0:'No',1:'Yes'}).fillna('Not Tested').astype('category')
BLCS['kras'] = BLCS['kras'].map({0:'No',1:'Yes'}).fillna('Not Tested').astype('category')

In [15]:
final_cols = ['ptime','pevent','dtime','devent','Yi1','Yi2','YiS','Di1','Di2','DiS',
               'agedx','sex','race','ethnic','education','bmi','smk','pkyrs','trt',
               'ctype','cstage','copd','asthma','egfr','kras']
BLCS_OUT4 = BLCS[final_cols + [c for c in BLCS.columns if c not in final_cols]]

In [16]:
BLCS_CHECK = BLCS_OUT4[(BLCS_OUT4['Yi1'] <= 0) | (BLCS_OUT4['Yi2'] <= 0)]

In [17]:
BLCS_CLEAN = BLCS_OUT4[(BLCS_OUT4['Yi1'] > 0) & (BLCS_OUT4['Yi2'] > 0)]

In [None]:
# Model-Ready Dataset
BLCS_CLEAN2 = BLCS_CLEAN.copy()
BLCS_CLEAN2['stage_splt'] = np.where(BLCS_CLEAN2['cstage'].isin(['1','1A','1B','2','2A','2B','3','3A']),0,1)
BLCS_CLEAN2['agedx'].fillna(BLCS_CLEAN2['agedx'].mean(),inplace=True)
# Add and fill Unknown levels correctly
BLCS_CLEAN2['sex'] = BLCS_CLEAN2['sex'].cat.add_categories('Unknown')
BLCS_CLEAN2['sex'].fillna('Unknown',inplace=True)
BLCS_CLEAN2['race'] = BLCS_CLEAN2['race'].cat.add_categories('Unknown')
BLCS_CLEAN2['race'].fillna('Unknown',inplace=True)
# Collapse race
BLCS_CLEAN2['race'] = BLCS_CLEAN2['race'].apply(lambda r: 'White/Caucasian' if r=='White/Caucasian' else 'Other')
BLCS_CLEAN2['education'] = BLCS_CLEAN2['education'].cat.add_categories('Unknown')
BLCS_CLEAN2['education'].fillna('Unknown',inplace=True)
BLCS_CLEAN2['education'] = BLCS_CLEAN2['education'].replace({'Other':'Other'})
BLCS_CLEAN2['bmi'].fillna(BLCS_CLEAN2['bmi'].mean(),inplace=True)
BLCS_CLEAN2['smk'] = BLCS_CLEAN2['smk'].cat.add_categories('Unknown')
BLCS_CLEAN2['smk'].fillna('Unknown',inplace=True)
BLCS_CLEAN2['smk'] = BLCS_CLEAN2['smk'].replace({'Smoker, status unknown':'Smoker, status unknown'})
BLCS_CLEAN2['pkyrs'].fillna(BLCS_CLEAN2['pkyrs'].mean(),inplace=True)
BLCS_CLEAN2['surg'] = np.where(BLCS_CLEAN2['trt']=='Surgery',1,0)

In [19]:
BLCS_CLEAN2

Unnamed: 0,ptime,pevent,dtime,devent,Yi1,Yi2,YiS,Di1,Di2,DiS,...,alivedate,dead,questdate,pdate,censordate,censorcheck,Ti1,Ti2,Ci,stage_splt
0,0.350445,0,0.350445,1,0.350445,0.350445,0.350445,0,1,1,...,NaT,1.0,1992-12-14,NaT,1993-04-22,Not Corrected,,0.350445,0.350445,0
1,4.377823,0,4.377823,1,4.377823,4.377823,4.377823,0,1,1,...,NaT,1.0,1992-12-14,NaT,1997-05-02,Not Corrected,,4.377823,4.377823,0
2,7.816564,0,7.816564,1,7.816564,7.816564,7.816564,0,1,1,...,NaT,1.0,1992-12-15,NaT,2000-10-11,Not Corrected,,7.816564,7.816564,1
3,2.685832,0,2.685832,1,2.685832,2.685832,2.685832,0,1,1,...,NaT,1.0,1992-12-21,NaT,1995-08-23,Not Corrected,,2.685832,2.685832,0
4,0.451745,0,0.451745,1,0.451745,0.451745,0.451745,0,1,1,...,NaT,1.0,1992-12-22,NaT,1993-05-30,Not Corrected,,0.451745,0.451745,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7686,1.201916,0,1.201916,0,1.201916,1.201916,1.201916,0,0,0,...,2023-01-22,0.0,NaT,NaT,2023-01-22,Not Corrected,,,1.201916,1
7688,0.418891,0,0.418891,0,0.418891,0.418891,0.418891,0,0,0,...,2022-12-08,0.0,2022-12-07,NaT,2022-12-08,Not Corrected,,,0.418891,1
7691,0.123203,0,0.123203,0,0.123203,0.123203,0.123203,0,0,0,...,2022-10-14,0.0,2022-10-12,NaT,2022-10-14,Not Corrected,,,0.123203,0
7692,0.443532,0,0.443532,0,0.443532,0.443532,0.443532,0,0,0,...,2022-11-16,0.0,2022-11-15,NaT,2022-11-16,Not Corrected,,,0.443532,0


In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
from scipy.special import psi, gammaln
from scipy.special import digamma
import pyreadr
import os
import time
from tqdm import tqdm
import pickle

In [31]:
def loss1_fun_stable(h1, d1, y1, gamma):
    h1 = h1.view(-1)
    d1 = d1.view(-1)
    y1 = y1.view(-1)
    gamma = gamma.view(-1)

    risk_matrix = (y1.view(-1, 1) <= y1.view(1, -1)).float()

    h1_stable = h1 - torch.max(h1)
    exp_h1 = gamma * torch.exp(h1_stable)

    denom = torch.matmul(risk_matrix, exp_h1) + 1e-8
    log_risk = torch.log(denom) + torch.max(h1)

    loss_vector = -d1 * (h1 - log_risk)

    return loss_vector.sum() / (d1.sum() + 1e-8)

In [32]:
def loss2_fun_stable(h2, d1, d2, y1, gamma):
    h2 = h2.view(-1)
    d1 = d1.view(-1)
    d2 = d2.view(-1)
    y1 = y1.view(-1)
    gamma = gamma.view(-1)

    event_mask = ((d1 == 0) & (d2 == 1)).float()
    risk_matrix = (y1.view(-1, 1) <= y1.view(1, -1)).float()

    h2_stable = h2 - torch.max(h2)
    exp_h2 = gamma * torch.exp(h2_stable)

    denom = torch.matmul(risk_matrix, exp_h2) + 1e-8
    log_risk = torch.log(denom) + torch.max(h2)

    loss_vector = -event_mask * (h2 - log_risk)

    return loss_vector.sum() / (event_mask.sum() + 1e-8)


In [33]:
def loss3_fun_stable(h3, d1, d2, y1, y2, gamma):
    h3 = h3.view(-1)
    d1 = d1.view(-1)
    d2 = d2.view(-1)
    y1 = y1.view(-1)
    y2 = y2.view(-1)
    gamma = gamma.view(-1)

    event_mask = ((d1 == 1) & (d2 == 1)).float()

    risk_matrix = ((y1.view(1, -1) < y2.view(-1, 1)) &
                   (y2.view(1, -1) >= y2.view(-1, 1))).float()

    h3_stable = h3 - torch.max(h3)
    exp_h3 = gamma * d1 * torch.exp(h3_stable)

    denom = torch.matmul(risk_matrix, exp_h3) + 1e-8
    log_risk = torch.log(denom) + torch.max(h3)

    loss_vector = -event_mask * (h3 - log_risk)

    return loss_vector.sum() / (event_mask.sum() + 1e-8)

In [None]:
def fit_dnn(formula, data, na_action="na.fail", subset=None,
            dim_layers=[128, 64, 16], lr=0.01, dr=0.1,
            max_epochs=250, max_epochs_theta=100, max_epochs_n=5, verbose=True,
            ll=1, tol=1e-6, theta0=0.5, lr_theta=0.01, batch_size=128):
    # Assertions
    if na_action not in ["na.fail", "na.omit"]:
        raise ValueError('na_action should be either "na.fail" or "na.omit"')

    # Pre-Process Data
    
    # Outcomes
    y1 = data['Yi1'].values
    d1 = data['Di1'].values
    y2 = data['Yi2'].values
    d2 = data['Di2'].values

    # Unique Failure Times by Transition
    t1_obs = y1[d1 == 1]
    t2_obs = y2[(d1 == 0) & (d2 == 1)]
    t3_obs = y2[(d1 == 1) & (d2 == 1)]
    t1 = np.unique(np.sort(t1_obs))
    t2 = np.unique(np.sort(t2_obs))
    t3 = np.unique(np.sort(t3_obs))
    tol_val = tol

    # Features
    X1_mat = data[form['X']].values
    X2_mat = data[form['X']].values
    X3_mat = data[form['X']].values
    n = X1_mat.shape[0]

    # Initialize Parameters
    theta = torch.tensor(theta0, requires_grad=True, dtype=torch.float32)
    optimizer_theta = optim.Adam([theta], lr=lr_theta)

    # Baseline Hazards
    d1_j = np.array([np.sum(d1 * (y1 == t)) for t in t1])
    n1_j = np.array([np.sum(y1 >= t) for t in t1])
    lam01 = d1_j / n1_j

    d2_j = np.array([np.sum((1 - d1) * d2 * (y2 == t)) for t in t2])
    n2_j = np.array([np.sum((y2 >= t) & (y1 >= t)) for t in t2])
    lam02 = d2_j / n2_j

    d3_j = np.array([np.sum(d1 * d2 * (y2 == t)) for t in t3])
    n3_j = np.array([np.sum(y2[d1==1] >= t) for t in t3])
    lam03 = d3_j / n3_j

    # Neural Network Sub-Architectures
    input_dim = X1_mat.shape[1]
    def build_model(input_dim, dim_layers, dropout):
        return nn.Sequential(
            nn.Linear(input_dim, dim_layers[0], bias=False),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_layers[0], dim_layers[1], bias=False),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_layers[1], dim_layers[2], bias=False),
            nn.Linear(dim_layers[2], 1, bias=False)
        )
    h1_model = build_model(input_dim, dim_layers, dr)
    h2_model = build_model(input_dim, dim_layers, dr)
    h3_model = build_model(input_dim, dim_layers, dr)
    optimizer_h1 = optim.Adam(h1_model.parameters(), lr=lr)
    optimizer_h2 = optim.Adam(h2_model.parameters(), lr=lr)
    optimizer_h3 = optim.Adam(h3_model.parameters(), lr=lr)

    # Initial Risk Function Values
    h1 = h1_model(torch.tensor(X1_mat, dtype=torch.float32))
    h2 = h2_model(torch.tensor(X2_mat, dtype=torch.float32))
    h3 = h3_model(torch.tensor(X3_mat, dtype=torch.float32))

    # Internal Helper Functions for Baseline Hazards
    def lam01_fun(t, gamma_val, h1_vals):
        numer = np.sum(d1 * (y1 == t))
        h1_np = h1_vals.detach().cpu().numpy().flatten()
        denom = np.sum(gamma_val * ((y1 >= t).astype(float)) * np.exp(h1_np))
        return numer/denom if numer > 0 else 0

    def lam02_fun(t, gamma_val, h2_vals):
        numer = np.sum((1 - d1) * d2 * (y1 == t))
        h2_np = h2_vals.detach().cpu().numpy().flatten()
        denom = np.sum(gamma_val * (((y2 >= t) & (y1 >= t)).astype(float)) * np.exp(h2_np))
        return numer/denom if numer > 0 else 0

    def lam03_fun(t, gamma_val, h3_vals):
        numer = np.sum(d1 * d2 * (y2 == t))
        h3_np = h3_vals.detach().cpu().numpy().flatten()
        denom = np.sum(gamma_val * d1 * ((y2 >= t).astype(float) - (y1 >= t).astype(float)) * np.exp(h3_np))
        return numer/denom if numer > 0 else 0
    
    def Lam01_fun(t, lam01):
        return np.sum(lam01[t1 - t < tol_val])
    def Lam02_fun(t, lam02):
        return np.sum(lam02[t2 - t < tol_val])
    def Lam03_fun(t, lam03):
        return np.sum(lam03[t3 - t < tol_val])

    # Theta Contribution to Expected Log-Likelihood
    def loss4(theta_tensor, gamma_tensor, log_gamma_tensor):
        loss = -1/theta_tensor * torch.log(theta_tensor) + (1/theta_tensor - 1)*log_gamma_tensor \
               - 1/theta_tensor * gamma_tensor - torch.lgamma(1/theta_tensor)
        return -torch.sum(loss)

    # Dataset Class for Dataloader
    class CustomDataset(Dataset):
        def __init__(self, inputs, outputs):
            self.inputs = inputs
            self.outputs = outputs
        def __len__(self):
            return self.inputs.shape[0]
        def __getitem__(self, idx):
            return self.inputs[idx], self.outputs[idx]

    # Neural EM-Algorithm Loop

    # Initialize Loss Tracking
    loss_h1_epoch = [np.inf]*(max_epochs+1)
    loss_h2_epoch = [np.inf]*(max_epochs+1)
    loss_h3_epoch = [np.inf]*(max_epochs+1)
    loss_theta_epoch = [np.inf]*(max_epochs+1)
    
    loss_theta_inc = 0
    loss_h1_inc = 0
    loss_h2_inc = 0
    loss_h3_inc = 0
    gamma_val = np.ones_like(y1)
    log_gamma_val = np.zeros_like(y1)
    diff_EM = 100
    epoch = 2

    while epoch < max_epochs and diff_EM > tol_val:
        # E-STEP
        theta_val = theta.item()
        a = 1/theta_val + d1 + d2 
        # Update Risk Functions
        h1_all = h1_model(torch.tensor(X1_mat, dtype=torch.float32)).detach()
        h2_all = h2_model(torch.tensor(X2_mat, dtype=torch.float32)).detach()
        h3_all = h3_model(torch.tensor(X3_mat, dtype=torch.float32)).detach()
        # Compute Cumulative Hazards for Each Subject
        Lam01 = np.array([Lam01_fun(t, lam01) for t in y1])
        Lam02 = np.array([Lam02_fun(t, lam02) for t in y1])
        Lam03_y2 = np.array([Lam03_fun(t, lam03) for t in y2])
        Lam03_y1 = np.array([Lam03_fun(t, lam03) for t in y1])
        h1_np = h1_all.numpy().flatten()
        h2_np = h2_all.numpy().flatten()
        h3_np = h3_all.numpy().flatten()
        b = 1/theta_val + Lam01 * np.exp(h1_np) + Lam02 * np.exp(h2_np) + d1 * (Lam03_y2 - Lam03_y1) * np.exp(h3_np)
        gamma_val = a / b
        log_gamma_val = digamma(a) - np.log(b)

        print("mean:", np.mean(gamma_val))
        print("var:", np.var(gamma_val))

        # M-STEP
        lam01_new = np.array([lam01_fun(t, gamma_val, h1_all) for t in t1])
        lam02_new = np.array([lam02_fun(t, gamma_val, h2_all) for t in t2])
        lam03_new = np.array([lam03_fun(t, gamma_val, h3_all) for t in t3])
        diff_EM = max(np.max(np.abs((lam01 - lam01_new) / (lam01 + 1e-8))),
                    np.max(np.abs((lam02 - lam02_new) / (lam02 + 1e-8))),
                    np.max(np.abs((lam03 - lam03_new) / (lam03 + 1e-8))))
        lam01 = lam01_new
        lam02 = lam02_new
        lam03 = lam03_new

        # N-STEP
        loss_theta_inc_inner = 0
        prev_loss_theta_inner = float('inf')
        
        if loss_theta_inc < 2:
            for epoch_theta in range(max_epochs_theta):
                optimizer_theta.zero_grad()
                loss_theta = loss4(theta,
                                     torch.tensor(gamma_val, dtype=torch.float32),
                                     torch.tensor(log_gamma_val, dtype=torch.float32))
                loss_theta.backward(retain_graph=True)
                optimizer_theta.step()
                current_loss_theta = loss_theta.item()
                if abs(current_loss_theta - prev_loss_theta_inner) <= 1e-4:
                    break
                
                elif current_loss_theta > prev_loss_theta_inner:
                    loss_theta_inc_inner += 1
                    if loss_theta_inc_inner > 2:
                        break
                if verbose:
                    print(f"Theta Epoch {epoch_theta}, Loss: {current_loss_theta}, Theta: {theta.item()}")
                
                prev_loss_theta_inner = current_loss_theta
            
            if current_loss_theta > loss_theta_epoch[epoch-1]:
                loss_theta_inc += 1
            else:
                loss_theta_inc = 0
            
            loss_theta_epoch[epoch] = current_loss_theta

        # Update Neural Network Parameters
        # Create Datasets for Each Transition
        outputs1 = np.column_stack((y1, d1, y2, d2, gamma_val))
        dataset1 = CustomDataset(torch.tensor(X1_mat, dtype=torch.float32),
                                 torch.tensor(outputs1, dtype=torch.float32))
        dl1 = DataLoader(dataset1, batch_size=batch_size, shuffle=True, drop_last=True)

        outputs2 = np.column_stack((y1, d1, y2, d2, gamma_val))
        dataset2 = CustomDataset(torch.tensor(X2_mat, dtype=torch.float32),
                                 torch.tensor(outputs2, dtype=torch.float32))
        dl2 = DataLoader(dataset2, batch_size=batch_size, shuffle=True, drop_last=True)

        outputs3 = np.column_stack((y1, d1, y2, d2, gamma_val))
        dataset3 = CustomDataset(torch.tensor(X3_mat, dtype=torch.float32),
                                 torch.tensor(outputs3, dtype=torch.float32))
        dl3 = DataLoader(dataset3, batch_size=batch_size, shuffle=True, drop_last=True)
        
        loss1_val, loss2_val, loss3_val = float('inf'), float('inf'), float('inf')
        

        loss_h1_inc_inner = 0
        prev_loss_h1_inner = float('inf')
        
        if loss_h1_inc < 2:
            
            for epoch_h1 in range(max_epochs_n):
            
                loss_h1_epoch_n = 0

                num_batches1 = 0

                # Transition 1
                for batch_X1, batch_out in dl1:
                    num_batches1 += 1
                    batch_y1 = batch_out[:, 0]
                    batch_d1 = batch_out[:, 1]
                    batch_y2 = batch_out[:, 2]
                    batch_d2 = batch_out[:, 3]
                    batch_gamma = batch_out[:, 4]
                    h1_batch = h1_model(batch_X1)
                    event_idx = (batch_d1 == 1).nonzero(as_tuple=True)[0]
                    loss_h1 = 0.0
                    if len(event_idx) > 0:
                        loss_h1 = loss1_fun_stable(h1_batch, batch_d1, batch_y1, batch_gamma)
                        optimizer_h1.zero_grad()
                        loss_h1.backward(retain_graph=True)
                        optimizer_h1.step()

                    current_loss_h1 = loss_h1.item()
                
                if num_batches1 > 0:
                    
                    current_loss_h1 = current_loss_h1 / num_batches1
                    
                if abs(current_loss_h1 - prev_loss_h1_inner) <= 1e-4:
                    break
                
                elif current_loss_h1 > prev_loss_h1_inner:
                        
                    loss_h1_inc_inner += 1
                    
                    if loss_h1_inc_inner > 2:
                            
                        break
                    
                prev_loss_h1_inner = current_loss_h1
                       
            loss1_val = current_loss_h1
            
            if loss1_val > loss_h1_epoch[epoch-1]:
                
                loss_h1_inc += 1
            
            else:
                
                loss_h1_inc = 0

            loss_h1_epoch[epoch] = loss1_val
            
            h1 = h1_model(torch.tensor(X1_mat, dtype=torch.float32))
            
        loss_h2_inc_inner = 0
        prev_loss_h2_inner = float('inf')
        
        if loss_h2_inc < 2:
            
            for epoch_h2 in range(max_epochs_n):
            
                loss_h2_epoch_n = 0

                num_batches2 = 0

                # Transition 2
                for batch_X2, batch_out in dl2:
                    num_batches2 += 1
                    batch_y1 = batch_out[:, 0]
                    batch_d1 = batch_out[:, 1]
                    batch_y2 = batch_out[:, 2]
                    batch_d2 = batch_out[:, 3]
                    batch_gamma = batch_out[:, 4]
                    h2_batch = h2_model(batch_X2)
                    event_idx = ((batch_d1 == 0) & (batch_d2 == 1)).nonzero(as_tuple=True)[0]
                    loss_h2 = 0.0
                    if len(event_idx) > 0:
                        loss_h2 = loss2_fun_stable(h2_batch, batch_d1, batch_d2, batch_y1, batch_gamma)
                        optimizer_h2.zero_grad()
                        loss_h2.backward(retain_graph=True)
                        optimizer_h2.step()

                    current_loss_h2 = loss_h2.item()
                
                if num_batches2 > 0:
                    
                    current_loss_h2 = current_loss_h2 / num_batches2
                    
                if abs(current_loss_h2 - prev_loss_h2_inner) <= 1e-4:
                    break
                
                elif current_loss_h2 > prev_loss_h2_inner:
                        
                    loss_h2_inc_inner += 1
                    
                    if loss_h2_inc_inner > 2:
                            
                        break
                    
                prev_loss_h2_inner = current_loss_h2
                       
            loss2_val = current_loss_h2
            
            if loss2_val > loss_h2_epoch[epoch-1]:
                
                loss_h2_inc += 1
            
            else:
                
                loss_h2_inc = 0

            loss_h2_epoch[epoch] = loss2_val
            
            h2 = h2_model(torch.tensor(X2_mat, dtype=torch.float32))
            
        loss_h3_inc_inner = 0
        prev_loss_h3_inner = float('inf')
        
        if loss_h3_inc < 2:
            
            for epoch_h3 in range(max_epochs_n):
            
                loss_h3_epoch_n = 0

                num_batches3 = 0

                # Transition 3
                for batch_X3, batch_out in dl3:
                    num_batches3 += 1
                    batch_y1 = batch_out[:, 0]
                    batch_d1 = batch_out[:, 1]
                    batch_y2 = batch_out[:, 2]
                    batch_d2 = batch_out[:, 3]
                    batch_gamma = batch_out[:, 4]
                    h3_batch = h3_model(batch_X3)
                    event_idx = ((batch_d1 == 1) & (batch_d2 == 1)).nonzero(as_tuple=True)[0]
                    loss_h3 = 0.0
                    if len(event_idx) > 0:
                        loss_h3 = loss3_fun_stable(h3_batch, batch_d1, batch_d2, batch_y1, batch_y2, batch_gamma)
                        optimizer_h3.zero_grad()
                        loss_h3.backward(retain_graph=True)
                        optimizer_h3.step()
                
                    current_loss_h3 = loss_h3.item()
                
                if num_batches3 > 0:
                    
                    current_loss_h3 = current_loss_h3 / num_batches3
                    
                if abs(current_loss_h3 - prev_loss_h3_inner) <= 1e-4:
                    break
                
                elif current_loss_h3 > prev_loss_h3_inner:
                        
                    loss_h3_inc_inner += 1
                    
                    if loss_h3_inc_inner > 2:
                            
                        break
                    
                prev_loss_h3_inner = current_loss_h3
                       
            loss3_val = current_loss_h3
            
            if loss3_val > loss_h3_epoch[epoch-1]:
                
                loss_h3_inc += 1
            
            else:
                
                loss_h3_inc = 0

            loss_h3_epoch[epoch] = loss3_val
            
            h3 = h3_model(torch.tensor(X3_mat, dtype=torch.float32))
        

        # Plot Baseline Hazards
        plt.figure(figsize=(15, 4))
        plt.subplot(1, 3, 1)
        plt.plot(t1, np.cumsum(lam01), marker='o')
        plt.title('Baseline Hazards 1')
        plt.xlabel('Failure Time')
        plt.ylabel('Cumulative Baseline Hazards')
        plt.axline((0, 0), slope=2, linestyle='--')
        plt.subplot(1, 3, 2)
        plt.plot(t2, np.cumsum(lam02), marker='o')
        plt.title('Baseline Hazards 2')
        plt.xlabel('Failure Time')
        plt.ylabel('Cumulative Baseline Hazards')
        plt.axline((0, 0), slope=3, linestyle='--')
        plt.subplot(1, 3, 3)
        plt.plot(t3, np.cumsum(lam03), marker='o')
        plt.title('Baseline Hazards 3')
        plt.xlabel('Failure Time')
        plt.ylabel('Cumulative Baseline Hazards')
        plt.axline((0, 0), slope=2, linestyle='--')
        plt.show()

        print(f"Epoch {epoch}: loss1: {loss1_val}, loss2: {loss2_val}, loss3: {loss3_val}")
        epoch += 1
        diff_EM = 100  # Reset diff_EM for next outer iteration

        epochs = range(1, len(loss_h1_epoch))
        plt.figure(figsize=(12, 4))

        plt.plot(epochs, loss_h1_epoch[1:], label="Loss h1")
        plt.plot(epochs, loss_h2_epoch[1:], label="Loss h2")
        plt.plot(epochs, loss_h3_epoch[1:], label="Loss h3")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Loss Curves")
        plt.legend()
        plt.tight_layout()
        plt.show()
        
        if loss_h1_inc >= 2 and loss_h2_inc >= 2 and loss_h3_inc >=2:
            
            break

    return {
        "theta": theta.item(),
        "lam01": lam01,
        "lam02": lam02,
        "lam03": lam03,
        "h1": h1,
        "h2": h2,
        "h3": h3,
        "h1_model": h1_model,
        "h2_model": h2_model,
        "h3_model": h3_model,
        "gamma": gamma_val
    }

In [None]:
def run_simulation(dat_sim, form):
    
    fit_00 = fit_dnn(form, dat_sim,
                     na_action="na.fail", subset=None, dim_layers=[1024, 128, 16],
                     lr=0.001, dr=0.3,
                     max_epochs=5, max_epochs_theta=20, max_epochs_n=100,
                     batch_size=500, verbose=True, ll=1, tol=1e-6,
                     theta0=1, lr_theta=0.05)
    
    result = {
    "theta": fit_00["theta"],
    "lam01": fit_00["lam01"],
    "lam02": fit_00["lam02"],
    "lam03": fit_00["lam03"],
    "h1": fit_00["h1"].detach().numpy(),
    "h2": fit_00["h2"].detach().numpy(),
    "h3": fit_00["h3"].detach().numpy(),
    "h1_model": fit_00["h1_model"],
    "h2_model": fit_00["h2_model"],
    "h3_model": fit_00["h3_model"],
    "gamma": fit_00["gamma"]
}
    return result

In [None]:
dat = BLCS_CLEAN2[['Yi1','Yi2','YiS','Di1','Di2','DiS','agedx','sex','race','ethnic','smk','pkyrs','trt','egfr','kras','copd','asthma','stage_splt']]

dat = pd.get_dummies(dat, drop_first=True)*1
form = {
    'Y': ["Yi1", "Di1", "Yi2", "Di2"],
    'X': ['agedx', 'pkyrs', 'sex_Male',
       'sex_Unknown', 'race_White/Caucasian', 'ethnic_Non-Hispanic',
       'smk_Former smoker', 'smk_Never smoker', 'smk_Smoker, status unknown',
       'smk_Unknown', 'trt_Other', 'trt_Radiation', 'trt_Surgery', 'trt_nan',
       'egfr_Not Tested', 'egfr_Yes', 'kras_Not Tested', 'kras_Yes',
       'copd_Yes', 'asthma_Yes','stage_splt']
    }

In [None]:
dat.columns

In [None]:
run_simulation(dat, form)

In [None]:
def bbs(data, preds, t_pred):
    y1 = data['Yi1'].values
    d1 = data['Di1'].astype(bool).values
    y2 = data['Yi2'].values
    d2 = data['Di2'].astype(bool).values
    ctime = data['YiS'].values
    cind = data['DiS'].astype(bool).values
    
    t_pred = np.asarray(t_pred)
    n, m = preds.shape
    
    ind_y1 = y1[:, None] <= t_pred[None, :]
    ind_y2 = y2[:, None] <= t_pred[None, :]
    ind_y1_y2 = (y1 <= y2)  # shape (n,)
    
    ind_cat1 = (ind_y1 & d1[:, None] & ind_y1_y2[:, None]).astype(int)
    ind_cat2 = (ind_y1 & ind_y2 & (~d1[:, None]) & d2[:, None] & ind_y1_y2[:, None]).astype(int)
    ind_cat3 = ((~ind_y1) & (~ind_y2)).astype(int)
    
    kmf = KaplanMeierFitter()
    kmf.fit(durations=ctime, event_observed=cind)
    
    csurv = kmf.survival_function_at_times(ctime).values
    csurv[csurv == 0] = np.inf
    
    csurv_pred = kmf.survival_function_at_times(t_pred).values
    min_valid = np.nanmin(csurv_pred)
    csurv_pred = np.where(np.isnan(csurv_pred), min_valid, csurv_pred)
    csurv_pred[csurv_pred == 0] = np.inf

    bs = np.zeros(m)
    for j in range(m):
        bs[j] = np.mean(
            (0 - preds[:, j])**2 * ind_cat1[:, j] * (1 / csurv) +
            (0 - preds[:, j])**2 * ind_cat2[:, j] * (1 / csurv) +
            (1 - preds[:, j])**2 * ind_cat3[:, j] * (1 / csurv_pred[j])
        )

    if m > 1:
        dt = np.diff(t_pred)
        bs_mid = (bs[:-1] + bs[1:]) / 2
        ibs = np.dot(dt, bs_mid) / (t_pred.max() - t_pred.min())
    else:
        ibs = bs[0] / t_pred[0]
    
    return {'tpred': t_pred, 'bs': bs, 'ibs': ibs}

In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from scipy.optimize import minimize
from scipy.stats import t
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter

K = 5
kf = KFold(n_splits=K, shuffle=True, random_state=42)
t_pred = np.linspace(0, 5, 100)

bs_mat = np.full((len(t_pred), K), np.nan)
ibs_vec = np.empty(K)

for k, (train_idx, test_idx) in enumerate(kf.split(dat)):
    train_dat = dat.iloc[train_idx].copy()
    test_dat = dat.iloc[test_idx].copy()

    y1 = test_dat['Yi1'].values
    d1 = test_dat['Di1'].values
    y2 = test_dat['Yi2'].values
    d2 = test_dat['Di2'].values
    
    
    X1_mat = test_dat[form['X']].values
    X2_mat = test_dat[form['X']].values
    X3_mat = test_dat[form['X']].values
    
    
    t1 = np.sort(y1[d1 == 1])
    t2 = np.sort(y2[(d1 == 0) & (d2 == 1)])
    t3 = np.sort(y2[(d1 == 1) & (d2 == 1)])
    
    t_max_new = max(np.max(y1), np.max(y2))
    t_new = np.linspace(0, t_max_new, test_dat.shape[0])
    
    y1_old = train_dat['Yi1'].values
    d1_old = train_dat['Di1'].values
    y2_old = train_dat['Yi2'].values
    d2_old = train_dat['Di2'].values
    
    
    t1_old = np.sort(y1_old[d1_old == 1])
    t2_old = np.sort(y2_old[(d1_old == 0) & (d2_old == 1)])
    t3_old = np.sort(y2_old[(d1_old == 1) & (d2_old == 1)])
    
    t_max_old = max(np.max(y1_old), np.max(y2_old))
    t_old = np.linspace(0, t_max_old, train_dat.shape[0])

    res = run_simulation(train_dat, form)
    

    # Post-Estimation
    
    h1_model = res["h1_model"]
    h2_model = res["h2_model"]
    h3_model = res["h3_model"]
 
    lam01 = np.cumsum(res["lam01"])
    lam02 = np.cumsum(res["lam02"])
    theta = res["theta"]
    h1_pred = res["h1"].flatten()
    h2_pred = res["h2"].flatten()
    h3_pred = res["h3"].flatten()
    
    h1_new = h1_model(torch.tensor(X1_mat, dtype=torch.float32)).detach().flatten()
    h2_new = h2_model(torch.tensor(X2_mat, dtype=torch.float32)).detach().flatten()
    h3_new = h3_model(torch.tensor(X3_mat, dtype=torch.float32)).detach().flatten()

    preds_true = np.empty((len(t_pred), test_dat.shape[0]))
    for idx, t in enumerate(t_pred):
        # find the index a minimizing |t1[a] - t|          
        a = np.argmin(np.abs(t1 - t))
        b = np.argmin(np.abs(t2 - t))
        # compute the formula
        preds_true[idx, :] = (
            1
            + theta * lam01[a] * np.exp(h1_new)
            + theta * lam02[b] * np.exp(h2_new)
        ) ** (-1 / theta)
            

    res_bbs = bbs(test_dat, preds_true.T, t_pred)
    bs_mat[:, k] = res_bbs['bs']
    ibs_vec[k] = res_bbs['ibs']

# Aggregate
mean_bs = np.mean(bs_mat, axis=1)
sd_bs = np.std(bs_mat, axis=1, ddof=1)
se_bs = sd_bs / np.sqrt(K)
from scipy.stats import t
t_q = t.ppf(0.975, df=K-1)

ci_lower = mean_bs - t_q * se_bs
ci_upper = mean_bs + t_q * se_bs

# Plot
plt.plot(t_pred, mean_bs, lw=2, label="Mean BBS")
plt.fill_between(t_pred, ci_lower, ci_upper, alpha=0.2, color="blue")
plt.xlabel("Time (yrs)")
plt.ylabel("BBS")
plt.legend()
plt.show()

# Integrated BBS
mean_ibs = np.mean(ibs_vec)
se_ibs = np.std(ibs_vec, ddof=1) / np.sqrt(K)
ci_ibs = mean_ibs + np.array([-1, 1]) * t_q * se_ibs

print(f"iBBS = {mean_ibs:.4f} (95% CI: {ci_ibs[0]:.4f} – {ci_ibs[1]:.4f})")

In [None]:
plt.plot(t_pred, mean_bs, lw=2, label="Mean BBS")
plt.fill_between(t_pred, ci_lower, ci_upper, alpha=0.2, color="blue")
bbs = {'t_pred':t_pred,
       'mean_bs':mean_bs,
       'ci_lower':ci_lower,
       'ci_upper':ci_upper}
bbs = pd.DataFrame(bbs)

In [140]:
bbs.to_csv("bbs.csv", index=False)

In [None]:
import numpy as np
from tqdm import tqdm

def bootstrap_cumhaz(data, n_boot=50):
    lam01_list = []
    lam02_list = []
    lam03_list = []

    for _ in tqdm(range(n_boot)):
        # Bootstrap resample
        data_boot = data.sample(n=len(data), replace=True)

        # Train model on bootstrap sample
        result = run_simulation(data_boot, form)
        
        # Store cumulative hazard
        lam01_list.append(np.cumsum(result["lam01"]))
        lam02_list.append(np.cumsum(result["lam02"]))
        lam03_list.append(np.cumsum(result["lam03"]))

    lam01_arr = np.array(lam01_list)
    lam02_arr = np.array(lam02_list)
    lam03_arr = np.array(lam03_list)

    def mean_ci(arr):
        mean = np.mean(arr, axis=0)
        lower = np.percentile(arr, 2.5, axis=0)
        upper = np.percentile(arr, 97.5, axis=0)
        return mean, lower, upper

    lam01_mean, lam01_lower, lam01_upper = mean_ci(lam01_arr)
    lam02_mean, lam02_lower, lam02_upper = mean_ci(lam02_arr)
    lam03_mean, lam03_lower, lam03_upper = mean_ci(lam03_arr)

    return {
        "lam01": (lam01_mean, lam01_lower, lam01_upper),
        "lam02": (lam02_mean, lam02_lower, lam02_upper),
        "lam03": (lam03_mean, lam03_lower, lam03_upper)
    }