In [34]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random

def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [35]:
manifest = pd.read_csv("./AD_Patient_Manifest.csv")
manifest

Unnamed: 0,PTID,path
0,002_S_0295,patients_csv\002_S_0295.pkl
1,002_S_0413,patients_csv\002_S_0413.pkl
2,002_S_0619,patients_csv\002_S_0619.pkl
3,002_S_0685,patients_csv\002_S_0685.pkl
4,002_S_0729,patients_csv\002_S_0729.pkl
...,...,...
377,137_S_0994,patients_csv\137_S_0994.pkl
378,137_S_1041,patients_csv\137_S_1041.pkl
379,137_S_1414,patients_csv\137_S_1414.pkl
380,941_S_1194,patients_csv\941_S_1194.pkl


In [36]:
patient_df = pd.read_pickle("patients_csv/002_S_0295.pkl")
print(patient_df.columns.tolist())

['RID', 'PTID', 'DX', 'MMSE', 'AGE', 'PTGENDER', 'PTEDUCAT', 'PTETHCAT', 'PTRACCAT', 'PTMARRY', 'ADAS11', 'ADAS13', 'ADASQ4', 'Years_bl', 'EXAMDATE', 'image_path', 'DX_encoded', 'PTGENDER_encoded', 'PTETHCAT_encoded', 'PTRACCAT_encoded', 'PTMARRY_encoded']


In [37]:
demographic_columns = [
    "AGE", "PTEDUCAT",
    "PTGENDER_encoded", "PTETHCAT_encoded",
    "PTRACCAT_encoded", "PTMARRY_encoded"
]

label_mapping = {"CN": 0, "MCI": 1, "AD": 2}
max_seq_length = 5

In [38]:
cox_time = []
cox_event = []
cox_features = []

for _, row in manifest.iterrows():
    patient_df = pd.read_pickle(row["path"])

    if 'image_path' in patient_df.columns:
        patient_df['image_path'] = patient_df['image_path'].apply(
            lambda x: x.replace("/home/mason/ADNI_Dataset/", "../ADNI_Dataset/")
        )

    dx_sequence = patient_df["DX"].values.tolist()

    if "MCI" not in dx_sequence:
        continue

    mci_idx = dx_sequence.index("MCI")

    ad_idx = -1
    for i in range(mci_idx + 1, len(dx_sequence)):
        if dx_sequence[i] in ["AD", "Dementia"]:
            ad_idx = i
            break

    if ad_idx != -1:
        time = patient_df["Years_bl"].iloc[ad_idx] - patient_df["Years_bl"].iloc[mci_idx]
        event = 1
    else:
        time = patient_df["Years_bl"].iloc[-1] - patient_df["Years_bl"].iloc[mci_idx]
        event = 0

    try:
        demo = patient_df[demographic_columns].iloc[mci_idx].values.astype(np.float32)
    except:
        continue

    cox_time.append(time)
    cox_event.append(event)
    cox_features.append(demo)

In [39]:
X_tensor = torch.tensor(np.array(cox_features), dtype=torch.float32).to(device)
time_tensor = torch.tensor(np.array(cox_time), dtype=torch.float32).to(device)
event_tensor = torch.tensor(np.array(cox_event), dtype=torch.float32).to(device)


In [40]:
input_dim = X_tensor.shape[1]

class CoxModel(nn.Module):
    def __init__(self, input_dim):
        super(CoxModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.fc(x)


In [41]:
def cox_ph_loss(pred_risk, time, event):
    pred_risk = pred_risk.view(-1)
    order = torch.argsort(time, descending=True)
    sorted_time = time[order]
    sorted_event = event[order]
    sorted_risk = pred_risk[order]

    exp_risk = torch.exp(sorted_risk)
    cum_sum = torch.cumsum(exp_risk, dim=0)
    log_cum_sum = torch.log(cum_sum)

    diff = sorted_risk - log_cum_sum
    observed = diff * sorted_event
    loss = -torch.mean(observed)
    return loss


In [42]:
for _, row in manifest.iterrows():
    df = pd.read_pickle(row["path"])
    print(df["DX"].unique())
    break


['CN']


In [49]:
learning_rate = 0.0001
epochs = 10000

In [52]:
epochs += 1 # for clear prints
model = CoxModel(input_dim=X_tensor.shape[1]).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
from tqdm import trange

for epoch in trange(epochs, desc="Training"):
    model.train()
    optimizer.zero_grad()

    risk_scores = model(X_tensor)
    loss = cox_ph_loss(risk_scores, time_tensor, event_tensor)
    loss.backward()
    optimizer.step()

    # tqdm에 loss 표시
    trange.set_description(f"Epoch {epoch} Loss: {loss.item():.4f}")



Epoch 0, Loss: 2.2033
Epoch 10, Loss: 2.1930
Epoch 20, Loss: 2.1834
Epoch 30, Loss: 2.1747
Epoch 40, Loss: 2.1667
Epoch 50, Loss: 2.1594
Epoch 60, Loss: 2.1527
Epoch 70, Loss: 2.1468
Epoch 80, Loss: 2.1414
Epoch 90, Loss: 2.1366
Epoch 100, Loss: 2.1323
Epoch 110, Loss: 2.1284
Epoch 120, Loss: 2.1249
Epoch 130, Loss: 2.1218
Epoch 140, Loss: 2.1191
Epoch 150, Loss: 2.1166
Epoch 160, Loss: 2.1144
Epoch 170, Loss: 2.1124
Epoch 180, Loss: 2.1107
Epoch 190, Loss: 2.1092
Epoch 200, Loss: 2.1078
Epoch 210, Loss: 2.1065
Epoch 220, Loss: 2.1055
Epoch 230, Loss: 2.1046
Epoch 240, Loss: 2.1038
Epoch 250, Loss: 2.1031
Epoch 260, Loss: 2.1025
Epoch 270, Loss: 2.1020
Epoch 280, Loss: 2.1016
Epoch 290, Loss: 2.1012
Epoch 300, Loss: 2.1008
Epoch 310, Loss: 2.1005
Epoch 320, Loss: 2.1002
Epoch 330, Loss: 2.0999
Epoch 340, Loss: 2.0997
Epoch 350, Loss: 2.0995
Epoch 360, Loss: 2.0993
Epoch 370, Loss: 2.0991
Epoch 380, Loss: 2.0989
Epoch 390, Loss: 2.0988
Epoch 400, Loss: 2.0986
Epoch 410, Loss: 2.0985
Epo

In [51]:
from lifelines.utils import concordance_index

model.eval()
with torch.no_grad():
    pred = model(X_tensor).cpu().numpy().reshape(-1)
    time_np = time_tensor.cpu().numpy().reshape(-1)
    event_np = event_tensor.cpu().numpy().reshape(-1)

    c_idx = concordance_index(time_np, -pred, event_np)
    print(f"Concordance Index: {c_idx:.4f}")


Concordance Index: 0.6248
