# Summary: This Notebook shows the cross-validation results using BIKG-based model on OAK data

In [None]:
import sys
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from karateclub import SocioDim
import networkx as nx


In [None]:
from KMPlot import subplots
from KMPlot import KMPlot

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder
from sksurv.util import Surv
from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from pysurvival.models.survival_forest import RandomSurvivalForestModel
from lifelines.utils import concordance_index as lfcindex
from sklearn.tree import DecisionTreeRegressor


## set random seed to make sure results are repeatable

In [None]:
import random
randomSeed=10
np.random.seed(randomSeed);


## helper functions

In [None]:
def display_summary(df, name:str=None):
    """Displays the head and summary statistics of a DataFrame.
    """
    if name:
        print(f"Summary of data for: {name}")
    print(f"Number of rows: {df.shape[0]}")
    print(f"Number of columns: {df.shape[1]}")
    print(f"\nSample of data:")
    display(df.sample(5))
    
def intersection(lst1, lst2):
    lst3 = [value for value in lst1 if value in lst2]
    return lst3

### Loading preprocessed dataset including patient embedding, genomic features, survival information

In [None]:
workdir="../Data/inputs/inputDatasetOAK/"
all_features=pd.read_csv(workdir+'OAK-IO.csv')
all_features['SAMPLE_ID']=['Patient'+str(i) for i in range(0,len(all_features))]
all_features.set_index('SAMPLE_ID',inplace=True)

#clinical_subgroup = clinical_features[clinical_features['SAMPLE_TYPE'].isin(['Metastasis'])]
survival_outcomes=all_features[['OS','OS.CNSR']]

survival_outcomes

In [None]:
TMB=all_features[['btmb']]
TMB

In [None]:
genomicFeaturesColumn=[col for col in all_features.columns if 'molecular_' in col]
genomic_features=all_features.loc[:,genomicFeaturesColumn]
genomic_features.replace(np.nan,0,inplace=True)
genomic_features[genomic_features > 0] = 1
genomic_features.columns=[name[10:] for name in list(genomic_features.columns)]
#genomic_features

### Patient cohort statistics. For this OAK dataset, there are 324 patients. The gene panel contains 396 genes

In [None]:
display_summary(genomic_features, "patient genomic features")


In [None]:
def getGenePatientEdges(genomic_features):
    (x_ind,y_ind)=np.where(genomic_features==1)
    gene_patient_edges=pd.DataFrame(list(zip(genomic_features.index[x_ind],genomic_features.columns[y_ind])),columns=['source_label','target_label'])
    return gene_patient_edges

In [None]:
def learnBIKGGraphEmbeddings(subgraph_edges,gene_patient_edges):
    bikg_edges = pd.concat([subgraph_edges, gene_patient_edges])
    nodes = set(bikg_edges["source_label"].values.tolist()).union(
                set(bikg_edges["target_label"].values.tolist())
            )
    node_to_num = {index: node for node, index in enumerate(nodes)}
    bikg_edges["source_num_id"] = bikg_edges["source_label"].map(
                lambda x: node_to_num[x]
            )
    bikg_edges["target_num_id"] = bikg_edges["target_label"].map(
        lambda x: node_to_num[x]
    )
    bikg_edges = bikg_edges[["source_num_id", "target_num_id"]].values.tolist()
    dim=16
    upstream_model = SocioDim(dimensions=dim,seed=42)
    upstream_model.fit(nx.from_edgelist(bikg_edges))
    BIKG_graph_embedding = upstream_model.get_embedding()
    return BIKG_graph_embedding, node_to_num



In [None]:
def buildPatientEmbeddingUsingGeneEmbedding(gene_patient_edges,BIKG_graph_embedding,node_to_num):
    gene_patient_edges_grouped=gene_patient_edges.groupby('source_label')['target_label'].apply(list).reset_index(name='target_labels')
    dim=BIKG_graph_embedding.shape[1]
    patients_embeddings_feature=[]
    for index, row in gene_patient_edges_grouped.iterrows():
        target_labels=row['target_labels']
        patient_embeddings=[0]*dim
        target_labels_length=len(target_labels)
        for gene in target_labels:
            if gene in node_to_num:
                patient_embeddings=patient_embeddings+np.array(BIKG_graph_embedding[node_to_num[gene], :])
            else:
                patient_embeddings=patient_embeddings+np.array([0]*dim)
        patients_embeddings_feature.append(patient_embeddings/target_labels_length)
    emb_features = ['X{}'.format(i) for i in range(dim)]
    patients_embeddings_feature = pd.DataFrame(patients_embeddings_feature, index=gene_patient_edges_grouped['source_label'], columns = emb_features)
    return patients_embeddings_feature


In [None]:
def getSurvivalInformation(survival_outcomes,patient_embedding):
    mergedDF=survival_outcomes.join(patient_embedding,how='right')
    y = Surv.from_dataframe("OS.CNSR","OS",  mergedDF)
    return y


# 1. Load BIKG subgraph

In [None]:
subgraph_edges = pd.read_csv('../Data/inputs/BIKGImmuneGeneGeneSubgraphs/subgraph4_1.csv', sep=',')
#subgraph_edges = pd.read_csv('/projects/aa/ktnt055/BIKG_project/BIKGImmuneGeneGeneSubgraphs/wholeGraph.csv', sep=',')
subgraph_edges=subgraph_edges[['source_label','target_label']]
subgraph_edges

In [None]:
#save_patient_embedding_flag=True
save_patient_embedding_flag=False
save_dir='../Data/outputs/RobustTestUsingOAK_CrossValidation/'


In [None]:
patientsID=intersection(intersection(list(genomic_features.index),list(survival_outcomes.index)),list(TMB.index))


# BIKG-based OS predictive performance

In [None]:

from sklearn.model_selection import KFold

# Number of folds
num_folds = 5

# Initialize KFold
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

# Split the data into 8 folds
fold_splits = list(kf.split(patientsID))
c_index_test=[]
df_all_BIKG=pd.DataFrame()

importantFeature={}

# Access each fold
for fold_index, (train_indices, test_indices) in enumerate(fold_splits):
    print("fold_index:"+str(fold_index))
    training_index = [patientsID[i] for i in train_indices]
    print ("Training set size:" +str(len(training_index)))

    holdout_index = [patientsID[i] for i in test_indices]
    print ("Holdout set size:" +str(len(holdout_index)))

    #build patient genomic feature graph
    patientGraph_holdout=getGenePatientEdges(genomic_features.loc[holdout_index,:])
    patientGraph_training=getGenePatientEdges(genomic_features.loc[training_index,:])

    if save_patient_embedding_flag==True:
        # save patient embedding
        [BIKG_graph_embedding,node_to_num]=learnBIKGGraphEmbeddings(subgraph_edges,patientGraph_training)

        patient_embedding_holdout=buildPatientEmbeddingUsingGeneEmbedding(patientGraph_holdout,BIKG_graph_embedding,node_to_num)
        patient_embedding_dataframe=buildPatientEmbeddingUsingGeneEmbedding(patientGraph_training,BIKG_graph_embedding,node_to_num)

        patient_embedding_holdout.to_parquet(save_dir+str(fold_index)+'/patient_embedding_holdout.parquet')
        patient_embedding_dataframe.to_parquet(save_dir+str(fold_index)+'/patient_embedding_dataframe.parquet')
    else:
        # load patient embedding
        patient_embedding_dataframe=pd.read_parquet(save_dir+str(fold_index)+'/patient_embedding_dataframe.parquet')
        patient_embedding_holdout=pd.read_parquet(save_dir+str(fold_index)+'/patient_embedding_holdout.parquet')

    # get patient survival information
    y_dataframe=getSurvivalInformation(survival_outcomes,patient_embedding_dataframe)
    y_holdout=getSurvivalInformation(survival_outcomes,patient_embedding_holdout)

    # random survival forest model training and evaluation using one fold
    downstream_model = RandomSurvivalForestModel(num_trees=100)
    y_train_censorship=[x[0] for x in y_dataframe]
    y_train_time=[x[1] for x in y_dataframe]
    downstream_model.fit(X=patient_embedding_dataframe, T=y_train_time, E=y_train_censorship,seed=23) 

    # get the cutoff using training data
    y_pred_dataframe=downstream_model.predict_risk(patient_embedding_dataframe)
    cutoff_75_percentile=np.quantile(y_pred_dataframe, 0.75)
    cutoff_75_percentile

    y_test_censorship=[x[0] for x in y_holdout]
    y_test_time=[x[1] for x in y_holdout]
    y_pred=downstream_model.predict_risk(patient_embedding_holdout)
    c_index = lfcindex(y_test_time, y_pred, y_test_censorship)
    if c_index<0.5:
        c_index_test.append(1-c_index)
    else:
        c_index_test.append(c_index)

    df=pd.DataFrame([y_pred,y_test_time,y_test_censorship]).T
    df.columns=['predictRisk','OS','OS.CNSR']
    df['group'] = 'Unknown'
    df.loc[df.predictRisk >= cutoff_75_percentile,'group']= "High"
    df.loc[df.predictRisk < cutoff_75_percentile,'group']= "Low"
    df.set_index(patient_embedding_holdout.index,inplace=True)
    df_all_BIKG=pd.concat([df_all_BIKG,df],axis=0)
    #print(len(patient_embedding_holdout))
       
    #identify most important embedding feature associated with survial prediction
    mostImportantFeatures=downstream_model.variable_importance_table.head(10)
    # fit a decision tree regression model to assocaite the most important embedding feature with molecular features
    regressor = DecisionTreeRegressor(random_state=23)
    genomic_features_train=genomic_features.loc[patient_embedding_dataframe.index,]
    regressor.fit(genomic_features_train, patient_embedding_dataframe[mostImportantFeatures.loc[0,'feature']])
    # sort the genomic features in decreasing order of their importance
    importance = regressor.feature_importances_
    indices = np.argsort(importance)[::-1]
    # select the top 10 genomic features
    rankTable=pd.DataFrame(list(zip(genomic_features_train.columns[indices],importance[indices])),columns=['FeatureName','Importance'])
    selected=rankTable.iloc[0:10,:]
    # store the feature name into a dictionary with frequency
    genomicFeatureList=list(selected['FeatureName'])
    for gene in genomicFeatureList:
        if gene not in importantFeature:
            importantFeature[gene]=1
        else:
            importantFeature[gene]=importantFeature[gene]+1

print ("Average c-index of 5 fold CV:")
print (np.mean(c_index_test))
print (np.std(c_index_test))

In [None]:
axs = subplots(cols=1, rows=1, w=6, h=4)
KMPlot(df_all_BIKG, time='OS', event='OS.CNSR', label=[ 'group'], score='predictRisk').plot(
    ['High', 'Low'], ax=axs[0],
    comparisons=[['Low', 'High', 'Low vs High']],
    saturation=0.9,
    linewidth=1.5,
    palette='Set1',
    template_color = 'black',xy_font_size=18,
    hr_color='black',
    x_legend = 0.5, y_legend=0.95,legend_font_size=12,
    label_height_adj=0.06,
    x_hr_legand=0.0,y_hr_legend=.1,hr_font_size=12,
);



sns.despine(offset=2)

# TMB

In [None]:
TMB_cutoff=16

df=pd.DataFrame([TMB.btmb,survival_outcomes['OS'],survival_outcomes['OS.CNSR']]).T
df.columns=['predictRisk','OS','censorLabel']
df['group'] = 'Unknown'
df.loc[df.predictRisk >= TMB_cutoff,'group']= "High"
df.loc[df.predictRisk < TMB_cutoff,'group']= "Low"



axs = subplots(cols=1, rows=1, w=6, h=4)
KMPlot(df, time='OS', event='censorLabel', label=[ 'group'], score='predictRisk').plot(
    ['High', 'Low'], ax=axs[0],
    comparisons=[['Low', 'High', 'Low vs High']],
    saturation=0.9,
    linewidth=1.5,
    palette='Set1',
    template_color = 'black',xy_font_size=18,
    hr_color='black',
    x_legend = 0.5, y_legend=0.95,legend_font_size=12,
    label_height_adj=0.06,
    x_hr_legand=0.0,y_hr_legend=.1,hr_font_size=12,
);



sns.despine(offset=2)

# genomic features

In [None]:
# list of patient IDs
patientsID=intersection(intersection(list(genomic_features.index),list(survival_outcomes.index)),list(TMB.index))

from sklearn.model_selection import KFold

# Number of folds
num_folds = 5

# Initialize KFold
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

# Split the data into 5 folds
fold_splits = list(kf.split(patientsID))
c_index_test_Genomic=[]
df_all_genomicfeatures=pd.DataFrame()


# Access each fold
for fold_index, (train_indices, test_indices) in enumerate(fold_splits):
    print("fold_index:"+str(fold_index))
    training_index = [patientsID[i] for i in train_indices]
    print ("Training set size:" +str(len(training_index)))

    holdout_index = [patientsID[i] for i in test_indices]
    print ("Holdout set size:" +str(len(holdout_index)))
    
    # get patient features
    genomic_features_training=genomic_features.loc[training_index,:]
    genomic_features_holdout=genomic_features.loc[holdout_index,:]
    
    # get patient survival information
    y_train=getSurvivalInformation(survival_outcomes,genomic_features_training)
    y_test=getSurvivalInformation(survival_outcomes,genomic_features_holdout)
    
    # random survival forest model training and evaluation using one fold
    downstream_model = RandomSurvivalForestModel(num_trees=100)
    y_train_censorship=[x[0] for x in y_train]
    y_train_time=[x[1] for x in y_train]
    downstream_model.fit(X=genomic_features_training, T=y_train_time, E=y_train_censorship,seed=23)
    
    # get the cutoff using training data
    y_pred_dataframe=downstream_model.predict_risk(genomic_features_training)
    cutoff_75_percentile=np.quantile(y_pred_dataframe, 0.75)
    cutoff_75_percentile

    y_test_censorship=[x[0] for x in y_test]
    y_test_time=[x[1] for x in y_test]
    y_pred=downstream_model.predict_risk(genomic_features_holdout)
    c_index = lfcindex(y_test_time, y_pred, y_test_censorship)
    if c_index<0.5:
        c_index_test_Genomic.append(1-c_index)
    else:
        c_index_test_Genomic.append(c_index)
        
    df=pd.DataFrame([y_pred,y_test_time,y_test_censorship]).T
    df.columns=['predictRisk','OS','OS.CNSR']
    df['group'] = 'Unknown'
    df.loc[df.predictRisk >= cutoff_75_percentile,'group']= "High"
    df.loc[df.predictRisk < cutoff_75_percentile,'group']= "Low"
    df.set_index(genomic_features_holdout.index,inplace=True)
    df_all_genomicfeatures=pd.concat([df_all_genomicfeatures,df],axis=0)
    
print ("Average c-index of 5 fold CV:")
print (np.mean(c_index_test_Genomic))
print (np.std(c_index_test_Genomic))

In [None]:



axs = subplots(cols=1, rows=1, w=6, h=4)
KMPlot(df_all_genomicfeatures, time='OS', event='OS.CNSR', label=[ 'group'], score='predictRisk').plot(
    ['High', 'Low'], ax=axs[0],
    comparisons=[['Low', 'High', 'Low vs High']],
    saturation=0.9,
    linewidth=1.5,
    palette='Set1',
    template_color = 'black',xy_font_size=18,
    hr_color='black',
    x_legend = 0.5, y_legend=0.95,legend_font_size=12,
    label_height_adj=0.06,
    x_hr_legand=0.0,y_hr_legend=.1,hr_font_size=12,
);



sns.despine(offset=2)