In [1]:
import numpy as np
import pandas as pd
import random
#import argparse
from sklearn.model_selection import KFold, GridSearchCV
from sklearn.linear_model import ElasticNet, ElasticNetCV
#from sksurv.linear_model import CoxnetSurvivalAnalysis as CoxPH
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from utils import *
from model_functions import *

random.seed(7)

In [2]:
from mvdr.ajive.AJIVE import AJIVE

In [3]:
scaler = StandardScaler()

prots_train_df = pd.read_csv("Data/proteins_metabolites_train.csv", index_col = 'eid')
metabolites_train = scaler.fit_transform(prots_train_df.loc[:, 'Clinical_LDL_C':])
proteins_train =  scaler.fit_transform(prots_train_df.loc[:, : 'ZBTB17'])
prots_val_df = pd.read_csv("Data/proteins_metabolites_set2.csv", index_col = 'eid')
metabolites_set2 =  scaler.fit_transform(prots_val_df.loc[:, 'Clinical_LDL_C':])
proteins_set2 = scaler.fit_transform(prots_val_df.loc[:, : 'ZBTB17'])

In [4]:
#Initial signal ranks based on PCA
aj10 = AJIVE(init_signal_ranks = [10, 10], center = True).fit([proteins_train, metabolites_train]) 
aj20 = AJIVE(init_signal_ranks = [20, 20], center = True).fit([proteins_train, metabolites_train])

In [11]:
#aj10.summary()
aj20.summary()

'AJIVE, joint rank: 8, view 0 indiv rank: 13, view 1 indiv rank: 14'

In [12]:
aj10andersom = AJIVE(init_signal_ranks = [10, 10], center = True).fit([metabolites_train, proteins_train]) 
aj10andersom.summary()

'AJIVE, joint rank: 2, view 0 indiv rank: 9, view 1 indiv rank: 8'

In [14]:
aj20.common_.view_loadings_[0].shape
aj20.view_specific_[0].individual_.loadings_

array([[ 0.04617623,  0.05959563,  0.00725135, ...,  0.00897097,
         0.03368331,  0.02009749],
       [ 0.03068757,  0.03582454, -0.0090216 , ..., -0.07948508,
         0.07375266, -0.00139573],
       [ 0.07466134, -0.05115334, -0.00469359, ...,  0.01898551,
         0.01817606, -0.00487485],
       ...,
       [ 0.10960103, -0.10055149, -0.01068108, ..., -0.01236189,
         0.07109607,  0.03170717],
       [ 0.01660382,  0.03316564,  0.01055851, ...,  0.03509621,
        -0.08635792,  0.06777673],
       [ 0.04988403,  0.02400433,  0.17362339, ..., -0.00822945,
        -0.02978541, -0.03501392]])

In [3]:
def apply_ajive(train_ajive, protset, metset, eidset):
    # make array
    protset_array = np.array(protset)
    metset_array = np.array(metset)
    
    #Joint loadings
    common_loadings = train_ajive.common_.view_loadings_
    prot_common_loadings = common_loadings[0]
    met_common_loadings = common_loadings[1]
    prot_scores = proteins_train @ prot_common_loadings
    met_scores = metabolites_train @ met_common_loadings
    scores = np.add(prot_scores, met_scores)
    common_ajive = pd.DataFrame(scores * (1/np.linalg.norm(scores, axis=0)), columns = [f'common_comp_{i}' for i in range(common_loadings[0].shape[1])])
    
    #Individual loadings proteins
    view_0_pca = train_ajive.view_specific_[0].individual_
    prot_indiv_loadings_array = np.array(view_0_pca.loadings_) # this are the loadings vectors, you can now manually prject
    #Individual loadings metabolites
    view_1_pca = train_ajive.view_specific_[1].individual_
    met_indiv_loadings_array = np.array(view_1_pca.loadings_) # this are the loadings vectors, you can now manually prject   
    
    #Multiply
    proteins_ind_array = np.dot(protset_array, prot_indiv_loadings_array)
    metabolites_ind_array = np.dot(metset, met_indiv_loadings_array)
    
    # Make dataframe with same names as train
    proteins_ind_ajive = pd.DataFrame(proteins_ind_array, columns=[f'prot_indiv_comp_{i}' for i in range(prot_indiv_loadings_array.shape[1])])
    metabo_ind_ajive = pd.DataFrame(metabolites_ind_array, columns=[f'met_indiv_comp_{i}' for i in range(met_indiv_loadings_array.shape[1])])
    
    ajive_df = pd.concat(
        [pd.DataFrame(eidset.index, columns=['eid']),
         common_ajive,
         proteins_ind_ajive,
         metabo_ind_ajive
        ], 
        axis=1
    )
    return(ajive_df)

In [6]:
train_df10 = apply_ajive(aj10, proteins_train, metabolites_train, prots_train_df) #to validate function
set2_df10 = apply_ajive(aj10, proteins_set2, metabolites_set2, prots_val_df)
train_df20 = apply_ajive(aj20, proteins_train, metabolites_train, prots_train_df) #to validate function
set2_df20 = apply_ajive(aj20, proteins_set2, metabolites_set2, prots_val_df)

In [7]:
train_df10

Unnamed: 0,eid,common_comp_0,common_comp_1,prot_indiv_comp_0,prot_indiv_comp_1,prot_indiv_comp_2,prot_indiv_comp_3,prot_indiv_comp_4,prot_indiv_comp_5,prot_indiv_comp_6,prot_indiv_comp_7,met_indiv_comp_0,met_indiv_comp_1,met_indiv_comp_2,met_indiv_comp_3,met_indiv_comp_4,met_indiv_comp_5,met_indiv_comp_6,met_indiv_comp_7,met_indiv_comp_8
0,1004826.0,-0.014590,0.009190,-2.295032,-0.181259,-3.156176,-0.729482,-5.181388,1.197846,-2.459575,-4.313563,0.739892,-3.812825,6.868530,2.953371,0.343638,-1.165441,-0.246077,0.594332,3.048320
1,1008846.0,-0.002533,0.007394,-0.179740,4.288007,-2.794244,1.981952,-2.931818,1.699250,-1.861359,-1.140551,-0.752580,-6.197115,0.202377,-0.382413,-0.562534,-1.483049,0.610820,-0.478063,-0.150475
2,1009759.0,-0.010791,0.010788,7.969363,-0.628423,0.255349,-1.067967,5.025046,-2.185289,3.125351,0.381738,-6.478258,4.418395,12.625224,6.469841,4.162171,-5.536678,0.184981,-3.084732,5.075778
3,1010092.0,-0.001702,-0.009557,0.734608,-7.064392,-1.651508,-3.611970,-3.214327,0.716006,-0.926133,0.147869,-10.569365,-2.629327,-3.164699,-3.283982,-3.410241,1.375359,-0.040631,-0.103084,-1.622879
4,1013849.0,-0.012581,0.005670,-12.790275,4.363558,-4.255856,-1.667149,-3.055727,-0.963758,0.401739,-0.922648,-13.238631,0.261297,5.382949,4.445263,-1.621560,-1.949476,-0.976898,0.680405,3.042420
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15055,2283833.0,0.000034,-0.002047,4.774565,-0.810435,0.786640,-3.179956,-1.189937,1.230472,3.043736,2.680127,7.419017,-6.983900,-0.650400,-2.182430,1.143292,0.087070,-0.535263,1.164170,-0.811403
15056,5520133.0,0.015068,-0.001645,9.602181,5.558214,-3.426143,4.308532,0.418078,-4.309248,-0.143049,-0.072106,6.674860,11.543647,-4.864849,-3.195221,-2.599063,2.013852,-2.163485,1.061025,-0.829818
15057,2176533.0,-0.010009,0.002106,-0.874586,-8.744317,-1.034583,-2.031086,-1.379083,-1.250721,-2.485160,0.859011,3.953948,-3.440687,4.507848,1.488227,-0.176245,-1.050540,-1.947282,0.029716,2.069998
15058,4888291.0,-0.005175,-0.008782,-4.808610,-9.133658,2.128831,-0.475566,0.912156,-0.594169,-0.563737,-2.660702,-7.308661,-2.151861,-3.538833,1.896550,0.470213,0.946013,-1.163970,-1.046711,-1.789294


In [None]:
aj20

In [8]:
train_df10.to_csv('Data/ajive10_prot_met_train.csv', index = False)
set2_df10.to_csv('Data/ajive10_prot_met_set2.csv', index = False)
train_df20.to_csv('Data/ajive20_prot_met_train.csv', index = False)
set2_df20.to_csv('Data/ajive20_prot_met_set2.csv', index = False)

In [4]:
def get_data_ajive(config):
    prefix = "Data/Processed/Full"

    normalize = True
    if config['dset'] == "cmb_met_ajive10":
        prots_train_df = pd.read_csv('Data/ajive10_prot_met_train.csv')
        prots_val_df = pd.read_csv('Data/ajive10_prot_met_set2.csv')
    elif config['dset'] == "cmb_met_ajive20":
        prots_train_df = pd.read_csv('Data/ajive20_prot_met_train.csv')
        prots_val_df = pd.read_csv('Data/ajive20_prot_met_set2.csv')
       
    else:
        get_data(config)
    
    
    if config['target'] == "mort":
        target_train = pd.read_csv(prefix + "/mort_full_train.csv", index_col = 'eid')
        target_val = pd.read_csv(prefix + "/mort_full_test.csv", index_col = 'eid')
        #mort_test = pd.read_csv("Data/Processed/Full/mort_full_val.csv", index_col = 'eid')

    elif config['target'] == "frailty":
        target_train = pd.read_csv("Data/frailty_clean_train.csv", index_col = 'eid')
        target_val = pd.read_csv("Data/frailty_clean_set2.csv", index_col = 'eid')


        #add age to dset if we train for frailty
        basicinfo = pd.read_csv("Data/basicinfo_instance_0.csv", index_col = "eid")

        prots_train_df = prots_train_df.merge(basicinfo['age_center.0.0'], on = 'eid')
        prots_val_df = prots_val_df.merge(basicinfo['age_center.0.0'], on = 'eid')



    full_train, train_eids = preprocess(prots_train_df, target_train, target = config['target'] , normalize = normalize)
    full_val, val_eids = preprocess(prots_val_df, target_val, target = config['target'], normalize = normalize)


    return full_train, full_val, train_eids, val_eids
     

In [6]:
prots_train_df = pd.read_csv('Data/ajive10_prot_met_train.csv', index_col = 'eid')
prots_train_df

Unnamed: 0_level_0,common_comp_0,common_comp_1,prot_indiv_comp_0,prot_indiv_comp_1,prot_indiv_comp_2,prot_indiv_comp_3,prot_indiv_comp_4,prot_indiv_comp_5,prot_indiv_comp_6,prot_indiv_comp_7,met_indiv_comp_0,met_indiv_comp_1,met_indiv_comp_2,met_indiv_comp_3,met_indiv_comp_4,met_indiv_comp_5,met_indiv_comp_6,met_indiv_comp_7,met_indiv_comp_8
eid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
1004826.0,-0.014590,0.009190,-2.295032,-0.181259,-3.156176,-0.729482,-5.181388,1.197846,-2.459575,-4.313563,0.739892,-3.812825,6.868530,2.953371,0.343638,-1.165441,-0.246077,0.594332,3.048320
1008846.0,-0.002533,0.007394,-0.179740,4.288007,-2.794244,1.981952,-2.931818,1.699250,-1.861359,-1.140551,-0.752580,-6.197115,0.202377,-0.382413,-0.562534,-1.483049,0.610820,-0.478063,-0.150475
1009759.0,-0.010791,0.010788,7.969363,-0.628423,0.255349,-1.067967,5.025046,-2.185289,3.125351,0.381738,-6.478258,4.418395,12.625224,6.469841,4.162171,-5.536678,0.184981,-3.084732,5.075778
1010092.0,-0.001702,-0.009557,0.734608,-7.064392,-1.651508,-3.611970,-3.214327,0.716006,-0.926133,0.147869,-10.569365,-2.629327,-3.164699,-3.283982,-3.410241,1.375359,-0.040631,-0.103084,-1.622879
1013849.0,-0.012581,0.005670,-12.790275,4.363558,-4.255856,-1.667149,-3.055727,-0.963758,0.401739,-0.922648,-13.238631,0.261297,5.382949,4.445263,-1.621560,-1.949476,-0.976898,0.680405,3.042420
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2283833.0,0.000034,-0.002047,4.774565,-0.810435,0.786640,-3.179956,-1.189937,1.230472,3.043736,2.680127,7.419017,-6.983900,-0.650400,-2.182430,1.143292,0.087070,-0.535263,1.164170,-0.811403
5520133.0,0.015068,-0.001645,9.602181,5.558214,-3.426143,4.308532,0.418078,-4.309248,-0.143049,-0.072106,6.674860,11.543647,-4.864849,-3.195221,-2.599063,2.013852,-2.163485,1.061025,-0.829818
2176533.0,-0.010009,0.002106,-0.874586,-8.744317,-1.034583,-2.031086,-1.379083,-1.250721,-2.485160,0.859011,3.953948,-3.440687,4.507848,1.488227,-0.176245,-1.050540,-1.947282,0.029716,2.069998
4888291.0,-0.005175,-0.008782,-4.808610,-9.133658,2.128831,-0.475566,0.912156,-0.594169,-0.563737,-2.660702,-7.308661,-2.151861,-3.538833,1.896550,0.470213,0.946013,-1.163970,-1.046711,-1.789294


In [10]:
prots_train_df = pd.read_csv('Data/ajive20_prot_met_train.csv')


In [12]:
prots_train_df.shape

(15060, 36)

In [15]:
from utils import get_data

data = get_data({'dset' : "cmb_met_ajive", "target" : "frailty"})

data[0][0].shape

(15050, 20)

In [5]:
def find_best_model_modified(dset, target, bootstrap=False):
    # Retrieve data
    trainset, set2, eids_train, eids_set2 = get_data_ajive({'dset': dset, 'target': target})
    trainset_names = get_trainsetnames(dset, target)
    colname = f'en_{target}_{dset}'
     # Get the number of samples from the 'cmb_met' dataset
    cmb_met_data, _, _, _ = get_data({'dset': 'cmb_met', 'target': target})
    num_samples = cmb_met_data[0].shape[0]  # Get the number of participants

    # Initialize storage for results if using bootstrap
    combined_coefs = []
    combined_scores = []
    combined_new_col = []
    combined_r2_or_C = []



    # Perform bootstrap sampling if requested
    if bootstrap:
        for i in range(100):
            # Sample the training data with replacement using the number of samples from 'cmb_met'
            np.random.seed(i)  # Set the seed for reproducibility
            indices = np.random.choice(trainset[0].shape[0], size=num_samples, replace=True)
            sampled_trainset_X = trainset[0][indices]
            sampled_trainset_y = trainset[1][indices]
            if target == 'mort':
                sampled_trainset_event = trainset[2][indices]
                sampled_trainset = (sampled_trainset_X, sampled_trainset_y, sampled_trainset_event)
            else:
                sampled_trainset = (sampled_trainset_X, sampled_trainset_y)
                
            if target == 'frailty':
                best_coefs, scores, R2_save, new_col = find_best_lm_model(sampled_trainset, trainset_names, set2, colname)
                combined_r2_or_C.append(R2_save)
                np.save(f"./output_linear/rsquared/R2_{dset}.npy", R2_save)
            elif target == 'mort':
                best_model, best_coefs, scores, C_save, new_col = find_best_coxph_model(sampled_trainset, trainset_names, set2, colname)
                combined_r2_or_C.append(C_save)
            else:
                raise ValueError(f"Target '{target}' not recognized. Use 'frailty' or 'mort'.")

            combined_coefs.append(best_coefs)
            combined_scores.append(scores)
            combined_new_col.append(new_col)
        
        # Combine the results
        combined_coefs = pd.concat(combined_coefs, axis=1)
        combined_scores = np.array(combined_scores)
        combined_new_col = pd.concat(combined_new_col, axis=1)
        combined_r2_or_C = np.array(combined_r2_or_C)
        np.save(f"./output_linear/bootstrap/{target}/combined_scores_{target}_{dset}.npy", combined_scores)
        combined_coefs.to_csv(f"./output_linear/bootstrap/{target}/coefs_{target}_{dset}.csv")
        combined_new_col.to_csv(f"./output_linear/bootstrap/{target}/en_{target}_{dset}_set2.csv")
        np.save(f"./output_linear/bootstrap/{target}/metric_{target}_{dset}.npy", combined_r2_or_C)


    elif bootstrap == False and target == 'frailty':
        # Find best linear model
        best_coefs, scores, R2_save, new_col = find_best_lm_model(trainset, trainset_names, set2, colname)

        # Save results
        np.save(f"./output_linear/scores_frail/scores_frailty_{dset}.npy", scores)
        best_coefs.to_csv(f'./output_linear/coefs_frail/coefs_frailty_{dset}.csv')
        new_col.to_csv(f'./output_linear/set2frail/en_frailty_{dset}_set2.csv')
        np.save(f"./output_linear/rsquared/R2_{dset}.npy", R2_save)

    elif bootstrap == False and target == 'mort':
        # Find best Cox proportional hazards model
        best_model, best_coefs, scores, C_save, new_col = find_best_coxph_model(trainset, trainset_names, set2, colname)

        # Save results
        np.save(f"./output_linear/scores_mort/scores_mort_{dset}.npy", scores)
        best_coefs.to_csv(f'./output_linear/coefs_mort/coefs_mort_{dset}.csv')
        new_col.to_csv(f'./output_linear/set2mort/en_mort_{dset}_set2.csv')
        np.save(f"./output_linear/concordance/C_mort_{dset}.npy", C_save)
        np.save(f"./output_linear/bestmodel_mort/best_model_mort_{dset}.npy", best_model)

    else:
        raise ValueError(f"Target '{target}' not recognized. Use 'frailty' or 'mort'.")

In [12]:
find_best_model_modified('cmb_met_ajive10', 'frailty')
find_best_model_modified('cmb_met_ajive20', 'frailty')

In [13]:
find_best_model_modified('cmb_met_ajive10', 'mort')
find_best_model_modified('cmb_met_ajive20', 'mort')

  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_

In [14]:
find_best_model('cmb_met_pca20', 'frailty')

In [None]:
find_best_model('cmb_met_pca20', 'mort')

  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_params)
  estimator.fit(X_train, y_train, **fit_