In [5]:
import pandas as pd
import numpy as np
from torch import Tensor
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam
from torchsurv.loss import cox
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.loss.cox import neg_partial_log_likelihood
import tensorflow as tf
from tqdm import tqdm
from sksurv.linear_model import CoxPHSurvivalAnalysis, CoxnetSurvivalAnalysis
import torch.nn.functional as F
import random
import os

g = torch.Generator()
g.manual_seed(1)
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
    torch.use_deterministic_algorithms(True)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed(1)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
path_train = pd.read_csv("~//Documents//Survival VAE//lung_top_predicted_0.25_IMPACT_adeno.csv")
survival = pd.read_csv("~//Documents//Survival VAE//lung_clin_top_0.25_IMPACT_adeno.csv")
filtered_survival = survival.dropna(subset=['DFS_MONTHS', 'DFS_STATUS'])

path_train = path_train[path_train['PatientID'].isin(filtered_survival['PatientID'])]
filtered_survival = filtered_survival[filtered_survival['PatientID'].isin(path_train['PatientID'])]


filtered_survival = filtered_survival[['PatientID', 'DFS_STATUS', 'DFS_MONTHS']]
filtered_survival = filtered_survival.rename(columns={'PatientID':'PatientID','DFS_STATUS':'Status', 'DFS_MONTHS':'Survival_in_days'})

path_train = pd.merge(path_train, filtered_survival, on='PatientID')

path_x_train = path_train.drop(columns = ['PatientID', 'Status', 'Survival_in_days'])
features = path_x_train.columns

path_y_train = path_train[['Status', 'Survival_in_days']]

path_y_train.loc[path_y_train['Status'] == '0:DiseaseFree', 'Status'] = 0
path_y_train.loc[path_y_train['Status'] == '1:Recurred/Progressed', 'Status'] = 1

features = path_x_train.columns

In [9]:
path_val = pd.read_csv("~//Documents//Survival VAE//Validation Data//Breast_MSK_Scores_IMPACT_tnbc.csv")
path_val = path_val.dropna(subset=['DFS_MONTHS', 'DFS_EVENT'])
path_x_val = path_val.drop(columns = ['Unnamed: 0','PatientID', 'DFS_MONTHS', 'DFS_EVENT'])
path_y_val = path_val[['DFS_EVENT', 'DFS_MONTHS']]
path_y_val = path_y_val.rename(columns={'DFS_EVENT':'Status', 'DFS_MONTHS':'Survival_in_days'})

#path_x_train = path_val
#path_y_train = path_y_val

In [11]:
pathway_phenotypes = pd.read_csv("~//Documents//Survival VAE//Pathway_Phenotypes_v4.csv")
dna_repair_pathways = pathway_phenotypes['DNA_REPAIR'].tolist()
dna_repair_pathways = [x for x in dna_repair_pathways if str(x) != 'nan']
dna_repair_pathways = list(set(dna_repair_pathways) & set(features))
dna_repair = path_x_train[dna_repair_pathways]

growth_suppressor_pathways = pathway_phenotypes['EVADE_GROWTH_SUPPRESSOR'].tolist()
growth_suppressor_pathways = [x for x in growth_suppressor_pathways if str(x) != 'nan']
growth_suppressor_pathways = list(set(growth_suppressor_pathways) & set(features))
evade_growth_suppressor = path_x_train[growth_suppressor_pathways]

hormone_signaling_pathways = pathway_phenotypes['HORMONE_SIGNALING'].tolist()
hormone_signaling_pathways = [x for x in hormone_signaling_pathways if str(x) != 'nan']
hormone_signaling_pathways = list(set(hormone_signaling_pathways) & set(features))
hormone_signaling = path_x_train[hormone_signaling_pathways]

immune_pathways = pathway_phenotypes['IMMUNE'].tolist()
immune_pathways = [x for x in immune_pathways if str(x) != 'nan']
immune_pathways = list(set(immune_pathways) & set(features))
immune = path_x_train[immune_pathways]

inflammation_pathways = pathway_phenotypes['INFLAMMATION'].tolist()
inflammation_pathways = [x for x in inflammation_pathways if str(x) != 'nan']
inflammation_pathways = list(set(inflammation_pathways) & set(features))
inflammation = path_x_train[inflammation_pathways]

metabolism_pathways = pathway_phenotypes['METABOLISM'].tolist()
metabolism_pathways = [x for x in metabolism_pathways if str(x) != 'nan']
metabolism_pathways = list(set(metabolism_pathways) & set(features))
metabolism = path_x_train[metabolism_pathways]

metastasis_pathways = pathway_phenotypes['METASTASIS'].tolist()
metastasis_pathways = [x for x in metastasis_pathways if str(x) != 'nan']
metastasis_pathways = list(set(metastasis_pathways) & set(features))
metastasis = path_x_train[metastasis_pathways]

plasticity_pathways = pathway_phenotypes['PLASTICITY'].tolist()
plasticity_pathways = [x for x in plasticity_pathways if str(x) != 'nan']
plasticity_pathways = list(set(plasticity_pathways) & set(features))
plasticity = path_x_train[plasticity_pathways]

proliferation_pathways = pathway_phenotypes['PROLIFERATION'].tolist()
proliferation_pathways = [x for x in proliferation_pathways if str(x) != 'nan']
proliferation_pathways = list(set(proliferation_pathways) & set(features))
proliferation = path_x_train[proliferation_pathways]

resist_cell_death_pathways = pathway_phenotypes['RESIST_CELL_DEATH'].tolist()
resist_cell_death_pathways = [x for x in resist_cell_death_pathways if str(x) != 'nan']
resist_cell_death_pathways = list(set(resist_cell_death_pathways) & set(features))
resist_cell_death = path_x_train[resist_cell_death_pathways]

vasculator_pathways = pathway_phenotypes['VASCULATOR'].tolist()
vasculator_pathways = [x for x in vasculator_pathways if str(x) != 'nan']
vasculator_pathways = list(set(vasculator_pathways) & set(features))
vasculator = path_x_train[vasculator_pathways]


path_y_train = path_y_train.to_numpy()

dna_repair = dna_repair.to_numpy()
evade_growth_suppressor = evade_growth_suppressor.to_numpy()
hormone_signaling = hormone_signaling.to_numpy()
immune = immune.to_numpy()
inflammation = inflammation.to_numpy()
metabolism = metabolism.to_numpy()
metastasis = metastasis.to_numpy()
plasticity = plasticity.to_numpy()
proliferation = proliferation.to_numpy()
resist_cell_death = resist_cell_death.to_numpy()
vasculator = vasculator.to_numpy()

In [13]:
dna_repair_x_train = dna_repair
growth_suppressor_x_train = evade_growth_suppressor
hormone_signaling_x_train = hormone_signaling
immune_x_train = immune
inflammation_x_train = inflammation
metabolism_x_train = metabolism
metastasis_x_train = metastasis
plasticity_x_train = plasticity
proliferation_x_train = proliferation
resist_cell_death_x_train = resist_cell_death
vasculator_x_train = vasculator

y_train = path_y_train

dna_repair_x_dim = dna_repair_x_train.shape[1]
growth_suppressor_x_dim = growth_suppressor_x_train.shape[1]
hormone_signaling_x_dim = hormone_signaling_x_train.shape[1]
immune_x_dim = immune_x_train.shape[1]
inflammation_x_dim = inflammation_x_train.shape[1]
metabolism_x_dim = metabolism_x_train.shape[1]
metastasis_x_dim = metastasis_x_train.shape[1]
plasticity_x_dim = plasticity_x_train.shape[1]
proliferation_x_dim = proliferation_x_train.shape[1]
resist_cell_death_x_dim = resist_cell_death_x_train.shape[1]
vasculator_x_dim = vasculator_x_train.shape[1]

In [15]:
xmin = np.amin(dna_repair_x_train)
xmax = np.amax(dna_repair_x_train)
dna_repair_norm_train = (dna_repair_x_train - xmin) / (xmax - xmin)

xmin = np.amin(growth_suppressor_x_train)
xmax = np.amax(growth_suppressor_x_train)
growth_suppressor_norm_train = (growth_suppressor_x_train - xmin) / (xmax - xmin)

xmin = np.amin(hormone_signaling_x_train)
xmax = np.amax(hormone_signaling_x_train)
hormone_signaling_norm_train = (hormone_signaling_x_train - xmin) / (xmax - xmin)

xmin = np.amin(immune_x_train)
xmax = np.amax(immune_x_train)
immune_norm_train = (immune_x_train - xmin) / (xmax - xmin)

xmin = np.amin(inflammation_x_train)
xmax = np.amax(inflammation_x_train)
inflammation_norm_train = (inflammation_x_train - xmin) / (xmax - xmin)

xmin = np.amin(metabolism_x_train)
xmax = np.amax(metabolism_x_train)
metabolism_norm_train = (metabolism_x_train - xmin) / (xmax - xmin)

xmin = np.amin(metastasis_x_train)
xmax = np.amax(metastasis_x_train)
metastasis_norm_train = (metastasis_x_train - xmin) / (xmax - xmin)

xmin = np.amin(plasticity_x_train)
xmax = np.amax(plasticity_x_train)
plasticity_norm_train = (plasticity_x_train - xmin) / (xmax - xmin)

xmin = np.amin(proliferation_x_train)
xmax = np.amax(proliferation_x_train)
proliferation_norm_train = (proliferation_x_train - xmin) / (xmax - xmin)

xmin = np.amin(resist_cell_death_x_train)
xmax = np.amax(resist_cell_death_x_train)
resist_cell_death_norm_train = (resist_cell_death_x_train - xmin) / (xmax - xmin)

xmin = np.amin(vasculator_x_train)
xmax = np.amax(vasculator_x_train)
vasculator_norm_train = (vasculator_x_train - xmin) / (xmax - xmin)

In [17]:
dna_repair_norm_train = torch.from_numpy(dna_repair_norm_train.astype(np.float32))

growth_suppressor_norm_train = torch.from_numpy(growth_suppressor_norm_train.astype(np.float32))

hormone_signaling_norm_train = torch.from_numpy(hormone_signaling_norm_train.astype(np.float32))

immune_norm_train = torch.from_numpy(immune_norm_train.astype(np.float32))

inflammation_norm_train = torch.from_numpy(inflammation_norm_train.astype(np.float32))

metabolism_norm_train = torch.from_numpy(metabolism_norm_train.astype(np.float32))

metastasis_norm_train = torch.from_numpy(metastasis_norm_train.astype(np.float32))

plasticity_norm_train = torch.from_numpy(plasticity_norm_train.astype(np.float32))

proliferation_norm_train = torch.from_numpy(proliferation_norm_train.astype(np.float32))

resist_cell_death_norm_train = torch.from_numpy(resist_cell_death_norm_train.astype(np.float32))

vasculator_norm_train = torch.from_numpy(vasculator_norm_train.astype(np.float32))

y_train = torch.from_numpy(y_train.astype(np.float32))

dataset = TensorDataset(Tensor(dna_repair_norm_train),
                        Tensor(growth_suppressor_norm_train),
                        Tensor(hormone_signaling_norm_train),
                        Tensor(immune_norm_train),
                        Tensor(inflammation_norm_train),
                        Tensor(metabolism_norm_train),
                        Tensor(metastasis_norm_train),
                        Tensor(plasticity_norm_train),
                        Tensor(proliferation_norm_train),
                        Tensor(resist_cell_death_norm_train),
                        Tensor(vasculator_norm_train),
                        Tensor(y_train))
torch.manual_seed(1)
train_loader = DataLoader(dataset, batch_size = 8, shuffle = True,
                         generator = g, worker_init_fn=seed_worker,
                                             num_workers = 0)

In [19]:
def SurvVAE_Total(dna_repair_input_dim, growth_suppressor_input_dim, 
                 hormone_signaling_input_dim, immune_input_dim,
                 inflammation_input_dim, metabolism_input_dim,
                 metastasis_input_dim, plasticity_input_dim,
                 proliferation_input_dim, resist_cell_death_input_dim,
                 vasculator_input_dim, hidden_dim, latent_dim, train_loader,
                 DEVICE = DEVICE):

    class Encoder_Surv(nn.Module):
        def __init__(self, dna_repair_input_dim, growth_suppressor_input_dim, 
                     hormone_signaling_input_dim, immune_input_dim,
                     inflammation_input_dim, metabolism_input_dim,
                     metastasis_input_dim, plasticity_input_dim,
                     proliferation_input_dim, resist_cell_death_input_dim,
                     vasculator_input_dim, hidden_dim, latent_dim):
            super(Encoder_Surv, self).__init__()
            self.dropout = nn.Dropout(0.2)
            self.fc_dna_repair_1 = nn.Linear(dna_repair_input_dim, hidden_dim)
            self.fc_dna_repair_2 = nn.Linear(hidden_dim, 1)
    
            self.fc_growth_suppressor_1 = nn.Linear(growth_suppressor_input_dim, hidden_dim)
            self.fc_growth_suppressor_2 = nn.Linear(hidden_dim, 1)
    
            self.fc_hormone_signaling_1 = nn.Linear(hormone_signaling_input_dim, hidden_dim)
            self.fc_hormone_signaling_2 = nn.Linear(hidden_dim, 1)
    
            self.fc_immune_1 = nn.Linear(immune_input_dim, hidden_dim)
            self.fc_immune_2 = nn.Linear(hidden_dim, 1)
    
            self.fc_inflammation_1 = nn.Linear(inflammation_input_dim, hidden_dim)
            self.fc_inflammation_2 = nn.Linear(hidden_dim, 1)
    
            self.fc_metabolism_1 = nn.Linear(metabolism_input_dim, hidden_dim)
            self.fc_metabolism_2 = nn.Linear(hidden_dim, 1)
    
            self.fc_metastasis_1 = nn.Linear(metastasis_input_dim, hidden_dim)
            self.fc_metastasis_2 = nn.Linear(hidden_dim, 1)
    
            self.fc_plasticity_1 = nn.Linear(plasticity_input_dim, hidden_dim)
            self.fc_plasticity_2 = nn.Linear(hidden_dim, 1)
    
            self.fc_proliferation_1 = nn.Linear(proliferation_input_dim, hidden_dim)
            self.fc_proliferation_2 = nn.Linear(hidden_dim, 1)
    
            self.fc_resist_cell_death_1 = nn.Linear(resist_cell_death_input_dim, hidden_dim)
            self.fc_resist_cell_death_2 = nn.Linear(hidden_dim, 1)
    
            self.fc_vasculator_1 = nn.Linear(vasculator_input_dim, hidden_dim)
            self.fc_vasculator_2 = nn.Linear(hidden_dim, 1)

            self.LeakyReLU = nn.LeakyReLU(0.2)
            
            self.FC_mean  = nn.Linear(latent_dim, latent_dim)
            self.FC_var   = nn.Linear(latent_dim, latent_dim)
            
            self.training = True
            self.reg = nn.Linear(latent_dim, 1)
            
            self.dna_repair_stop = dna_repair_input_dim 
            self.growth_suppressor_stop = growth_suppressor_input_dim + self.dna_repair_stop 
            self.hormone_signaling_stop = hormone_signaling_input_dim + self.growth_suppressor_stop 
            self.immune_stop = immune_input_dim + self.hormone_signaling_stop
            self.inflammation_stop = inflammation_input_dim + self.immune_stop 
            self.metabolism_stop = metabolism_input_dim + self.inflammation_stop 
            self.metastasis_stop = metastasis_input_dim + self.metabolism_stop 
            self.plasticity_stop = plasticity_input_dim + self.metastasis_stop 
            self.proliferation_stop = proliferation_input_dim + self.plasticity_stop 
            self.resist_cell_death_stop = resist_cell_death_input_dim + self.proliferation_stop 
            self.vasculator_stop = vasculator_input_dim + self.resist_cell_death_stop 
            
        def forward(self, x):
            h_dna_repair = self.LeakyReLU(self.fc_dna_repair_1(self.dropout(x[:,0:self.dna_repair_stop])))
            h_growth_suppressor = self.LeakyReLU(self.fc_growth_suppressor_1(self.dropout(x[:,(self.dna_repair_stop):self.growth_suppressor_stop])))
            h_hormone_signaling = self.LeakyReLU(self.fc_hormone_signaling_1(self.dropout(x[:,(self.growth_suppressor_stop):self.hormone_signaling_stop])))
            h_immune = self.LeakyReLU(self.fc_immune_1(self.dropout(x[:,(self.hormone_signaling_stop):self.immune_stop])))
            h_inflammation = self.LeakyReLU(self.fc_inflammation_1(self.dropout(x[:,(self.immune_stop):self.inflammation_stop])))
            h_metabolism = self.LeakyReLU(self.fc_metabolism_1(self.dropout(x[:,(self.inflammation_stop):self.metabolism_stop])))
            h_metastasis = self.LeakyReLU(self.fc_metastasis_1(self.dropout(x[:,(self.metabolism_stop):self.metastasis_stop])))
            h_plasticity = self.LeakyReLU(self.fc_plasticity_1(self.dropout(x[:,(self.metastasis_stop):self.plasticity_stop])))
            h_proliferation = self.LeakyReLU(self.fc_proliferation_1(self.dropout(x[:,(self.plasticity_stop):self.proliferation_stop])))
            h_resist_cell_death = self.LeakyReLU(self.fc_resist_cell_death_1(self.dropout(x[:,(self.proliferation_stop):self.resist_cell_death_stop])))
            h_vasculator = self.LeakyReLU(self.fc_vasculator_1(self.dropout(x[:,(self.resist_cell_death_stop):self.vasculator_stop])))
    
            h_dna_repair = self.LeakyReLU(self.fc_dna_repair_2(self.dropout(h_dna_repair)))
            h_growth_suppressor = self.LeakyReLU(self.fc_growth_suppressor_2(self.dropout(h_growth_suppressor)))
            h_hormone_signaling = self.LeakyReLU(self.fc_hormone_signaling_2(self.dropout(h_hormone_signaling)))
            h_immune = self.LeakyReLU(self.fc_immune_2(self.dropout(h_immune)))
            h_inflammation = self.LeakyReLU(self.fc_inflammation_2(self.dropout(h_inflammation)))
            h_metabolism = self.LeakyReLU(self.fc_metabolism_2(self.dropout(h_metabolism)))
            h_metastasis = self.LeakyReLU(self.fc_metastasis_2(self.dropout(h_metastasis)))
            h_plasticity = self.LeakyReLU(self.fc_plasticity_2(self.dropout(h_plasticity)))
            h_proliferation = self.LeakyReLU(self.fc_proliferation_2(self.dropout(h_proliferation)))
            h_resist_cell_death = self.LeakyReLU(self.fc_resist_cell_death_2(self.dropout(h_resist_cell_death)))
            h_vasculator = self.LeakyReLU(self.fc_vasculator_2(self.dropout(h_vasculator)))
            h_ = torch.cat((h_dna_repair, h_growth_suppressor, h_hormone_signaling, 
                          h_immune, h_inflammation, h_metabolism, h_metastasis,
                          h_plasticity, h_proliferation, h_resist_cell_death, h_vasculator),1)
    
            mean     = self.FC_mean(h_)
            log_var  = self.FC_var(h_)                   
            haz      = self.reg(mean) 
            return mean, log_var, haz, h_

    class Decoder_Surv(nn.Module):
        def __init__(self, dna_repair_input_dim, growth_suppressor_input_dim, 
                     hormone_signaling_input_dim, immune_input_dim,
                     inflammation_input_dim, metabolism_input_dim,
                     metastasis_input_dim, plasticity_input_dim,
                     proliferation_input_dim, resist_cell_death_input_dim,
                     vasculator_input_dim, hidden_dim, latent_dim):
            super(Decoder_Surv, self).__init__()
            
            self.fc_dna_repair_1 = nn.Linear(1, hidden_dim)
            self.fc_dna_repair_2 = nn.Linear(hidden_dim, dna_repair_input_dim)
    
            self.fc_growth_suppressor_1 = nn.Linear(1, hidden_dim)
            self.fc_growth_suppressor_2 = nn.Linear(hidden_dim, growth_suppressor_input_dim)
    
            self.fc_hormone_signaling_1 = nn.Linear(1, hidden_dim)
            self.fc_hormone_signaling_2 = nn.Linear(hidden_dim, hormone_signaling_input_dim)
    
            self.fc_immune_1 = nn.Linear(1, hidden_dim)
            self.fc_immune_2 = nn.Linear(hidden_dim, immune_input_dim)
    
            self.fc_inflammation_1 = nn.Linear(1, hidden_dim)
            self.fc_inflammation_2 = nn.Linear(hidden_dim, inflammation_input_dim)
    
            self.fc_metabolism_1 = nn.Linear(1, hidden_dim)
            self.fc_metabolism_2 = nn.Linear(hidden_dim, metabolism_input_dim)
    
            self.fc_metastasis_1 = nn.Linear(1, hidden_dim)
            self.fc_metastasis_2 = nn.Linear(hidden_dim, metastasis_input_dim)
    
            self.fc_plasticity_1 = nn.Linear(1, hidden_dim)
            self.fc_plasticity_2 = nn.Linear(hidden_dim, plasticity_input_dim)
    
            self.fc_proliferation_1 = nn.Linear(1, hidden_dim)
            self.fc_proliferation_2 = nn.Linear(hidden_dim, proliferation_input_dim)
    
            self.fc_resist_cell_death_1 = nn.Linear(1, hidden_dim)
            self.fc_resist_cell_death_2 = nn.Linear(hidden_dim, resist_cell_death_input_dim)
    
            self.fc_vasculator_1 = nn.Linear(1, hidden_dim)
            self.fc_vasculator_2 = nn.Linear(hidden_dim, vasculator_input_dim)
            
            self.LeakyReLU = nn.LeakyReLU(0.2)

        def forward(self, x):
            h_dna_repair = self.LeakyReLU(self.fc_dna_repair_1(x[:,0].unsqueeze(dim=1)))
            h_growth_suppressor = self.LeakyReLU(self.fc_growth_suppressor_1(x[:,1].unsqueeze(dim=1)))
            h_hormone_signaling = self.LeakyReLU(self.fc_hormone_signaling_1(x[:,2].unsqueeze(dim=1)))
            h_immune = self.LeakyReLU(self.fc_immune_1(x[:,3].unsqueeze(dim=1)))
            h_inflammation = self.LeakyReLU(self.fc_inflammation_1(x[:,4].unsqueeze(dim=1)))
            h_metabolism = self.LeakyReLU(self.fc_metabolism_1(x[:,5].unsqueeze(dim=1)))
            h_metastasis = self.LeakyReLU(self.fc_metastasis_1(x[:,6].unsqueeze(dim=1)))
            h_plasticity = self.LeakyReLU(self.fc_plasticity_1(x[:,7].unsqueeze(dim=1)))
            h_proliferation = self.LeakyReLU(self.fc_proliferation_1(x[:,8].unsqueeze(dim=1)))
            h_resist_cell_death = self.LeakyReLU(self.fc_resist_cell_death_1(x[:,9].unsqueeze(dim=1)))
            h_vasculator = self.LeakyReLU(self.fc_vasculator_1(x[:,10].unsqueeze(dim=1)))
    
            h_dna_repair = self.LeakyReLU(self.fc_dna_repair_2(h_dna_repair))
            h_growth_suppressor = self.LeakyReLU(self.fc_growth_suppressor_2(h_growth_suppressor))
            h_hormone_signaling = self.LeakyReLU(self.fc_hormone_signaling_2(h_hormone_signaling))
            h_immune = self.LeakyReLU(self.fc_immune_2(h_immune))
            h_inflammation = self.LeakyReLU(self.fc_inflammation_2(h_inflammation))
            h_metabolism = self.LeakyReLU(self.fc_metabolism_2(h_metabolism))
            h_metastasis = self.LeakyReLU(self.fc_metastasis_2(h_metastasis))
            h_plasticity = self.LeakyReLU(self.fc_plasticity_2(h_plasticity))
            h_proliferation = self.LeakyReLU(self.fc_proliferation_2(h_proliferation))
            h_resist_cell_death = self.LeakyReLU(self.fc_resist_cell_death_2(h_resist_cell_death))
            h_vasculator = self.LeakyReLU(self.fc_vasculator_2(h_vasculator))
    
    
            dna_repair_x_hat = torch.sigmoid(h_dna_repair)
            growth_suppressor_x_hat = torch.sigmoid(h_growth_suppressor)
            hormone_signaling_x_hat = torch.sigmoid(h_hormone_signaling)
            immune_x_hat = torch.sigmoid(h_immune)
            inflammation_x_hat = torch.sigmoid(h_inflammation)
            metabolism_x_hat = torch.sigmoid(h_metabolism)
            metastasis_x_hat = torch.sigmoid(h_metastasis)
            plasticity_x_hat = torch.sigmoid(h_plasticity)
            proliferation_x_hat = torch.sigmoid(h_proliferation)
            resist_cell_death_x_hat = torch.sigmoid(h_resist_cell_death)
            vasculator_x_hat = torch.sigmoid(h_vasculator)
            return dna_repair_x_hat, growth_suppressor_x_hat, hormone_signaling_x_hat, immune_x_hat, inflammation_x_hat, metabolism_x_hat, metastasis_x_hat, plasticity_x_hat, proliferation_x_hat, resist_cell_death_x_hat, vasculator_x_hat

    class Model_Surv(nn.Module):
        def __init__(self, Encoder_Surv, Decoder_Surv):
            super(Model_Surv, self).__init__()
            self.Encoder_Surv = Encoder_Surv
            self.Decoder_Surv = Decoder_Surv
    
        def reparameterization(self, mean, var):
            epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
            z = mean + var*epsilon                          # reparameterization trick
            return z
    
        def forward(self, x):
            mean, log_var, haz, h_ = self.Encoder_Surv(x)
            z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
            dna_repair_x_hat, growth_suppressor_x_hat, hormone_signaling_x_hat, immune_x_hat, inflammation_x_hat, metabolism_x_hat, metastasis_x_hat, plasticity_x_hat, proliferation_x_hat, resist_cell_death_x_hat, vasculator_x_hat = self.Decoder_Surv(z)
    
            return [dna_repair_x_hat, growth_suppressor_x_hat, hormone_signaling_x_hat, immune_x_hat, inflammation_x_hat, metabolism_x_hat, metastasis_x_hat, plasticity_x_hat, proliferation_x_hat, resist_cell_death_x_hat, vasculator_x_hat, mean, log_var, haz, h_]

    encoder_surv = Encoder_Surv(dna_repair_input_dim = dna_repair_input_dim, 
                                growth_suppressor_input_dim = growth_suppressor_input_dim, 
                     hormone_signaling_input_dim = hormone_signaling_input_dim,
                                immune_input_dim = immune_input_dim,
                     inflammation_input_dim = inflammation_input_dim,
                                metabolism_input_dim = metabolism_input_dim,
                     metastasis_input_dim = metastasis_input_dim,
                                plasticity_input_dim = plasticity_input_dim,
                     proliferation_input_dim = proliferation_input_dim,
                                resist_cell_death_input_dim = resist_cell_death_input_dim,
                     vasculator_input_dim = vasculator_input_dim,
                                hidden_dim = hidden_dim, latent_dim = latent_dim)
    decoder_surv = Decoder_Surv(dna_repair_input_dim = dna_repair_input_dim, 
                                growth_suppressor_input_dim = growth_suppressor_input_dim, 
                     hormone_signaling_input_dim = hormone_signaling_input_dim,
                                immune_input_dim = immune_input_dim,
                     inflammation_input_dim = inflammation_input_dim,
                                metabolism_input_dim = metabolism_input_dim,
                     metastasis_input_dim = metastasis_input_dim,
                                plasticity_input_dim = plasticity_input_dim,
                     proliferation_input_dim = proliferation_input_dim,
                                resist_cell_death_input_dim = resist_cell_death_input_dim,
                     vasculator_input_dim = vasculator_input_dim,
                                hidden_dim = hidden_dim, latent_dim = latent_dim)
    
    model_surv = Model_Surv(Encoder_Surv=encoder_surv, Decoder_Surv=decoder_surv).to(DEVICE)
            
    
    BCE_loss = nn.BCELoss()

    def loss_function(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11,
                      dna_repair_x_hat, growth_suppressor_x_hat, hormone_signaling_x_hat,
                      immune_x_hat, inflammation_x_hat, metabolism_x_hat,
                      metastasis_x_hat, plasticity_x_hat, proliferation_x_hat,
                      resist_cell_death_x_hat, vasculator_x_hat, mean, log_var, haz, y):
        reproduction_loss1 = nn.functional.binary_cross_entropy(dna_repair_x_hat, x1, reduction='sum')
        reproduction_loss2 = nn.functional.binary_cross_entropy(growth_suppressor_x_hat, x2, reduction='sum')
        reproduction_loss3 = nn.functional.binary_cross_entropy(hormone_signaling_x_hat, x3, reduction='sum')
        reproduction_loss4 = nn.functional.binary_cross_entropy(immune_x_hat, x4, reduction='sum')
        reproduction_loss5 = nn.functional.binary_cross_entropy(inflammation_x_hat, x5, reduction='sum')
        reproduction_loss6 = nn.functional.binary_cross_entropy(metabolism_x_hat, x6, reduction='sum')
        reproduction_loss7 = nn.functional.binary_cross_entropy(metastasis_x_hat, x7, reduction='sum')
        reproduction_loss8 = nn.functional.binary_cross_entropy(plasticity_x_hat, x8, reduction='sum')
        reproduction_loss9 = nn.functional.binary_cross_entropy(proliferation_x_hat, x9, reduction='sum')
        reproduction_loss10 = nn.functional.binary_cross_entropy(resist_cell_death_x_hat, x10, reduction='sum')
        reproduction_loss11 = nn.functional.binary_cross_entropy(vasculator_x_hat, x11, reduction='sum')
        KLD      = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
        surv_loss = neg_partial_log_likelihood(haz, y[:,0].bool(), y[:,1])
    
        return surv_loss + KLD + (reproduction_loss1/(x1.shape[0]*x1.shape[1])) + (reproduction_loss2/(x2.shape[0]*x2.shape[1])) + (reproduction_loss3/(x3.shape[0]*x3.shape[1])) + (reproduction_loss4/(x4.shape[0]*x4.shape[1])) + (reproduction_loss5/(x5.shape[0]*x5.shape[1]))+ (reproduction_loss6/(x6.shape[0]*x6.shape[1])) + (reproduction_loss7/(x7.shape[0]*x7.shape[1])) + (reproduction_loss8/(x8.shape[0]*x8.shape[1])) + (reproduction_loss9/(x9.shape[0]*x9.shape[1]))+ (reproduction_loss10/(x10.shape[0]*x10.shape[1])) + (reproduction_loss11/(x11.shape[0]*x11.shape[1]))
    
    optimizer = Adam(model_surv.parameters(), lr=0.001, weight_decay = 0.0001)

    epochs = 1000

    print("Start training VAE...")
    model_surv.train()
    torch.manual_seed(1)
    for epoch in tqdm(range(epochs)):
        overall_loss = 0
        for batch_idx, (x1, x2, x3, x4, x5, x6, x7, x8,
                        x9, x10, x11, y) in enumerate(train_loader):
            x = [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11]
            x = torch.cat(x, 1)
            x = x.to(DEVICE)
            optimizer.zero_grad()
    
            dna_repair_x_hat, growth_suppressor_x_hat, hormone_signaling_x_hat, immune_x_hat, inflammation_x_hat, metabolism_x_hat, metastasis_x_hat, plasticity_x_hat, proliferation_x_hat, resist_cell_death_x_hat, vasculator_x_hat, mean, log_var, haz, h_ = model_surv(x)
    
            loss = loss_function(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11,
                      dna_repair_x_hat, growth_suppressor_x_hat, hormone_signaling_x_hat,
                      immune_x_hat, inflammation_x_hat, metabolism_x_hat,
                      metastasis_x_hat, plasticity_x_hat, proliferation_x_hat,
                      resist_cell_death_x_hat, vasculator_x_hat, mean, log_var, haz, y)
            
            overall_loss += loss.item()
            
            loss.backward()
            optimizer.step()
        
    print("Finish!!")
    torch.manual_seed(1)
    model_surv.eval()
    with torch.no_grad():
        x_train = [dna_repair_norm_train, growth_suppressor_norm_train, hormone_signaling_norm_train,
                  immune_norm_train, inflammation_norm_train, metabolism_norm_train,
                   metastasis_norm_train, plasticity_norm_train,
                  proliferation_norm_train, resist_cell_death_norm_train, vasculator_norm_train]
        x_train = torch.cat(x_train, 1)
        dna_repair_x_hat, growth_suppressor_x_hat, hormone_signaling_x_hat, immune_x_hat, inflammation_x_hat, metabolism_x_hat, metastasis_x_hat, plasticity_x_hat, proliferation_x_hat, resist_cell_death_x_hat, vasculator_x_hat, mean, log_var, haz, train_latent = model_surv(x_train)
        cindex = ConcordanceIndex()
        y_path_train = haz
        ci = cindex(haz, y_train[:,0].bool(), y_train[:,1])
        print(ci)
       

    return model_surv 


In [21]:
torch.manual_seed(1)
mod = SurvVAE_Total(dna_repair_input_dim = dna_repair_x_dim, growth_suppressor_input_dim = growth_suppressor_x_dim, 
                 hormone_signaling_input_dim = hormone_signaling_x_dim, immune_input_dim = immune_x_dim,
                 inflammation_input_dim = inflammation_x_dim, metabolism_input_dim = metabolism_x_dim,
                 metastasis_input_dim = metastasis_x_dim, plasticity_input_dim = plasticity_x_dim,
                 proliferation_input_dim = proliferation_x_dim, resist_cell_death_input_dim = resist_cell_death_x_dim,
                 vasculator_input_dim = vasculator_x_dim, hidden_dim = 10, latent_dim = 11, train_loader = train_loader,
                   DEVICE = DEVICE)

Start training VAE...


100%|██████████| 1000/1000 [01:18<00:00, 12.71it/s]

Finish!!
tensor(0.6985)





In [22]:
torch.save(mod.state_dict(), '..//Documents//Survival VAE//FINAL PHENOSURV//lung_model_weights_PC_FINAL_adeno.pth')

In [23]:
torch.manual_seed(1) 
mod.eval() 
with torch.no_grad():
    x_train = [dna_repair_norm_train, growth_suppressor_norm_train, hormone_signaling_norm_train,
              immune_norm_train, inflammation_norm_train, metabolism_norm_train,
               metastasis_norm_train, plasticity_norm_train,
              proliferation_norm_train, resist_cell_death_norm_train, vasculator_norm_train]
    x_train = torch.cat(x_train, 1)
    dna_repair_x_hat, growth_suppressor_x_hat, hormone_signaling_x_hat, immune_x_hat, inflammation_x_hat, metabolism_x_hat, metastasis_x_hat, plasticity_x_hat, proliferation_x_hat, resist_cell_death_x_hat, vasculator_x_hat, mean, log_var, train_haz, train_latent = mod(x_train)
    cindex = ConcordanceIndex()
    ci = cindex(train_haz, y_train[:,0].bool(), y_train[:,1])
    print(ci)

tensor(0.6985)


In [24]:
train_preds = train_latent.numpy()
train_preds = pd.DataFrame(train_preds)
y_train = y_train.numpy()
y_train = pd.DataFrame(y_train)
train_data = pd.concat([train_preds, y_train], axis = 1)

In [25]:
train_data.columns = ['DNA_REPAIR', 'EVADE_GROWTH_SUPPRESSORS', 'HORMONE_SIGNALING',
                    'IMMUNE', 'INFLAMMATION', 'METABOLISM', 'METASTASIS',
                    'PLASTICITY', 'PROLIFERATION', 'RESIST_CELL_DEATH', 'VASCULATOR',
                   'STATUS', 'TIME']

train_data.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_train_data_PC_FINAL_adeno.csv', index=False)

In [26]:
class Encoder_Surv(nn.Module):
    def __init__(self, dna_repair_input_dim, growth_suppressor_input_dim, 
                 hormone_signaling_input_dim, immune_input_dim,
                 inflammation_input_dim, metabolism_input_dim,
                 metastasis_input_dim, plasticity_input_dim,
                 proliferation_input_dim, resist_cell_death_input_dim,
                 vasculator_input_dim, hidden_dim, latent_dim):
        super(Encoder_Surv, self).__init__()
        self.dropout = nn.Dropout(0.2)
        self.fc_dna_repair_1 = nn.Linear(dna_repair_input_dim, hidden_dim)
        self.fc_dna_repair_2 = nn.Linear(hidden_dim, 1)

        self.fc_growth_suppressor_1 = nn.Linear(growth_suppressor_input_dim, hidden_dim)
        self.fc_growth_suppressor_2 = nn.Linear(hidden_dim, 1)

        self.fc_hormone_signaling_1 = nn.Linear(hormone_signaling_input_dim, hidden_dim)
        self.fc_hormone_signaling_2 = nn.Linear(hidden_dim, 1)

        self.fc_immune_1 = nn.Linear(immune_input_dim, hidden_dim)
        self.fc_immune_2 = nn.Linear(hidden_dim, 1)

        self.fc_inflammation_1 = nn.Linear(inflammation_input_dim, hidden_dim)
        self.fc_inflammation_2 = nn.Linear(hidden_dim, 1)

        self.fc_metabolism_1 = nn.Linear(metabolism_input_dim, hidden_dim)
        self.fc_metabolism_2 = nn.Linear(hidden_dim, 1)

        self.fc_metastasis_1 = nn.Linear(metastasis_input_dim, hidden_dim)
        self.fc_metastasis_2 = nn.Linear(hidden_dim, 1)

        self.fc_plasticity_1 = nn.Linear(plasticity_input_dim, hidden_dim)
        self.fc_plasticity_2 = nn.Linear(hidden_dim, 1)

        self.fc_proliferation_1 = nn.Linear(proliferation_input_dim, hidden_dim)
        self.fc_proliferation_2 = nn.Linear(hidden_dim, 1)

        self.fc_resist_cell_death_1 = nn.Linear(resist_cell_death_input_dim, hidden_dim)
        self.fc_resist_cell_death_2 = nn.Linear(hidden_dim, 1)

        self.fc_vasculator_1 = nn.Linear(vasculator_input_dim, hidden_dim)
        self.fc_vasculator_2 = nn.Linear(hidden_dim, 1)

        self.LeakyReLU = nn.LeakyReLU(0.2)
        
        self.FC_mean  = nn.Linear(latent_dim, latent_dim)
        self.FC_var   = nn.Linear(latent_dim, latent_dim)
        
        self.training = True
        self.reg = nn.Linear(latent_dim, 1)
        
        self.dna_repair_stop = dna_repair_input_dim 
        self.growth_suppressor_stop = growth_suppressor_input_dim + self.dna_repair_stop 
        self.hormone_signaling_stop = hormone_signaling_input_dim + self.growth_suppressor_stop 
        self.immune_stop = immune_input_dim + self.hormone_signaling_stop
        self.inflammation_stop = inflammation_input_dim + self.immune_stop 
        self.metabolism_stop = metabolism_input_dim + self.inflammation_stop 
        self.metastasis_stop = metastasis_input_dim + self.metabolism_stop 
        self.plasticity_stop = plasticity_input_dim + self.metastasis_stop 
        self.proliferation_stop = proliferation_input_dim + self.plasticity_stop 
        self.resist_cell_death_stop = resist_cell_death_input_dim + self.proliferation_stop 
        self.vasculator_stop = vasculator_input_dim + self.resist_cell_death_stop 
        
    def forward(self, x):
        h_dna_repair = self.LeakyReLU(self.fc_dna_repair_1(self.dropout(x[:,0:self.dna_repair_stop])))
        h_growth_suppressor = self.LeakyReLU(self.fc_growth_suppressor_1(self.dropout(x[:,(self.dna_repair_stop):self.growth_suppressor_stop])))
        h_hormone_signaling = self.LeakyReLU(self.fc_hormone_signaling_1(self.dropout(x[:,(self.growth_suppressor_stop):self.hormone_signaling_stop])))
        h_immune = self.LeakyReLU(self.fc_immune_1(self.dropout(x[:,(self.hormone_signaling_stop):self.immune_stop])))
        h_inflammation = self.LeakyReLU(self.fc_inflammation_1(self.dropout(x[:,(self.immune_stop):self.inflammation_stop])))
        h_metabolism = self.LeakyReLU(self.fc_metabolism_1(self.dropout(x[:,(self.inflammation_stop):self.metabolism_stop])))
        h_metastasis = self.LeakyReLU(self.fc_metastasis_1(self.dropout(x[:,(self.metabolism_stop):self.metastasis_stop])))
        h_plasticity = self.LeakyReLU(self.fc_plasticity_1(self.dropout(x[:,(self.metastasis_stop):self.plasticity_stop])))
        h_proliferation = self.LeakyReLU(self.fc_proliferation_1(self.dropout(x[:,(self.plasticity_stop):self.proliferation_stop])))
        h_resist_cell_death = self.LeakyReLU(self.fc_resist_cell_death_1(self.dropout(x[:,(self.proliferation_stop):self.resist_cell_death_stop])))
        h_vasculator = self.LeakyReLU(self.fc_vasculator_1(self.dropout(x[:,(self.resist_cell_death_stop):self.vasculator_stop])))

        h_dna_repair = self.LeakyReLU(self.fc_dna_repair_2(self.dropout(h_dna_repair)))
        h_growth_suppressor = self.LeakyReLU(self.fc_growth_suppressor_2(self.dropout(h_growth_suppressor)))
        h_hormone_signaling = self.LeakyReLU(self.fc_hormone_signaling_2(self.dropout(h_hormone_signaling)))
        h_immune = self.LeakyReLU(self.fc_immune_2(self.dropout(h_immune)))
        h_inflammation = self.LeakyReLU(self.fc_inflammation_2(self.dropout(h_inflammation)))
        h_metabolism = self.LeakyReLU(self.fc_metabolism_2(self.dropout(h_metabolism)))
        h_metastasis = self.LeakyReLU(self.fc_metastasis_2(self.dropout(h_metastasis)))
        h_plasticity = self.LeakyReLU(self.fc_plasticity_2(self.dropout(h_plasticity)))
        h_proliferation = self.LeakyReLU(self.fc_proliferation_2(self.dropout(h_proliferation)))
        h_resist_cell_death = self.LeakyReLU(self.fc_resist_cell_death_2(self.dropout(h_resist_cell_death)))
        h_vasculator = self.LeakyReLU(self.fc_vasculator_2(self.dropout(h_vasculator)))
        h_ = torch.cat((h_dna_repair, h_growth_suppressor, h_hormone_signaling, 
                      h_immune, h_inflammation, h_metabolism, h_metastasis,
                      h_plasticity, h_proliferation, h_resist_cell_death, h_vasculator),1)
        mean     = self.FC_mean(h_)
        log_var  = self.FC_var(h_)                   
        haz      = self.reg(mean) 
        return mean, log_var, haz, h_


class Decoder_Surv(nn.Module):
    def __init__(self, dna_repair_input_dim, growth_suppressor_input_dim, 
                 hormone_signaling_input_dim, immune_input_dim,
                 inflammation_input_dim, metabolism_input_dim,
                 metastasis_input_dim, plasticity_input_dim,
                 proliferation_input_dim, resist_cell_death_input_dim,
                 vasculator_input_dim, hidden_dim, latent_dim):
        super(Decoder_Surv, self).__init__()
        
        self.fc_dna_repair_1 = nn.Linear(1, hidden_dim)
        self.fc_dna_repair_2 = nn.Linear(hidden_dim, dna_repair_input_dim)

        self.fc_growth_suppressor_1 = nn.Linear(1, hidden_dim)
        self.fc_growth_suppressor_2 = nn.Linear(hidden_dim, growth_suppressor_input_dim)

        self.fc_hormone_signaling_1 = nn.Linear(1, hidden_dim)
        self.fc_hormone_signaling_2 = nn.Linear(hidden_dim, hormone_signaling_input_dim)

        self.fc_immune_1 = nn.Linear(1, hidden_dim)
        self.fc_immune_2 = nn.Linear(hidden_dim, immune_input_dim)

        self.fc_inflammation_1 = nn.Linear(1, hidden_dim)
        self.fc_inflammation_2 = nn.Linear(hidden_dim, inflammation_input_dim)

        self.fc_metabolism_1 = nn.Linear(1, hidden_dim)
        self.fc_metabolism_2 = nn.Linear(hidden_dim, metabolism_input_dim)

        self.fc_metastasis_1 = nn.Linear(1, hidden_dim)
        self.fc_metastasis_2 = nn.Linear(hidden_dim, metastasis_input_dim)

        self.fc_plasticity_1 = nn.Linear(1, hidden_dim)
        self.fc_plasticity_2 = nn.Linear(hidden_dim, plasticity_input_dim)

        self.fc_proliferation_1 = nn.Linear(1, hidden_dim)
        self.fc_proliferation_2 = nn.Linear(hidden_dim, proliferation_input_dim)

        self.fc_resist_cell_death_1 = nn.Linear(1, hidden_dim)
        self.fc_resist_cell_death_2 = nn.Linear(hidden_dim, resist_cell_death_input_dim)

        self.fc_vasculator_1 = nn.Linear(1, hidden_dim)
        self.fc_vasculator_2 = nn.Linear(hidden_dim, vasculator_input_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)

    def forward(self, x):
        h_dna_repair = self.LeakyReLU(self.fc_dna_repair_1(x[:,0].unsqueeze(dim=1)))
        h_growth_suppressor = self.LeakyReLU(self.fc_growth_suppressor_1(x[:,1].unsqueeze(dim=1)))
        h_hormone_signaling = self.LeakyReLU(self.fc_hormone_signaling_1(x[:,2].unsqueeze(dim=1)))
        h_immune = self.LeakyReLU(self.fc_immune_1(x[:,3].unsqueeze(dim=1)))
        h_inflammation = self.LeakyReLU(self.fc_inflammation_1(x[:,4].unsqueeze(dim=1)))
        h_metabolism = self.LeakyReLU(self.fc_metabolism_1(x[:,5].unsqueeze(dim=1)))
        h_metastasis = self.LeakyReLU(self.fc_metastasis_1(x[:,6].unsqueeze(dim=1)))
        h_plasticity = self.LeakyReLU(self.fc_plasticity_1(x[:,7].unsqueeze(dim=1)))
        h_proliferation = self.LeakyReLU(self.fc_proliferation_1(x[:,8].unsqueeze(dim=1)))
        h_resist_cell_death = self.LeakyReLU(self.fc_resist_cell_death_1(x[:,9].unsqueeze(dim=1)))
        h_vasculator = self.LeakyReLU(self.fc_vasculator_1(x[:,10].unsqueeze(dim=1)))

        h_dna_repair = self.LeakyReLU(self.fc_dna_repair_2(h_dna_repair))
        h_growth_suppressor = self.LeakyReLU(self.fc_growth_suppressor_2(h_growth_suppressor))
        h_hormone_signaling = self.LeakyReLU(self.fc_hormone_signaling_2(h_hormone_signaling))
        h_immune = self.LeakyReLU(self.fc_immune_2(h_immune))
        h_inflammation = self.LeakyReLU(self.fc_inflammation_2(h_inflammation))
        h_metabolism = self.LeakyReLU(self.fc_metabolism_2(h_metabolism))
        h_metastasis = self.LeakyReLU(self.fc_metastasis_2(h_metastasis))
        h_plasticity = self.LeakyReLU(self.fc_plasticity_2(h_plasticity))
        h_proliferation = self.LeakyReLU(self.fc_proliferation_2(h_proliferation))
        h_resist_cell_death = self.LeakyReLU(self.fc_resist_cell_death_2(h_resist_cell_death))
        h_vasculator = self.LeakyReLU(self.fc_vasculator_2(h_vasculator))


        dna_repair_x_hat = torch.sigmoid(h_dna_repair)
        growth_suppressor_x_hat = torch.sigmoid(h_growth_suppressor)
        hormone_signaling_x_hat = torch.sigmoid(h_hormone_signaling)
        immune_x_hat = torch.sigmoid(h_immune)
        inflammation_x_hat = torch.sigmoid(h_inflammation)
        metabolism_x_hat = torch.sigmoid(h_metabolism)
        metastasis_x_hat = torch.sigmoid(h_metastasis)
        plasticity_x_hat = torch.sigmoid(h_plasticity)
        proliferation_x_hat = torch.sigmoid(h_proliferation)
        resist_cell_death_x_hat = torch.sigmoid(h_resist_cell_death)
        vasculator_x_hat = torch.sigmoid(h_vasculator)
        return dna_repair_x_hat, growth_suppressor_x_hat, hormone_signaling_x_hat, immune_x_hat, inflammation_x_hat, metabolism_x_hat, metastasis_x_hat, plasticity_x_hat, proliferation_x_hat, resist_cell_death_x_hat, vasculator_x_hat

class Model_Surv(nn.Module):
    def __init__(self, Encoder_Surv, Decoder_Surv):
        super(Model_Surv, self).__init__()
        self.Encoder_Surv = Encoder_Surv
        self.Decoder_Surv = Decoder_Surv

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
        z = mean + var*epsilon                          # reparameterization trick
        return z

    def forward(self, x):
        mean, log_var, haz, h_ = self.Encoder_Surv(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        dna_repair_x_hat, growth_suppressor_x_hat, hormone_signaling_x_hat, immune_x_hat, inflammation_x_hat, metabolism_x_hat, metastasis_x_hat, plasticity_x_hat, proliferation_x_hat, resist_cell_death_x_hat, vasculator_x_hat = self.Decoder_Surv(z)

        return h_

encoder_surv = Encoder_Surv(dna_repair_input_dim = dna_repair_x_dim, 
                            growth_suppressor_input_dim = growth_suppressor_x_dim, 
                 hormone_signaling_input_dim = hormone_signaling_x_dim,
                            immune_input_dim = immune_x_dim,
                 inflammation_input_dim = inflammation_x_dim,
                            metabolism_input_dim = metabolism_x_dim,
                 metastasis_input_dim = metastasis_x_dim,
                            plasticity_input_dim = plasticity_x_dim,
                 proliferation_input_dim = proliferation_x_dim,
                            resist_cell_death_input_dim = resist_cell_death_x_dim,
                 vasculator_input_dim = vasculator_x_dim,
                            hidden_dim = 10, latent_dim = 11)
decoder_surv = Decoder_Surv(dna_repair_input_dim = dna_repair_x_dim, 
                            growth_suppressor_input_dim = growth_suppressor_x_dim, 
                 hormone_signaling_input_dim = hormone_signaling_x_dim,
                            immune_input_dim = immune_x_dim,
                 inflammation_input_dim = inflammation_x_dim,
                            metabolism_input_dim = metabolism_x_dim,
                 metastasis_input_dim = metastasis_x_dim,
                            plasticity_input_dim = plasticity_x_dim,
                 proliferation_input_dim = proliferation_x_dim,
                            resist_cell_death_input_dim = resist_cell_death_x_dim,
                 vasculator_input_dim = vasculator_x_dim,
                            hidden_dim = 10, latent_dim = 11)

model_surv = Model_Surv(Encoder_Surv=encoder_surv, Decoder_Surv=decoder_surv)

In [27]:
x = [Tensor(dna_repair_norm_train),
                        Tensor(growth_suppressor_norm_train),
                        Tensor(hormone_signaling_norm_train),
                        Tensor(immune_norm_train),
                        Tensor(inflammation_norm_train),
                        Tensor(metabolism_norm_train),
                        Tensor(metastasis_norm_train),
                        Tensor(plasticity_norm_train),
                        Tensor(proliferation_norm_train),
                        Tensor(resist_cell_death_norm_train),
                        Tensor(vasculator_norm_train)]
x = torch.cat(x, 1)

In [28]:
background = x[:50]
test_images = x[50:80]

In [29]:
model_surv.load_state_dict(torch.load('..//Documents//Survival VAE//FINAL PHENOSURV//lung_model_weights_PC_FINAL_adeno.pth', weights_only=False))

<All keys matched successfully>

In [30]:
import shap
e = shap.DeepExplainer(model_surv, background)

  from .autonotebook import tqdm as notebook_tqdm


In [31]:
shap_values = e.shap_values(test_images, check_additivity=False)
shap1, shap2, shap3, shap4, shap5, shap6, shap7, shap8, shap9, shap10, shap11 = np.dsplit(shap_values,11)

In [32]:
shap_df1 = pd.DataFrame(shap1.squeeze().T)
shap_df2 = pd.DataFrame(shap2.squeeze().T)
shap_df3 = pd.DataFrame(shap3.squeeze().T)
shap_df4 = pd.DataFrame(shap4.squeeze().T)
shap_df5 = pd.DataFrame(shap5.squeeze().T)
shap_df6 = pd.DataFrame(shap6.squeeze().T)
shap_df7 = pd.DataFrame(shap7.squeeze().T)
shap_df8 = pd.DataFrame(shap8.squeeze().T)
shap_df9 = pd.DataFrame(shap9.squeeze().T)
shap_df10 = pd.DataFrame(shap10.squeeze().T)
shap_df11 = pd.DataFrame(shap11.squeeze().T)

In [33]:
shap_df1['Pathway'] = dna_repair_pathways + growth_suppressor_pathways + hormone_signaling_pathways + immune_pathways + inflammation_pathways + metabolism_pathways + metastasis_pathways + plasticity_pathways + proliferation_pathways + resist_cell_death_pathways + vasculator_pathways
shap_df2['Pathway'] = dna_repair_pathways + growth_suppressor_pathways + hormone_signaling_pathways + immune_pathways + inflammation_pathways + metabolism_pathways + metastasis_pathways + plasticity_pathways + proliferation_pathways + resist_cell_death_pathways + vasculator_pathways
shap_df3['Pathway'] = dna_repair_pathways + growth_suppressor_pathways + hormone_signaling_pathways + immune_pathways + inflammation_pathways + metabolism_pathways + metastasis_pathways + plasticity_pathways + proliferation_pathways + resist_cell_death_pathways + vasculator_pathways
shap_df4['Pathway'] = dna_repair_pathways + growth_suppressor_pathways + hormone_signaling_pathways + immune_pathways + inflammation_pathways + metabolism_pathways + metastasis_pathways + plasticity_pathways + proliferation_pathways + resist_cell_death_pathways + vasculator_pathways
shap_df5['Pathway'] = dna_repair_pathways + growth_suppressor_pathways + hormone_signaling_pathways + immune_pathways + inflammation_pathways + metabolism_pathways + metastasis_pathways + plasticity_pathways + proliferation_pathways + resist_cell_death_pathways + vasculator_pathways
shap_df6['Pathway'] = dna_repair_pathways + growth_suppressor_pathways + hormone_signaling_pathways + immune_pathways + inflammation_pathways + metabolism_pathways + metastasis_pathways + plasticity_pathways + proliferation_pathways + resist_cell_death_pathways + vasculator_pathways
shap_df7['Pathway'] = dna_repair_pathways + growth_suppressor_pathways + hormone_signaling_pathways + immune_pathways + inflammation_pathways + metabolism_pathways + metastasis_pathways + plasticity_pathways + proliferation_pathways + resist_cell_death_pathways + vasculator_pathways
shap_df8['Pathway'] = dna_repair_pathways + growth_suppressor_pathways + hormone_signaling_pathways + immune_pathways + inflammation_pathways + metabolism_pathways + metastasis_pathways + plasticity_pathways + proliferation_pathways + resist_cell_death_pathways + vasculator_pathways
shap_df9['Pathway'] = dna_repair_pathways + growth_suppressor_pathways + hormone_signaling_pathways + immune_pathways + inflammation_pathways + metabolism_pathways + metastasis_pathways + plasticity_pathways + proliferation_pathways + resist_cell_death_pathways + vasculator_pathways
shap_df10['Pathway'] = dna_repair_pathways + growth_suppressor_pathways + hormone_signaling_pathways + immune_pathways + inflammation_pathways + metabolism_pathways + metastasis_pathways + plasticity_pathways + proliferation_pathways + resist_cell_death_pathways + vasculator_pathways
shap_df11['Pathway'] = dna_repair_pathways + growth_suppressor_pathways + hormone_signaling_pathways + immune_pathways + inflammation_pathways + metabolism_pathways + metastasis_pathways + plasticity_pathways + proliferation_pathways + resist_cell_death_pathways + vasculator_pathways

In [34]:
shap_df1.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_test_shapvalues_dna_repair_PC_FINAL_adeno.csv', index=False)
shap_df2.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_test_shapvalues_growth_suppressor_PC_FINAL_adeno.csv', index=False)
shap_df3.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_test_shapvalues_hormone_signaling_PC_FINAL_adeno.csv', index=False)
shap_df4.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_test_shapvalues_immune_PC_FINAL_adeno.csv', index=False)
shap_df5.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_test_shapvalues_inflammation_PC_FINAL_adeno.csv', index=False)
shap_df6.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_test_shapvalues_metabolism_PC_FINAL_adeno.csv', index=False)
shap_df7.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_test_shapvalues_metastasis_PC_FINAL_adeno.csv', index=False)
shap_df8.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_test_shapvalues_plasticity_PC_FINAL_adeno.csv', index=False)
shap_df9.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_test_shapvalues_proliferation_PC_FINAL_adeno.csv', index=False)
shap_df10.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_test_shapvalues_resist_cell_death_PC_FINAL_adeno.csv', index=False)
shap_df11.to_csv('~//Documents//Survival VAE//FINAL PHENOSURV//lung_test_shapvalues_vasculator_PC_FINAL_adeno.csv', index=False)