Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@ r:
- devel

env:
- KERAS_BACKEND="tensorflow"
global:
- MAKEFLAGS="-j 2"
- KERAS_BACKEND="tensorflow"
- MAKEFLAGS="-j 2"

# until we troubleshoot these issues
matrix:
allow_failures:
- r: 3.1
- r: 3.2

r_binary_packages:
- rstan
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ S3method(predict_confint,model_fit)
S3method(predict_num,"_elnet")
S3method(predict_num,model_fit)
S3method(predict_predint,model_fit)
S3method(predict_quantile,model_fit)
S3method(predict_raw,"_elnet")
S3method(predict_raw,"_lognet")
S3method(predict_raw,"_multnet")
Expand Down Expand Up @@ -95,6 +96,8 @@ export(predict_num)
export(predict_num.model_fit)
export(predict_predint)
export(predict_predint.model_fit)
export(predict_quantile)
export(predict_quantile.model_fit)
export(predict_raw)
export(predict_raw.model_fit)
export(rand_forest)
Expand All @@ -113,10 +116,12 @@ import(rlang)
importFrom(dplyr,arrange)
importFrom(dplyr,as_tibble)
importFrom(dplyr,bind_cols)
importFrom(dplyr,bind_rows)
importFrom(dplyr,collect)
importFrom(dplyr,full_join)
importFrom(dplyr,funs)
importFrom(dplyr,group_by)
importFrom(dplyr,mutate)
importFrom(dplyr,pull)
importFrom(dplyr,rename)
importFrom(dplyr,rename_at)
Expand Down Expand Up @@ -159,6 +164,7 @@ importFrom(stats,predict)
importFrom(stats,qnorm)
importFrom(stats,qt)
importFrom(stats,quantile)
importFrom(stats,setNames)
importFrom(stats,terms)
importFrom(stats,update)
importFrom(tibble,as_tibble)
Expand Down
10 changes: 4 additions & 6 deletions R/aaa_spark_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
#' @importFrom dplyr starts_with rename rename_at vars funs
format_spark_probs <- function(results, object) {
results <- dplyr::select(results, starts_with("probability_"))
results <- dplyr::rename_at(
results,
vars(starts_with("probability_")),
funs(gsub("probability", "pred", .))
)
results
p <- ncol(results)
lvl <- paste0("probability_", 0:(p - 1))
names(lvl) <- paste0("pred_", object$fit$.index_labels)
results %>% rename(!!!syms(lvl))
}

format_spark_class <- function(results, object) {
Expand Down
12 changes: 8 additions & 4 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ form_form <-
function(object, control, env, ...) {
opts <- quos(...)

y_levels <- levels_from_formula( # prob rewrite this as simple subset/levels
env$formula,
env$data
)
if (object$mode != "regression") {
y_levels <- levels_from_formula( # prob rewrite this as simple subset/levels
env$formula,
env$data
)
} else {
y_levels <- NULL
}

object <- check_mode(object, y_levels)

Expand Down
12 changes: 12 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,15 @@ check_args <- function(object) {
check_args.default <- function(object) {
invisible(object)
}

# ------------------------------------------------------------------------------

# copied form recipes

names0 <- function (num, prefix = "x") {
if (num < 1)
stop("`num` should be > 0", call. = FALSE)
ind <- format(1:num)
ind <- gsub(" ", "0", ind)
paste0(prefix, ind)
}
12 changes: 9 additions & 3 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#' @param object An object of class `model_fit`
#' @param new_data A rectangular data object, such as a data frame.
#' @param type A single character value or `NULL`. Possible values
#' are "numeric", "class", "probs", "conf_int", "pred_int", or
#' "raw". When `NULL`, `predict` will choose an appropriate value
#' are "numeric", "class", "probs", "conf_int", "pred_int", "quantile",
#' or "raw". When `NULL`, `predict` will choose an appropriate value
#' based on the model's mode.
#' @param opts A list of optional arguments to the underlying
#' predict function that will be used when `type = "raw"`. The
Expand Down Expand Up @@ -45,6 +45,10 @@
#' produces for class probabilities (or other non-scalar outputs),
#' the columns will be named `.pred_lower_classlevel` and so on.
#'
#' Quantile predictions return a tibble with a column `.pred`, which is
#' a list-column. Each list element contains a tibble with columns
#' `.pred` and `.quantile` (and perhaps others).
#'
#' Using `type = "raw"` with `predict.model_fit` (or using
#' `predict_raw`) will return the unadulterated results of the
#' prediction function.
Expand Down Expand Up @@ -96,6 +100,7 @@ predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ...
prob = predict_classprob(object = object, new_data = new_data, ...),
conf_int = predict_confint(object = object, new_data = new_data, ...),
pred_int = predict_predint(object = object, new_data = new_data, ...),
quantile = predict_quantile(object = object, new_data = new_data, ...),
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
stop("I don't know about type = '", "'", type, call. = FALSE)
)
Expand All @@ -112,7 +117,8 @@ predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ...
res
}

pred_types <- c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int")
pred_types <-
c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int", "quantile")

#' @importFrom glue glue_collapse
check_pred_type <- function(object, type) {
Expand Down
41 changes: 41 additions & 0 deletions R/predict_quantile.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#' @keywords internal
#' @rdname other_predict
#' @param quant A vector of numbers between 0 and 1 for the quantile being
#' predicted.
#' @inheritParams predict.model_fit
#' @method predict_quantile model_fit
#' @export predict_quantile.model_fit
#' @export
predict_quantile.model_fit <-
function (object, new_data, quantile = (1:9)/10, ...) {

if (is.null(object$spec$method$quantile))
stop("No quantile prediction method defined for this ",
"engine.", call. = FALSE)

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$quantile$pre))
new_data <- object$spec$method$quantile$pre(new_data, object)

# Pass some extra arguments to be used in post-processor
object$spec$method$quantile$args$p <- quantile
pred_call <- make_pred_call(object$spec$method$quantile)

res <- eval_tidy(pred_call)

# post-process the predictions
if(!is.null(object$spec$method$quantile$post)) {
res <- object$spec$method$quantile$post(res, object)
}

res
}

#' @export
#' @keywords internal
#' @rdname other_predict
#' @inheritParams predict.model_fit
predict_quantile <- function (object, ...)
UseMethod("predict_quantile")
81 changes: 76 additions & 5 deletions R/surv_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,39 @@
#' `strata` function cannot be used. To achieve the same effect,
#' the extra parameter roles can be used (as described above).
#'
#' The model can be created using the `fit()` function using the
#' following _engines_:
#' \itemize{
#' \item \pkg{R}: `"flexsurv"`
#' }
#' @inheritParams boost_tree
#' @param mode A single character string for the type of model.
#' The only possible value for this model is "regression".
#' @param dist A character string for the outcome distribution. "weibull" is
#' the default.
#' @details
#' For `surv_reg`, the mode will always be "regression".
#'
#' The model can be created using the `fit()` function using the
#' following _engines_:
#' \itemize{
#' \item \pkg{R}: `"flexsurv"`, `"survreg"`
#' }
#'
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
#' model fit call. These can be changed by using the `...`
#' argument to pass in the preferred values. For this type of
#' model, the template of the fit calls are:
#'
#' \pkg{flexsurv}
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")}
#'
#' \pkg{survreg}
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survreg")}
#'
#' Note that `model = TRUE` is needed to produce quantile
#' predictions when there is a stratification variable and can be
#' overridden in other cases.
#'
#' @seealso [varying()], [fit()], [survival::Surv()]
#' @references Jackson, C. (2016). `flexsurv`: A Platform for Parametric Survival
#' Modeling in R. _Journal of Statistical Software_, 70(8), 1 - 33.
Expand Down Expand Up @@ -160,3 +183,51 @@ check_args.surv_reg <- function(object) {

invisible(object)
}

# ------------------------------------------------------------------------------

#' @importFrom stats setNames
#' @importFrom dplyr mutate
survreg_quant <- function(results, object) {
pctl <- object$spec$method$quantile$args$p
n <- nrow(results)
p <- ncol(results)
results <-
results %>%
as_tibble() %>%
setNames(names0(p)) %>%
mutate(.row = 1:n) %>%
gather(.label, .pred, -.row) %>%
arrange(.row, .label) %>%
mutate(.quantile = rep(pctl, n)) %>%
dplyr::select(-.label)
.row <- results[[".row"]]
results <-
results %>%
dplyr::select(-.row)
results <- split(results, .row)
names(results) <- NULL
tibble(.pred = results)
}

# ------------------------------------------------------------------------------

#' @importFrom dplyr bind_rows
flexsurv_mean <- function(results, object) {
results <- unclass(results)
results <- bind_rows(results)
results$est
}

#' @importFrom stats setNames
flexsurv_quant <- function(results, object) {
results <- map(results, as_tibble)
names(results) <- NULL
results <- map(results, setNames, c(".quantile", ".pred", ".pred_lower", ".pred_upper"))
}

# ------------------------------------------------------------------------------

#' @importFrom utils globalVariables
utils::globalVariables(".label")

Loading