From 304f956a0a8cf33234ebb31dbfec2720acab8fe4 Mon Sep 17 00:00:00 2001 From: Nima Hejazi Date: Thu, 8 Feb 2018 15:06:52 -0800 Subject: [PATCH] unbreak the breakage --- R/adaptest.R | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/R/adaptest.R b/R/adaptest.R index ff10ee9..243d02a 100644 --- a/R/adaptest.R +++ b/R/adaptest.R @@ -36,11 +36,11 @@ data_adapt <- function(Y, negative, parameter_wrapper, SL_lib) { - if (!data.table::is.data.table(Y)) { + if (!is.data.frame(Y)) { if (!is.matrix(Y)) { stop("Argument Y must be a data.frame or a matrix.") } - Y <- (Y) + Y <- as.matrix(Y) } if (!is.vector(A)) stop("Argument A must be numeric.") if (!is.null(W)) if (!is.matrix(W)) stop("Argument W must be matrix.") @@ -362,7 +362,6 @@ adaptest <- function(Y, #' @param SL_lib character of \code{SuperLearner} library #' #' @importFrom origami training validation -#' @importFrom data.table as.data.table #' @importFrom tmle tmle # cv_param_est <- function(fold, @@ -375,6 +374,7 @@ cv_param_est <- function(fold, Y_name, A_name, W_name) { + # define training and validation sets based on input object of class "folds" param_data <- origami::training(data) estim_data <- origami::validation(data) @@ -382,11 +382,11 @@ cv_param_est <- function(fold, # get param generating data (NOTE: these are data.table's) A_param <- param_data[, grep(A_name, colnames(data))] Y_param <- param_data[, grep(Y_name, colnames(data))] - W_param <- param_data[, grep(W_name, colnames(data))] + W_param <- param_data[, grep(W_name, colnames(data)), FALSE] # get estimation data (NOTE: these are data.table's) A_estim <- estim_data[, grep(A_name, colnames(data))] Y_estim <- estim_data[, grep(Y_name, colnames(data))] - W_estim <- estim_data[, grep(W_name, colnames(data))] + W_estim <- estim_data[, grep(W_name, colnames(data)), FALSE] # generate data-adaptive target parameter data_adaptive_index <- parameter_wrapper(