In [1]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
import torch
import torchtuples as tt
from pycox.models import CoxPH
from pycox.evaluation import EvalSurv
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Fixed columns
cols_standardize = ['Age','BMI', 'poverty_level', 'PC1','PC2','PC3','PC4','PC5','PC6','PC7','PC8','PC9']
cols_leave = ['Race', 'sex', 'Mobility', 'diabetes.y', 'Asthma', 'Arthritis', 'heart_failure', 'coronary_heart_disease', 
              'angina', 'stroke', 'thyroid', 'bronchitis', 'cancer']

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]
x_mapper = DataFrameMapper(standardize + leave)

# Model architecture
def get_model(input_dim):
    num_nodes = [256, 256]
    net = tt.practical.MLPVanilla(input_dim, num_nodes, 1, batch_norm=True, dropout=0.7, output_bias=False).to(device)
    return CoxPH(net, tt.optim.Adam)

# For C-index storage
c_index_list = []

# Directory containing the splits
split_dir = "D:/Final_Year_Project/splits_final/"  # Modify this if needed

for i in range(1, 101):
    print(f"Processing split {i}...")

    # Load train/test split
    df_train = pd.read_csv(os.path.join(split_dir, f"train_split_{i}.csv"))
    df_test = pd.read_csv(os.path.join(split_dir, f"test_split_{i}.csv"))

    # Fit transformer only on training data
    x_train = torch.tensor(x_mapper.fit_transform(df_train).astype('float32')).to(device)
    x_val = torch.tensor(x_mapper.transform(df_test).astype('float32')).to(device)
    x_test = torch.tensor(x_mapper.transform(df_test).astype('float32')).to(device)

    # Get target
    get_target = lambda df: (
        torch.tensor(df['time_mort'].values, dtype=torch.float32).to(device),
        torch.tensor(df['mortstat'].values, dtype=torch.float32).to(device)
    )
    y_train = get_target(df_train)
    y_val = get_target(df_test)
    durations_test, events_test = get_target(df_test)
    val = x_val, y_val

    durations_np = durations_test.cpu().numpy()
    events_np = events_test.cpu().numpy()

    # Model
    model = get_model(x_train.shape[1])

    # Find LR
    batch_size = 64
    lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=10)
    best_lr = lrfinder.get_best_lr()
    model.optimizer.set_lr(best_lr)

    # Train
    callbacks = [tt.callbacks.EarlyStopping()]
    model.fit(x_train, y_train, batch_size, 512, callbacks, verbose=False, 
              val_data=val, val_batch_size=batch_size)

    # Predict survival
    model.compute_baseline_hazards()
    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_np, events_np, censor_surv='km')
    c_index = ev.concordance_td()
    c_index_list.append(c_index)

# Save C-indices
c_index_df = pd.DataFrame({'split': list(range(1, 101)), 'c_index': c_index_list})
# c_index_df.to_csv("D:/Final_Year_Project/c_index_results.csv", index=False)

print("All splits processed. Results saved to c_index_results.csv.")

Processing split 1...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 2...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 3...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 4...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 5...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 6...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 7...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 8...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 9...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 10...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 11...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 12...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 13...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 14...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 15...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 16...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 17...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 18...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 19...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 20...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 21...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 22...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 23...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 24...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 25...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 26...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 27...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 28...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 29...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 30...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 31...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 32...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 33...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 34...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 35...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 36...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 37...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 38...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 39...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 40...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 41...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 42...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 43...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 44...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 45...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 46...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 47...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 48...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 49...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 50...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 51...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 52...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 53...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 54...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 55...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 56...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 57...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 58...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 59...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 60...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 61...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 62...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 63...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 64...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 65...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 66...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 67...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 68...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 69...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 70...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 71...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 72...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 73...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 74...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 75...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 76...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 77...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 78...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 79...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 80...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 81...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 82...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 83...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 84...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 85...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 86...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 87...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 88...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 89...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 90...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 91...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 92...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 93...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 94...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 95...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 96...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 97...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 98...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 99...


  self.net.load_state_dict(torch.load(path, **kwargs))


Processing split 100...


  self.net.load_state_dict(torch.load(path, **kwargs))


All splits processed. Results saved to c_index_results.csv.


In [2]:
c_index_df['c_index'].mean()

np.float64(0.7812455008664488)

In [3]:
c_index_df.to_csv("D:/Final_Year_Project/c_index_FPCdeepsurv.csv", index=False)