In the README of this package I demonstrated how to use the multisnpnet-Cox-tv package to fit a Cox model with user-provided regularization parameters and with the time-varying covariates provided in the form of an operation time. Here I will show how to use this package to fit a Cox-Lasso path using BASIL to screen the SNPs.

In [1]:
library(coxtv)
library(pgenlibr)
library(data.table)
library(tidyverse)

phe.file = "/oak/stanford/groups/mrivas/ukbb24983/phenotypedata/master_phe/cox/phenotypefiles/f.131298.0.0.phe"
death.file = "/oak/stanford/groups/mrivas/projects/ukbb-phenotyping/20200404_icd_death/ukb41413_icd_death.tsv"
masterphe.file = "/oak/stanford/groups/mrivas/ukbb24983/phenotypedata/master_phe/master.phe"
genotype.pfile = "/oak/stanford/groups/mrivas/ukbb24983/array_combined/pgen/ukb24983_cal_hla_cnv"
psamid = data.table::fread(paste0(genotype.pfile, '.psam'),colClasses = list(character=c("IID")), 
                  select = c("IID"))
psamid = psamid$IID

covs = c("sex", "age", "PC1", "PC2", "PC3", "PC4", "PC5", "PC6", "PC7", "PC8", "PC9", "PC10")

train_ratio = 0.8


configs = list()
configs[['gcount.full.prefix']] = '/scratch/users/ruilinli/tvtest/gcount/test'
configs[['plink2.path']] = "/scratch/users/ruilinli/prox_grad_cox_block/plink2"
configs[['nCores']] = 6
configs[['mem']] = 60000
configs[['vzs']] = TRUE
configs[['save']] =TRUE
configs[['zstdcat.path']] = "/home/groups/mrivas/software/anaconda2_sherlock2/bin/zstdcat"
configs[['save.computeProduct']] = TRUE
configs[['results.dir']] = "/scratch/users/ruilinli/tvtest/result/"
configs[['save.dir']] = "/scratch/users/ruilinli/tvtest/save"
configs[['KKT.verbose']] = TRUE
configs[['endian']]="little"
configs[["standardize.variant"]] = FALSE
configs[['missing.rate']] = 0.1
configs[['MAF.thresh']] = 0.001

── [1mAttaching packages[22m ─────────────────────────────────────── tidyverse 1.3.0 ──

[32m✔[39m [34mggplot2[39m 3.3.0     [32m✔[39m [34mpurrr  [39m 0.3.3
[32m✔[39m [34mtibble [39m 2.1.3     [32m✔[39m [34mdplyr  [39m 0.8.5
[32m✔[39m [34mtidyr  [39m 1.0.2     [32m✔[39m [34mstringr[39m 1.4.0
[32m✔[39m [34mreadr  [39m 1.3.1     [32m✔[39m [34mforcats[39m 0.4.0

── [1mConflicts[22m ────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mbetween()[39m   masks [34mdata.table[39m::between()
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m    masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mfirst()[39m     masks [34mdata.table[39m::first()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m       masks [34mstats[39m::lag()
[31m✖[39m [34mdplyr[39m::[32mlast()[39m      masks [34mdata.table[39m::last()
[31m✖[39m [34mpurrr[39m::[32mtranspose()[39m masks [34mdata.table[39m::transpose(

In [2]:
# Some preprocessing
phe = data.table::fread(phe.file, 
                  colClasses = list(character=c("FID"), numeric=c("coxnet_y_f.131298.0.0", "coxnet_status_f.131298.0.0")), 
                  select = c("FID", "coxnet_y_f.131298.0.0", "coxnet_status_f.131298.0.0"))
names(phe) = c("ID", "t0", "MI")
phe = filter(phe, ID %in% psamid)

event = data.table::fread(death.file, 
                  colClasses = list(character=c("#IID"), numeric=c("val")), 
                  select = c("#IID", "val"))
names(event) = c("ID", "val")

masterphe = data.table::fread(masterphe.file, 
                  colClasses = list(character=c("FID"), numeric=covs), 
                  select = c("FID", covs))
names(masterphe) = c("ID", covs)


In [3]:
phe = filter(phe, MI == 1) # people with MI
phe = select(phe, -MI) # Don't need the MI indicator anymore, since every in the dataset had MI
phe$status = 0


eID = which(phe$ID %in% event$ID) # people with MI that also had event = death
phe$status[eID] = 1
phe$t1 = NA # The event time
phe$t1[eID] = event$val[match(phe$ID[eID], event$ID)]
phe$t1[-eID] = masterphe$age[match(phe$ID[-eID], masterphe$ID)] + 1.1219 # last followup time for people who did not have the event



In [4]:
# add covariates to the phe file
covs = covs[covs!="age"]
for (cov in covs){
        phe[,cov] <- masterphe[match(phe$ID, masterphe$ID), ..cov]
}

In [5]:
phe$y = phe$t1 - phe$t0
phe = filter(phe, t1 > 0)
phe = phe[complete.cases(phe),]
min_event_time = min(phe$y[as.logical(phe$status)])
phe = filter(phe, y>= max(0, min_event_time)) # non-events before the first event will never be used


In [6]:
# Now we get the time-varying covariates into the format we need
tv_files = c("/oak/stanford/groups/mrivas/projects/primary_care/gp_clinical/example_phenotypes/final/LDL.tsv",
            "/oak/stanford/groups/mrivas/projects/primary_care/gp_clinical/example_phenotypes/final/Weight.tsv")

tv_list = list()
i = 1
for(tvfile in tv_files){
    tv = data.table::fread(tvfile, 
                  colClasses = list(character=c("id"), numeric=c("age", "value")), 
                  select = c("id", "age", "value"))
    names(tv) = c("ID", "time", "value")
    tv = filter(tv, !is.na(value))
    bounds = quantile(tv$value, c(0.002, 0.999)) # remove extreme observations
    print(bounds) 
    tv = filter(tv, (value > bounds[1]) & (value < bounds[2]))
    # take the intersection
    phe = filter(phe, ID %in% tv$ID)
    tv = filter(tv, ID %in% phe$ID)
    tv$time = tv$time - phe$t0[match(tv$ID, phe$ID)]
    tv_list[[i]] = tv
    i = i + 1
}

 0.2% 99.9% 
  0.0   6.9 
 0.2% 99.9% 
    0   165 


In [7]:
# We split the data into training and validation set, the original split column may not be appropriate here
# since we are using a small subset of founders (because only a small number of people had MI)
total_events = sum(phe$status)
non_events = nrow(phe) - total_events
train_id = sample(phe$ID[as.logical(phe$status)], round(total_events*train_ratio))
train_id = c(train_id, sample(phe$ID[!as.logical(phe$status)], round(non_events*train_ratio)))
val_id = filter(phe, ! ID %in% train_id)$ID

phe$split = 'train'
phe$split[phe$ID %in% val_id] = 'val'

phe_val = as.data.table(filter(phe, split=='val'))
phe_train = as.data.table(filter(phe, split=='train'))



tv_train = list()
tv_val = list()
for(i in (1:length(tv_list))){
    tv_train[[i]] = filter(tv_list[[i]], ID %in% phe_train$ID)
    tv_val[[i]] = filter(tv_list[[i]], ID %in% phe_val$ID)
}
info_train = coxtv::get_info(phe_train, tv_train)
info_val = coxtv::get_info(phe_val, tv_val)


rm(tv_list)
rm(phe)

“2438 people do not have time-varying covariates measured before the first event. The most recent measurement after the event is used.”
“1474 people do not have time-varying covariates measured before the first event. The most recent measurement after the event is used.”
“599 people do not have time-varying covariates measured before the first event. The most recent measurement after the event is used.”
“363 people do not have time-varying covariates measured before the first event. The most recent measurement after the event is used.”


Now we have the data in the right format to be fed to the coxtv functions. To summarize, we need
- A dataframe (here phe_train and phe_val) that has columns:
    - y that contains the time-to-event response
    - status, a binary vector that represents whether event has occured
    - ID, which will be used to identify each person in this dataframe
    - some covariates columns that are time independent
- A list that contains the time-varying covariates (here tv_train and tv_val). Each element of the list corresponds to one time-varying covariate and must have the columns:
    - ID, same set of IDs as used in the phe data frame
    - time, the time at which the measurement was taken (relative to each person's t0)
    - value, the value of the measurements at the corresponding time
- A set of time-independent covariates names that will be used to fit a Cox model, these names must be available in the phe dataframet 
- (Optional, to save some compute) A list info that is obtained using coxtv::get_info(phe, tv_list). This list contains information about the time-varying covariates in the form that coxtv can readily use to fit a model. Keep info can save some computation, especially when n or the number of events is large. For now the user needs to make sure that when fitting a Cox model, info must correspond to phe and tv_list. Otherwise a segmentation fault might happen. In future version I might encapsulate BASIL into a package to solve this problem.

##### It is important that the users make sure that each person must have at least one measurement before the first event time, for each of the time-varying covariate! If this is not satisfied, a warning will be thrown and the most recent measurement after the event will be used.

##### Now let's fit the first iteration, which is supposed to be unpenalized:

In [8]:
covs = c("t0", covs)

In [9]:
result = coxtv(phe_train, NULL, covs, c(0.0), info=info_train) #lambda is the last argument, if info is provided then the second parameter is not needed

In [10]:
as.matrix(result[[1]])

0,1
TV1,0.0022575249
TV2,-0.0008927381
t0,0.1277776373
sex,0.3273922625
PC1,0.002249861
PC2,0.0032595022
PC3,-0.0097584263
PC4,-0.0229086856
PC5,-0.0109442029
PC6,-0.0467834069


Now let's fit a Lasso path using BASIL to screen the SNPs. First we compute the residual:

In [11]:
source("/scratch/users/ruilinli/prox_grad_cox_block/snpnet/R/functions.R")
# Need to use some snpnet helper functions


residuals = cox_residual(phe_train, covs, info_train, result[[1]]) # 1/n already multiplied in this residual
# to compute the gradient (with respect to the SNP coefficients) let's first load the genotype files
# code copied from snpnet
vars <- dplyr::mutate(dplyr::rename(data.table::fread(cmd=paste0(configs[['zstdcat.path']], ' ', paste0(genotype.pfile, '.pvar.zst'))), 'CHROM'='#CHROM'), VAR_ID=paste(ID, ALT, sep='_'))$VAR_ID
pvar <- pgenlibr::NewPvar(paste0(genotype.pfile, '.pvar.zst'))
pgen_train = pgenlibr::NewPgen(paste0(genotype.pfile, '.pgen'), pvar=pvar, sample_subset=match(phe_train$ID, psamid))
pgen_val = pgenlibr::NewPgen(paste0(genotype.pfile, '.pgen'), pvar=pvar, sample_subset=match(phe_val$ID, psamid))


pgenlibr::ClosePvar(pvar)    

stats <- computeStats(genotype.pfile, paste(phe_train$ID, phe_train$ID, sep="_"), configs = configs)

In [12]:
residuals = matrix(residuals,nrow = length(phe_train$ID), ncol = 1, dimnames = list(paste(phe_train$ID, phe_train$ID, sep='_'), c("0")))
gradient =computeProduct(residuals, genotype.pfile, vars, stats, configs, iter=0)

[2020-04-30 13:27:43 snpnet]     Start computeProduct()
           used  (Mb) gc trigger  (Mb) max used  (Mb)
Ncells  2623258 140.1    7186685 383.9  4123368 220.3
Vcells 31143942 237.7   56498432 431.1 49372753 376.7
[2020-04-30 13:27:43 snpnet]       Start plink2 --variant-score
[2020-04-30 13:30:22 snpnet]         End plink2 --variant-score. Time elapsed: 2.6477 mins
[2020-04-30 13:30:22 snpnet]       End computeProduct(). Time elapsed: 2.6503 mins


In [13]:
# Now we can generate a lambda sequence
nlambda = 100
lambda.min.ratio = 0.01
lambda_max = max(abs(gradient), na.rm = NA)
lambda_min = lambda_max * lambda.min.ratio
lambda_seq = exp(seq(from = log(lambda_max), to = log(lambda_min), length.out = nlambda))

# The first lambda solution is already obtained
max_valid_index = 1
prev_valid_index = 0

# Use validation C-index to determine early stop
max_cindex = 0
cindex = numeric(nlambda)
out = list()
out[[1]] = result[[1]]
features.to.discard = NULL

In [14]:
score = abs(gradient[,1])
iter = 1
ever.active = covs
print(ever.active)
current_B = result[[1]]

 [1] "t0"   "sex"  "PC1"  "PC2"  "PC3"  "PC4"  "PC5"  "PC6"  "PC7"  "PC8" 
[11] "PC9"  "PC10"


In [15]:
while(max_valid_index < nlambda){
    if(max_valid_index > prev_valid_index){
        for(i in (prev_valid_index + 1):max_valid_index){
            cindex[i] = cindex_tv(phe_val, NULL, covs, out[[i]], info=info_val)
        }
        
        cindex_this_iter = cindex[(prev_valid_index + 1):max_valid_index]
        
        max_cindex_this_iter = max(cindex_this_iter)
        if(max_cindex_this_iter >= max_cindex){
            max_cindex = max_cindex_this_iter
        } else {
            print("early stop reached")
            break
        }
        
        if(which.max(cindex_this_iter) != length(cindex_this_iter)){
            print("early stop reached")
            break            
        }
    }
    prev_valid_index = max_valid_index
    print(paste("current maximum valid index is:",max_valid_index ))
    print("Current C-Indices are:")
    print(cindex[1:max_valid_index])

    if(length(features.to.discard) > 0){
        phe_train[, (features.to.discard) := NULL]
        phe_val[, (features.to.discard) := NULL]
        covs = covs[!covs %in% features.to.discard] # the name is a bit confusing, maybe change it to ti_names?
    }
    
    which.in.model <- which(names(score) %in% covs)
    score[which.in.model] <- NA
    sorted.score <- sort(score, decreasing = T, na.last = NA)
    features.to.add <- names(sorted.score)[1:min(1000, length(sorted.score))]
    covs = c(covs, features.to.add)
    B_init = c(current_B, rep(0.0, length(features.to.add)))
    
    tmp.features.add <- prepareFeatures(pgen_train, vars, features.to.add, stats)
    phe_train[, colnames(tmp.features.add) := tmp.features.add]
    
    tmp.features.add <- prepareFeatures(pgen_val, vars, features.to.add, stats)
    phe_val[, colnames(tmp.features.add) := tmp.features.add]
    
    rm(tmp.features.add)
    
    # Not fit a regularized Cox model for the next 10 lambdas
    lambda_seq_local = lambda_seq[(max_valid_index + 1):min(max_valid_index + 10, length(lambda_seq))]
    # Need better ways to set p.fac
    p.fac = rep(1, length(B_init))
    p.fac[1:14] = 0.0
    print(paste("Number of variables to be fitted is:",length(B_init)))


    result = coxtv(phe_train, NULL, covs, lambda_seq_local, B0 = B_init, p.fac = p.fac, info=info_train)
    
    residuals = matrix(nrow = length(phe_train$ID), ncol = length(lambda_seq_local), 
                       dimnames = list(paste(phe_train$ID, phe_train$ID, sep='_'), signif(lambda_seq_local, 3)))
    for(i in 1:length(result)){
        residuals[,i] = cox_residual(phe_train, covs, info_train, result[[i]])
    }
    new_score = abs(computeProduct(residuals, genotype.pfile, vars, stats, configs, iter=iter))
    max_score = apply(new_score, 2, function(x){max(x[!names(x) %in% covs], na.rm=NA)})
    print(max_score)
    # if all failed
    if(all(max_score > lambda_seq_local)){
        features.to.discard = NULL
        current_B = result[[1]]
        score = new_score[, 1]
    } else {
        local_valid = which.min(c(max_score <= lambda_seq_local, FALSE)) - 1 # number of valid this iteration
        
        for(j in 1:local_valid){
            out[[max_valid_index+j]] = result[[j]]
        }
        
        max_valid_index = max_valid_index + local_valid
        ever.active <- union(ever.active, names(which(apply(sapply(result[1:local_valid, drop=F], function(x){x!=0}), 1, any))))
        features.to.discard = setdiff(covs, ever.active)
        score = new_score[, local_valid]
        current_B = result[[local_valid]]
        current_B = current_B[!names(current_B) %in% features.to.discard]
        print(paste("Number of features discarded in this iteration is", length(features.to.discard)))
    }
    iter = iter + 1

}

[1] "current maximum valid index is: 1"
[1] "Current C-Indices are:"
[1] 0.7681487
[1] "Number of variables to be fitted is: 1014"
[2020-04-30 13:32:54 snpnet]     Start computeProduct()
           used  (Mb) gc trigger  (Mb) max used  (Mb)
Ncells  2639400 141.0    7186685 383.9  5088656 271.8
Vcells 41945038 320.1   81533741 622.1 80496732 614.2
[2020-04-30 13:32:54 snpnet]       Start plink2 --variant-score
[2020-04-30 13:35:12 snpnet]         End plink2 --variant-score. Time elapsed: 2.2911 mins
[2020-04-30 13:35:13 snpnet]       End computeProduct(). Time elapsed: 2.3099 mins
lambda_idx_0.0155 lambda_idx_0.0148 lambda_idx_0.0141 lambda_idx_0.0135 
      0.009576207       0.009633053       0.009704023       0.009845673 
lambda_idx_0.0129 lambda_idx_0.0123 lambda_idx_0.0117 lambda_idx_0.0112 
      0.010046540       0.010166109       0.010284680       0.010346207 
lambda_idx_0.0107 lambda_idx_0.0102 
      0.010277798       0.010504970 
[1] "Number of features discarded in this itera