<a href="https://colab.research.google.com/github/vandrearczyk/hecktor-euvip2024/blob/main/baseline_prediction_hecktor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
! pip install scikit-survival

import numpy as np
import os
import pandas as pd
from sksurv.datasets import get_x_y
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import concordance_index_censored
from sklearn.model_selection import train_test_split
from google.colab import files



In [29]:
def load_features(folder_path):
    """
    Load all CSV files from a specified folder and concatenate them into a single DataFrame.

    Args:
    folder_path (str): Path to the folder containing CSV files.

    Returns:
    pd.DataFrame: Combined DataFrame from all CSV files.
    """
    dfs = []
    for filename in os.listdir(folder_path):
        if filename.startswith("features_album") and filename.endswith(".csv"):
            file_path = os.path.join(folder_path, filename)
            df = pd.read_csv(file_path)
            dfs.append(df)
    combined_df = pd.concat(dfs, ignore_index=True)
    return combined_df

def preprocess_data(combined_df, prefixes=None):
    """
    Preprocess the combined DataFrame by keeping the first three columns and those starting with specified prefixes.
    Then pivot the table to combine 'Modality', 'ROI', and each feature.

    Args:
    combined_df (pd.DataFrame): Combined DataFrame from multiple CSV files.
    prefixes (list of str or None): List of prefixes to keep in the DataFrame columns.
                                    If None, all columns are retained.

    Returns:
    pd.DataFrame: Pivoted DataFrame ready for model training.
    """
    # Keep the first three columns
    first_three_columns = combined_df.iloc[:, :3]

    # If prefixes is None, keep all columns, otherwise filter columns by the specified prefixes
    if prefixes is None:
        filtered_df = combined_df
    else:
        filtered_columns = [col for col in combined_df.columns if any(col.startswith(prefix) for prefix in prefixes)]
        filtered_df = pd.concat([first_three_columns, combined_df[filtered_columns]], axis=1)

    # Melt the filtered DataFrame
    feature_columns = [col for col in filtered_df.columns if col not in first_three_columns.columns]
    melted_df = filtered_df.melt(id_vars=['PatientID', 'Modality', 'ROI'], value_vars=feature_columns, var_name='Feature')

    # Create combined feature names
    melted_df['Combined'] = melted_df['ROI'] + '_' + melted_df['Modality'] + '_' + melted_df['Feature']

    # Pivot the DataFrame
    pivoted_df = melted_df.pivot_table(index='PatientID', columns='Combined', values='value')
    pivoted_df.reset_index(inplace=True)

    return pivoted_df

In [30]:
# Upload features
if any(fn.startswith('features_album') for fn in os.listdir('.')):
  print('Features already uploaded')
else:
  uploaded = files.upload()

Features already uploaded


In [31]:
# Upload survival_data
if os.path.exists('dummy_survival_data.csv'):
  print('Survival data already uploaded')
else:
  uploaded = files.upload()

Survival data already uploaded


In [33]:
# Load and preprocess the data
features_df = load_features('./')
pivoted_df = preprocess_data(features_df, prefixes=['original_intensity'])
survival_df = pd.read_csv('dummy_survival_data.csv')


In [34]:
# Prepare data for training
X = pivoted_df.drop(columns=['PatientID'])
X = X.fillna(X.mean())
y = np.array([(status, time) for status, time in zip(survival_df['SurvivalStatus'], survival_df['SurvivalTime'])],
                dtype=[('event', 'bool'), ('time', 'float')])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train the model
model = RandomSurvivalForest(n_estimators=100, min_samples_split=10, min_samples_leaf=15, random_state=42)
model.fit(X_train, y_train)

# Evaluate the model
cindex_train = concordance_index_censored(y_train['event'], y_train['time'], model.predict(X_train))[0]
cindex_test = concordance_index_censored(y_test['event'], y_test['time'], model.predict(X_test))[0]

print(f'Concordance Index (Train): {cindex_train:.2f}')
print(f'Concordance Index (Test): {cindex_test:.2f}')

Concordance Index (Train): 0.83
Concordance Index (Test): 0.53
