In [1]:
!pip install lifelines

Collecting lifelines
  Downloading lifelines-0.30.0-py3-none-any.whl.metadata (3.2 kB)
Collecting autograd-gamma>=0.3 (from lifelines)
  Downloading autograd-gamma-0.5.0.tar.gz (4.0 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting formulaic>=0.2.2 (from lifelines)
  Downloading formulaic-1.1.1-py3-none-any.whl.metadata (6.9 kB)
Collecting interface-meta>=1.2.0 (from formulaic>=0.2.2->lifelines)
  Downloading interface_meta-1.3.0-py3-none-any.whl.metadata (6.7 kB)
Downloading lifelines-0.30.0-py3-none-any.whl (349 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m349.3/349.3 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading formulaic-1.1.1-py3-none-any.whl (115 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.7/115.7 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading interface_meta-1.3.0-py3-none-any.whl (14 kB)
Building wheels for collected packages: autograd-gamma
  Building wheel for autograd-gamma (s

In [2]:
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold

from lifelines.utils import concordance_index
import warnings
warnings.filterwarnings('ignore')

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {DEVICE}")

# Paths and column definitions
SURVIVAL_DF_PATH = '/content/drive/MyDrive/surv.csv'
IMAGING_PKL_PATH = '/content/drive/MyDrive/x_with_coord.pkl'
TABULAR_COLS = ["AGE", "GENDER", "PTEDUCAT", "APOE4"]
PATIENT_ID_COL = "ID"
TIME_COL, EVENT_COL, GROUP_COL = "Years_bl", "status", PATIENT_ID_COL

##Image Data Loading and Pre-processing
imaging_df = pd.read_pickle(IMAGING_PKL_PATH)

device: cuda


In [5]:
df = imaging_df.copy()
block = 2
signal_cols = df.columns[3:]

df["Xb"] = df.x // block
df["Yb"] = df.y // block
df["Zb"] = df.z // block

print("Mean‑pooling …")
pooled = df.groupby(["Xb", "Yb", "Zb"])[signal_cols].mean()

pooled = pooled.reset_index()
pooled

Mean‑pooling …


Unnamed: 0,Xb,Yb,Zb,002_S_0729_2006-07-17_5.csv,002_S_0782_2006-08-14_3.csv,002_S_0954_2006-10-10_7.csv,002_S_1070_2006-11-28_11.csv,002_S_1155_2006-12-14_9.csv,002_S_1268_2007-02-14_11.csv,002_S_2010_2010-06-24_11.csv,...,941_S_2060_2010-09-08_733.csv,941_S_4036_2011-05-10_735.csv,941_S_4187_2011-06-22_734.csv,941_S_4377_2012-01-04_740.csv,941_S_4420_2012-03-28_739.csv,941_S_4764_2012-06-01_742.csv,941_S_6017_2017-05-17_822.csv,941_S_6052_2017-07-20_825.csv,941_S_6068_2017-08-21_831.csv,941_S_6345_2018-05-10_839.csv
0,1,37,44,-2.792912,-2.378889,-1.773856,-2.335286,0.596719,-1.765281,-2.833100,...,-1.701152,0.875000,0.201460,-2.583830,-2.595766,-1.767830,-2.355316,0.226981,1.646702,-2.352754
1,1,37,45,-2.808298,-2.378889,-1.773856,-2.329250,0.901965,-1.765281,-2.829968,...,-1.701152,0.155349,0.600329,-2.571290,-2.511538,-1.328341,-2.355316,0.690027,1.329351,-2.352754
2,1,37,46,-2.789328,-2.378889,-1.773856,-2.329084,0.651638,-1.765281,-2.832926,...,-1.701152,-1.182283,0.757916,-2.688458,-2.369983,-2.284303,-2.355316,-0.780878,1.504705,-2.352754
3,1,38,43,-2.791528,-2.375753,-1.773856,-2.517997,0.632981,-1.765281,-2.769821,...,-1.701152,1.011001,0.763914,-0.858805,-2.498166,-2.478895,-2.352305,1.436756,1.877898,-2.345762
4,1,38,44,-0.863732,-2.348688,-1.773856,-2.496181,1.103007,-1.765281,-2.746532,...,-1.701152,0.858369,0.688716,-1.114168,-2.000347,-0.240278,-2.330434,1.394198,1.328678,-2.335502
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
239163,77,49,50,-2.728616,-2.360885,-1.773856,-2.321014,-2.579449,-1.765281,-2.817723,...,-1.761741,-2.519014,-2.863345,-2.497635,-2.421602,-2.173714,-2.367502,-2.477654,-2.565042,-2.344331
239164,77,49,51,-2.726615,-2.380032,-1.773856,-2.313229,-2.566476,-1.765281,-2.826856,...,-1.780595,-2.517990,-2.874205,-2.528827,-2.411746,-2.184370,-2.366173,-2.471008,-2.574568,-2.326821
239165,77,49,52,-2.722288,-2.413499,-1.773856,-2.313831,-2.569669,-1.765281,-2.828936,...,-1.737624,-2.516753,-2.876511,-2.547080,-2.399212,-2.199603,-2.369792,-2.468890,-2.563396,-2.334988
239166,77,51,40,-2.762474,-2.382130,-1.773856,-2.328615,-2.588504,-1.765231,-2.810227,...,-1.701075,-2.520640,-2.848475,-2.551446,-2.426327,-2.171802,-2.352401,-2.462863,-2.548314,-2.340197


In [6]:
class SurvivalCNN(nn.Module):
    def __init__(self, num_tabular_features, cnn_output_dim=256):
        super().__init__()
        self.cnn_branch = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm3d(32),
            nn.MaxPool3d(2, 2),

            nn.Conv3d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm3d(64),
            nn.MaxPool3d(2, 2),

            nn.Conv3d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm3d(128),
            nn.MaxPool3d(2, 2),

            nn.Conv3d(128, 256, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm3d(256),
            nn.MaxPool3d(2, 2),

            nn.AdaptiveAvgPool3d((1, 1, 1)),
            nn.Flatten())


        # Note: self.tabular_branch outputs 1 feature, this is its linear transformation
        self.tabular_branch = nn.Linear(num_tabular_features, 1)

        self.image_head = nn.Sequential(
            nn.Linear(cnn_output_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

        self.cnn_output_dim = cnn_output_dim

        self._initialize_weights()

    def forward(self, x_image, x_tabular):
        image_features = self.cnn_branch(x_image)

        tabular_risk = self.tabular_branch(x_tabular)
        image_risk = self.image_head(image_features)
        risk_score = image_risk + tabular_risk

        return risk_score

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

# Load imaging data
imaging_df = pooled
coord_cols = ['Xb', 'Yb', 'Zb']
x_coords_all, y_coords_all, z_coords_all = imaging_df['Xb'].values, imaging_df['Yb'].values, imaging_df['Zb'].values
x_unique_sorted, y_unique_sorted, z_unique_sorted = np.sort(np.unique(x_coords_all)), np.sort(np.unique(y_coords_all)), np.sort(np.unique(z_coords_all))
dim_x, dim_y, dim_z = len(x_unique_sorted), len(y_unique_sorted), len(z_unique_sorted)
print(f"3D dimension: (X: {dim_x}, Y: {dim_y}, Z: {dim_z})")

# Create coordinate-to-index maps
x_map, y_map, z_map = {val: i for i, val in enumerate(x_unique_sorted)}, {val: i for i, val in enumerate(y_unique_sorted)}, {val: i for i, val in enumerate(z_unique_sorted)}
x_indices, y_indices, z_indices = np.array([x_map[val] for val in x_coords_all]), np.array([y_map[val] for val in y_coords_all]), np.array([z_map[val] for val in z_coords_all])

# Reshape voxel data into 3D images
patient_id_cols_from_img = imaging_df.columns[3:]
num_patients = len(patient_id_cols_from_img)
all_patient_images = np.zeros((num_patients, dim_x, dim_y, dim_z), dtype=np.float32)

for i, patient_col_name in enumerate(patient_id_cols_from_img):
    all_patient_images[i, x_indices, y_indices, z_indices] = imaging_df[patient_col_name].values

# Add channel dimension for PyTorch -> (N, 1, X, Y, Z)
X_data_images = np.expand_dims(all_patient_images, axis=1)
print(f"Image shape: {X_data_images.shape}")

# Create a mapping DataFrame from the image data
cleaned_patient_ids = [pid.split('.')[0] for pid in patient_id_cols_from_img]

imaging_ids_df = pd.DataFrame({
    'imaging_patient_id': cleaned_patient_ids,
    PATIENT_ID_COL: np.arange(1, num_patients + 1),
    'image_idx': np.arange(num_patients)
})

## 2. Survival/Tabular Data Loading and Merging
print("\n--- Loading survival data ---")
survival_df = pd.read_csv(SURVIVAL_DF_PATH)
df_final = pd.merge(survival_df, imaging_ids_df, on=PATIENT_ID_COL, how='inner')

print(f"Total patients: {len(df_final)}")
print(f"Event rate: {df_final[EVENT_COL].mean():.2%}")
print(f"Min follow-up time: {df_final[TIME_COL].min():.2f}")
print(f"Max follow-up time: {df_final[TIME_COL].max():.2f}")


3D dimension: (X: 77, Y: 92, Z: 74)
Image shape: (742, 1, 77, 92, 74)

--- Loading survival data ---
Total patients: 742
Event rate: 38.01%
Min follow-up time: 0.00
Max follow-up time: 13.71


In [7]:
class CoxPHloss(nn.Module):
    def __init__(self, reduction="mean"):
        super().__init__()
        self.reduction = reduction

    def forward(self, risk_scores, times, events):
        # 1. sort
        idx = torch.argsort(times)
        rs  = risk_scores[idx]
        e   = events[idx].float()

        # 2. log Σ_{t_j ≥ t_i} exp(η_j)
        log_risk_set = torch.logcumsumexp(rs.flip(0), 0).flip(0)

        # 3. -(η_i - log Σ) * event
        neg_log_pl = (log_risk_set - rs) * e
        loss = neg_log_pl.sum()

        # 4. reduction
        if self.reduction == "mean":
            num_events = e.sum()
            if num_events > 0:
                loss = loss / num_events
            else:  # 全删失
                return torch.tensor(0., dtype=rs.dtype, device=rs.device)

        return loss
class SurvivalDataset(Dataset):
    def __init__(self, df, image_data, tabular_cols, time_col, event_col):
        self.df = df.reset_index(drop=True)
        self.image_data = image_data
        self.tabular_cols = tabular_cols
        self.time_col = time_col
        self.event_col = event_col

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Get image
        image_idx = row['image_idx']
        image = self.image_data[image_idx]

        # Get tabular features
        tabular = row[self.tabular_cols].values.astype(np.float32)

        # Get survival outcomes
        time = row[self.time_col]
        event = row[self.event_col]

        return {
            'image': torch.tensor(image, dtype=torch.float32),
            'tabular': torch.tensor(tabular, dtype=torch.float32),
            'time': torch.tensor(time, dtype=torch.float32),
            'event': torch.tensor(event, dtype=torch.float32)
        }

## 5. Training Function

def train_epoch(model, train_loader, criterion, optimizer, device, l1_lambda=1e-4):
    model.train()
    total_loss = 0

    for batch in train_loader:
        images = batch['image'].to(device)
        tabular = batch['tabular'].to(device)
        times = batch['time'].to(device)
        events = batch['event'].to(device)
        risk_scores = model(images, tabular).squeeze()

        loss = criterion(risk_scores, times, events)

        # L1 norm
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        total_batch_loss = loss + l1_lambda * l1_norm

        optimizer.zero_grad()
        total_batch_loss.backward()
        optimizer.step()

        total_loss += total_batch_loss.item()

    return total_loss / len(train_loader)
## 6. Evaluation Function
def evaluate_model(model, dataloader, device):
    model.eval()
    all_risks = []
    all_times = []
    all_events = []

    with torch.no_grad():
        for batch in dataloader:
            images = batch['image'].to(device)
            tabular = batch['tabular'].to(device)
            times = batch['time']
            events = batch['event']

            risk_scores = model(images, tabular).squeeze()

            all_risks.extend(risk_scores.cpu().numpy())
            all_times.extend(times.numpy())
            all_events.extend(events.numpy())

    # Calculate C-index
    c_index = concordance_index(all_times, -np.array(all_risks), all_events)

    return c_index, all_risks, all_times, all_events



In [12]:
def run_5fold_cv_simple(df_final, X_data_images, tabular_cols, time_col, event_col,
                        n_epochs=100, batch_size=32, learning_rate=0.001, device=DEVICE):

    # Initialize results storage without train_loss
    cv_results = {
        'fold': [],
        'train_c_index': [],
        'val_c_index': []
    }

    # Prepare stratified k-fold
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)

    # Run 5-fold CV
    for fold, (train_idx, val_idx) in enumerate(skf.split(df_final, df_final[event_col])):
        print(f"\n--- Fold {fold + 1}/5 ---")

        # Split data
        train_df = df_final.iloc[train_idx].copy()
        val_df = df_final.iloc[val_idx].copy()

        # Standardize tabular features
        scaler = StandardScaler()
        train_df[tabular_cols] = scaler.fit_transform(train_df[tabular_cols])
        val_df[tabular_cols] = scaler.transform(val_df[tabular_cols])

        # Create datasets
        train_dataset = SurvivalDataset(train_df, X_data_images, tabular_cols, time_col, event_col)
        val_dataset = SurvivalDataset(val_df, X_data_images, tabular_cols, time_col, event_col)

        # Create dataloaders
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        # Initialize model
        model = SurvivalCNN(num_tabular_features=len(tabular_cols)).to(device)

        # Loss and optimizer
        criterion = CoxPHloss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=30, factor=0.5,min_lr=1e-8,verbose=True)

        # Training loop
        best_val_c_index = 0

        for epoch in range(n_epochs):
            # Train
            train_loss = train_epoch(model, train_loader, criterion, optimizer, device)

            # Evaluate
            train_c_index, _, _, _ = evaluate_model(model, train_loader, device)
            val_c_index, _, _, _ = evaluate_model(model, val_loader, device)

            # Update scheduler
            scheduler.step(-val_c_index)

            # Save best model
            if val_c_index > best_val_c_index:
                best_val_c_index = val_c_index
                torch.save(model.state_dict(), f'best_model_fold_{fold}.pth')

            if (epoch + 1) % 5 == 0:
                print(f"Epoch {epoch+1}/{n_epochs}: "
                      f"Train Loss: {train_loss:.4f}, "
                      f"Train C-index: {train_c_index:.4f}, "
                      f"Val C-index: {val_c_index:.4f}")

        # Load best model and get final results
        model.load_state_dict(torch.load(f'best_model_fold_{fold}.pth'))
        final_train_c_index, _, _, _ = evaluate_model(model, train_loader, device)
        final_val_c_index, _, _, _ = evaluate_model(model, val_loader, device)

        # Store results
        cv_results['fold'].append(fold + 1)
        cv_results['train_c_index'].append(final_train_c_index)
        cv_results['val_c_index'].append(final_val_c_index)

        print(f"Fold {fold + 1} - Best Val C-index: {final_val_c_index:.4f}")
        print(f"Fold {fold + 1} - Train C-index at best: {final_train_c_index:.4f}")

    # Calculate average results
    avg_train_c_index = np.mean(cv_results['train_c_index'])
    std_train_c_index = np.std(cv_results['train_c_index'])
    avg_val_c_index = np.mean(cv_results['val_c_index'])
    std_val_c_index = np.std(cv_results['val_c_index'])

    print(f"\n--- 5-Fold CV Results ---")
    print(f"Average Train C-index: {avg_train_c_index:.4f} ± {std_train_c_index:.4f}")
    print(f"Average Val C-index: {avg_val_c_index:.4f} ± {std_val_c_index:.4f}")

    # Display fold-wise results
    results_df = pd.DataFrame(cv_results)
    print("\nDetailed Results:")
    print(results_df)

    return cv_results, results_df

In [14]:
cv_results, results_df = run_5fold_cv_simple(
        df_final=df_final,
        X_data_images=X_data_images,
        tabular_cols=TABULAR_COLS,
        time_col=TIME_COL,
        event_col=EVENT_COL,
        n_epochs=200,
        batch_size=80,
        learning_rate=0.001,
        device=DEVICE)


--- Fold 1/5 ---
Epoch 5/200: Train Loss: 7.2042, Train C-index: 0.5981, Val C-index: 0.5837
Epoch 10/200: Train Loss: 6.5666, Train C-index: 0.6236, Val C-index: 0.5821
Epoch 15/200: Train Loss: 6.3184, Train C-index: 0.6661, Val C-index: 0.6031
Epoch 20/200: Train Loss: 5.8653, Train C-index: 0.6890, Val C-index: 0.6049
Epoch 25/200: Train Loss: 5.3171, Train C-index: 0.7635, Val C-index: 0.6210
Epoch 30/200: Train Loss: 5.1422, Train C-index: 0.8017, Val C-index: 0.6638
Epoch 35/200: Train Loss: 4.6860, Train C-index: 0.7125, Val C-index: 0.6487
Epoch 40/200: Train Loss: 4.6109, Train C-index: 0.8054, Val C-index: 0.6473
Epoch 45/200: Train Loss: 4.7595, Train C-index: 0.8256, Val C-index: 0.6691
Epoch 50/200: Train Loss: 4.1343, Train C-index: 0.9312, Val C-index: 0.7103
Epoch 55/200: Train Loss: 3.8397, Train C-index: 0.9117, Val C-index: 0.6764
Epoch 60/200: Train Loss: 3.6579, Train C-index: 0.9373, Val C-index: 0.7022
Epoch 65/200: Train Loss: 3.6362, Train C-index: 0.9351, Va