diff --git a/.travis.yml b/.travis.yml index 7871b5512..ebf585975 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,9 +15,15 @@ r: - devel env: - - KERAS_BACKEND="tensorflow" global: - - MAKEFLAGS="-j 2" + - KERAS_BACKEND="tensorflow" + - MAKEFLAGS="-j 2" + +# until we troubleshoot these issues +matrix: + allow_failures: + - r: 3.1 + - r: 3.2 r_binary_packages: - rstan diff --git a/NAMESPACE b/NAMESPACE index 1a9dc2bc8..88b02f030 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -22,6 +22,7 @@ S3method(predict_confint,model_fit) S3method(predict_num,"_elnet") S3method(predict_num,model_fit) S3method(predict_predint,model_fit) +S3method(predict_quantile,model_fit) S3method(predict_raw,"_elnet") S3method(predict_raw,"_lognet") S3method(predict_raw,"_multnet") @@ -95,6 +96,8 @@ export(predict_num) export(predict_num.model_fit) export(predict_predint) export(predict_predint.model_fit) +export(predict_quantile) +export(predict_quantile.model_fit) export(predict_raw) export(predict_raw.model_fit) export(rand_forest) @@ -113,10 +116,12 @@ import(rlang) importFrom(dplyr,arrange) importFrom(dplyr,as_tibble) importFrom(dplyr,bind_cols) +importFrom(dplyr,bind_rows) importFrom(dplyr,collect) importFrom(dplyr,full_join) importFrom(dplyr,funs) importFrom(dplyr,group_by) +importFrom(dplyr,mutate) importFrom(dplyr,pull) importFrom(dplyr,rename) importFrom(dplyr,rename_at) @@ -159,6 +164,7 @@ importFrom(stats,predict) importFrom(stats,qnorm) importFrom(stats,qt) importFrom(stats,quantile) +importFrom(stats,setNames) importFrom(stats,terms) importFrom(stats,update) importFrom(tibble,as_tibble) diff --git a/R/aaa_spark_helpers.R b/R/aaa_spark_helpers.R index 4257d7c93..fe3d5b455 100644 --- a/R/aaa_spark_helpers.R +++ b/R/aaa_spark_helpers.R @@ -3,12 +3,10 @@ #' @importFrom dplyr starts_with rename rename_at vars funs format_spark_probs <- function(results, object) { results <- dplyr::select(results, starts_with("probability_")) - results <- dplyr::rename_at( - results, - vars(starts_with("probability_")), - funs(gsub("probability", "pred", .)) - ) - results + p <- ncol(results) + lvl <- paste0("probability_", 0:(p - 1)) + names(lvl) <- paste0("pred_", object$fit$.index_labels) + results %>% rename(!!!syms(lvl)) } format_spark_class <- function(results, object) { diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 4676dfb05..fce3d77bf 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -8,10 +8,14 @@ form_form <- function(object, control, env, ...) { opts <- quos(...) - y_levels <- levels_from_formula( # prob rewrite this as simple subset/levels - env$formula, - env$data - ) + if (object$mode != "regression") { + y_levels <- levels_from_formula( # prob rewrite this as simple subset/levels + env$formula, + env$data + ) + } else { + y_levels <- NULL + } object <- check_mode(object, y_levels) diff --git a/R/misc.R b/R/misc.R index 06307a1e6..5748cae92 100644 --- a/R/misc.R +++ b/R/misc.R @@ -178,3 +178,15 @@ check_args <- function(object) { check_args.default <- function(object) { invisible(object) } + +# ------------------------------------------------------------------------------ + +# copied form recipes + +names0 <- function (num, prefix = "x") { + if (num < 1) + stop("`num` should be > 0", call. = FALSE) + ind <- format(1:num) + ind <- gsub(" ", "0", ind) + paste0(prefix, ind) +} diff --git a/R/predict.R b/R/predict.R index fec443e8d..5dfd42823 100644 --- a/R/predict.R +++ b/R/predict.R @@ -7,8 +7,8 @@ #' @param object An object of class `model_fit` #' @param new_data A rectangular data object, such as a data frame. #' @param type A single character value or `NULL`. Possible values -#' are "numeric", "class", "probs", "conf_int", "pred_int", or -#' "raw". When `NULL`, `predict` will choose an appropriate value +#' are "numeric", "class", "probs", "conf_int", "pred_int", "quantile", +#' or "raw". When `NULL`, `predict` will choose an appropriate value #' based on the model's mode. #' @param opts A list of optional arguments to the underlying #' predict function that will be used when `type = "raw"`. The @@ -45,6 +45,10 @@ #' produces for class probabilities (or other non-scalar outputs), #' the columns will be named `.pred_lower_classlevel` and so on. #' +#' Quantile predictions return a tibble with a column `.pred`, which is +#' a list-column. Each list element contains a tibble with columns +#' `.pred` and `.quantile` (and perhaps others). +#' #' Using `type = "raw"` with `predict.model_fit` (or using #' `predict_raw`) will return the unadulterated results of the #' prediction function. @@ -96,6 +100,7 @@ predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ... prob = predict_classprob(object = object, new_data = new_data, ...), conf_int = predict_confint(object = object, new_data = new_data, ...), pred_int = predict_predint(object = object, new_data = new_data, ...), + quantile = predict_quantile(object = object, new_data = new_data, ...), raw = predict_raw(object = object, new_data = new_data, opts = opts, ...), stop("I don't know about type = '", "'", type, call. = FALSE) ) @@ -112,7 +117,8 @@ predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ... res } -pred_types <- c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int") +pred_types <- + c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int", "quantile") #' @importFrom glue glue_collapse check_pred_type <- function(object, type) { diff --git a/R/predict_quantile.R b/R/predict_quantile.R new file mode 100644 index 000000000..ed8cfdbe3 --- /dev/null +++ b/R/predict_quantile.R @@ -0,0 +1,41 @@ +#' @keywords internal +#' @rdname other_predict +#' @param quant A vector of numbers between 0 and 1 for the quantile being +#' predicted. +#' @inheritParams predict.model_fit +#' @method predict_quantile model_fit +#' @export predict_quantile.model_fit +#' @export +predict_quantile.model_fit <- + function (object, new_data, quantile = (1:9)/10, ...) { + + if (is.null(object$spec$method$quantile)) + stop("No quantile prediction method defined for this ", + "engine.", call. = FALSE) + + new_data <- prepare_data(object, new_data) + + # preprocess data + if (!is.null(object$spec$method$quantile$pre)) + new_data <- object$spec$method$quantile$pre(new_data, object) + + # Pass some extra arguments to be used in post-processor + object$spec$method$quantile$args$p <- quantile + pred_call <- make_pred_call(object$spec$method$quantile) + + res <- eval_tidy(pred_call) + + # post-process the predictions + if(!is.null(object$spec$method$quantile$post)) { + res <- object$spec$method$quantile$post(res, object) + } + + res + } + +#' @export +#' @keywords internal +#' @rdname other_predict +#' @inheritParams predict.model_fit +predict_quantile <- function (object, ...) + UseMethod("predict_quantile") diff --git a/R/surv_reg.R b/R/surv_reg.R index 07aad237e..29c3489ab 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -25,16 +25,39 @@ #' `strata` function cannot be used. To achieve the same effect, #' the extra parameter roles can be used (as described above). #' -#' The model can be created using the `fit()` function using the -#' following _engines_: -#' \itemize{ -#' \item \pkg{R}: `"flexsurv"` -#' } #' @inheritParams boost_tree #' @param mode A single character string for the type of model. #' The only possible value for this model is "regression". #' @param dist A character string for the outcome distribution. "weibull" is #' the default. +#' @details +#' For `surv_reg`, the mode will always be "regression". +#' +#' The model can be created using the `fit()` function using the +#' following _engines_: +#' \itemize{ +#' \item \pkg{R}: `"flexsurv"`, `"survreg"` +#' } +#' +#' @section Engine Details: +#' +#' Engines may have pre-set default arguments when executing the +#' model fit call. These can be changed by using the `...` +#' argument to pass in the preferred values. For this type of +#' model, the template of the fit calls are: +#' +#' \pkg{flexsurv} +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")} +#' +#' \pkg{survreg} +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survreg")} +#' +#' Note that `model = TRUE` is needed to produce quantile +#' predictions when there is a stratification variable and can be +#' overridden in other cases. +#' #' @seealso [varying()], [fit()], [survival::Surv()] #' @references Jackson, C. (2016). `flexsurv`: A Platform for Parametric Survival #' Modeling in R. _Journal of Statistical Software_, 70(8), 1 - 33. @@ -160,3 +183,51 @@ check_args.surv_reg <- function(object) { invisible(object) } + +# ------------------------------------------------------------------------------ + +#' @importFrom stats setNames +#' @importFrom dplyr mutate +survreg_quant <- function(results, object) { + pctl <- object$spec$method$quantile$args$p + n <- nrow(results) + p <- ncol(results) + results <- + results %>% + as_tibble() %>% + setNames(names0(p)) %>% + mutate(.row = 1:n) %>% + gather(.label, .pred, -.row) %>% + arrange(.row, .label) %>% + mutate(.quantile = rep(pctl, n)) %>% + dplyr::select(-.label) + .row <- results[[".row"]] + results <- + results %>% + dplyr::select(-.row) + results <- split(results, .row) + names(results) <- NULL + tibble(.pred = results) +} + +# ------------------------------------------------------------------------------ + +#' @importFrom dplyr bind_rows +flexsurv_mean <- function(results, object) { + results <- unclass(results) + results <- bind_rows(results) + results$est +} + +#' @importFrom stats setNames +flexsurv_quant <- function(results, object) { + results <- map(results, as_tibble) + names(results) <- NULL + results <- map(results, setNames, c(".quantile", ".pred", ".pred_lower", ".pred_upper")) +} + +# ------------------------------------------------------------------------------ + +#' @importFrom utils globalVariables +utils::globalVariables(".label") + diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 73f7e7d6b..43f55cecb 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -1,6 +1,7 @@ surv_reg_arg_key <- data.frame( - flexsurv = c("dist"), + flexsurv = c("dist"), + survreg = c("dist"), stringsAsFactors = FALSE, row.names = c("dist") ) @@ -9,6 +10,7 @@ surv_reg_modes <- "regression" surv_reg_engines <- data.frame( flexsurv = TRUE, + survreg = TRUE, stringsAsFactors = TRUE, row.names = c("regression") ) @@ -23,5 +25,96 @@ surv_reg_flexsurv_data <- protect = c("formula", "data", "weights"), func = c(pkg = "flexsurv", fun = "flexsurvreg"), defaults = list() + ), + pred = list( + pre = NULL, + post = flexsurv_mean, + func = c(fun = "summary"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "mean" + ) + ), + quantile = list( + pre = NULL, + post = flexsurv_quant, + func = c(fun = "summary"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "quantile", + quantiles = expr(quantile) + ) ) ) + +# ------------------------------------------------------------------------------ + +surv_reg_survreg_data <- + list( + libs = c("survival"), + fit = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "survival", fun = "survreg"), + defaults = list(model = TRUE) + ), + pred = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "response" + ) + ), + quantile = list( + pre = NULL, + post = survreg_quant, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "quantile", + p = expr(quantile) + ) + ) + ) + +# ------------------------------------------------------------------------------ + +# surv_reg_stan_data <- +# list( +# libs = c("brms"), +# fit = list( +# interface = "formula", +# protect = c("formula", "data", "weights"), +# func = c(pkg = "brms", fun = "brm"), +# defaults = list( +# family = expr(brms::weibull()), +# seed = expr(sample.int(10^5, 1)) +# ) +# ), +# pred = list( +# pre = NULL, +# post = function(results, object) { +# tibble::as_tibble(results) %>% +# dplyr::select(Estimate) %>% +# setNames(".pred") +# }, +# func = c(fun = "predict"), +# args = +# list( +# object = expr(object$fit), +# newdata = expr(new_data), +# type = "response" +# ) +# ) +# ) + diff --git a/docs/articles/articles/Classification.html b/docs/articles/articles/Classification.html index 4728624de..ed83c3d28 100644 --- a/docs/articles/articles/Classification.html +++ b/docs/articles/articles/Classification.html @@ -178,17 +178,17 @@