Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 1 addition & 21 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,13 @@ S3method(predict,model_fit)
S3method(predict,model_spec)
S3method(predict,nullmodel)
S3method(predict_class,"_lognet")
S3method(predict_class,model_fit)
S3method(predict_class,"_multnet")
S3method(predict_classprob,"_lognet")
S3method(predict_classprob,"_multnet")
S3method(predict_classprob,model_fit)
S3method(predict_confint,model_fit)
S3method(predict_numeric,"_elnet")
S3method(predict_numeric,model_fit)
S3method(predict_predint,model_fit)
S3method(predict_quantile,model_fit)
S3method(predict_raw,"_elnet")
S3method(predict_raw,"_lognet")
S3method(predict_raw,"_multnet")
S3method(predict_raw,model_fit)
S3method(print,boost_tree)
S3method(print,decision_tree)
S3method(print,fit_control)
Expand Down Expand Up @@ -103,20 +97,6 @@ export(nearest_neighbor)
export(null_model)
export(nullmodel)
export(predict.model_fit)
export(predict_class)
export(predict_class.model_fit)
export(predict_classprob)
export(predict_classprob.model_fit)
export(predict_confint)
export(predict_confint.model_fit)
export(predict_numeric)
export(predict_numeric.model_fit)
export(predict_predint)
export(predict_predint.model_fit)
export(predict_quantile)
export(predict_quantile.model_fit)
export(predict_raw)
export(predict_raw.model_fit)
export(rand_forest)
export(rpart_train)
export(set_args)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## New Features

* A "null model" is now available that fits a predictor-free model (using the mean of the outcome for regression or the mode for classification).

* `fit_xy()` can take a single column data frame or matrix for `y` without error

## Other Changes
Expand All @@ -13,6 +14,8 @@ that are actually varying).

* `fit_control()` not returns an S3 method.

* The prediction modules (e.g. `predict_class`, `predict_numeric`, etc) were de-exported. These were internal functions that were not to be used by the users and the users were using them.

## Bug Fixes

* `varying_args()` now uses the version from the `generics` package. This means
Expand All @@ -33,6 +36,7 @@ column names once (#107).
* For multinomial regression using glmnet, `multi_predict()` now pulls the
correct default penalty (#108).

* Confidence and prediction intervals for logistic regression were only computed the intervals for a single level. Both are now computed. (#156)


# parsnip 0.0.1
Expand Down
96 changes: 77 additions & 19 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,20 @@
#' \pkg{spark}
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")}
#'
#'
#' \pkg{keras}
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")}
#'
#' When using `glmnet` models, there is the option to pass
#' multiple values (or no values) to the `penalty` argument.
#' This can have an effect on the model object results. When using
#' the `predict()` method in these cases, the return object type
#' depends on the value of `penalty`. If a single value is
#' given, the results will be a simple numeric vector. When
#' multiple values or no values for `penalty` are used in
#' `linear_reg()`, the `predict()` method will return a data frame with
#' columns `values` and `lambda`.
#' multiple values (or no values) to the `penalty` argument. This
#' can have an effect on the model object results. When using the
#' `predict()` method in these cases, the return value depends on
#' the value of `penalty`. When using `predict()`, only a single
#' value of the penalty can be used. When predicting on multiple
#' penalties, the `multi_predict()` function can be used. It
#' returns a tibble with a list column called `.pred` that contains
#' a tibble with all of the penalty results.
#'
#' For prediction, the `stan` engine can compute posterior
#' intervals analogous to confidence and prediction intervals. In
Expand Down Expand Up @@ -130,7 +130,7 @@ print.linear_reg <- function(x, ...) {
cat("Linear Regression Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if(!is.null(x$method$fit$args)) {
if (!is.null(x$method$fit$args)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down Expand Up @@ -216,12 +216,66 @@ organize_glmnet_pred <- function(x, object) {

# ------------------------------------------------------------------------------

# For `predict` methods that use `glmnet`, we have specific methods.
# Only one value of the penalty should be allowed when called by `predict()`:

check_penalty <- function(penalty = NULL, object, multi = FALSE) {

if (is.null(penalty)) {
penalty <- object$fit$lambda
}

# when using `predict()`, allow for a single lambda
if (!multi) {
if (length(penalty) != 1)
stop("`penalty` should be a single numeric value. ",
"`multi_predict()` can be used to get multiple predictions ",
"per row of data.", call. = FALSE)
}

if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda)
stop("The glmnet model was fit with a single penalty value of ",
object$fit$lambda, ". Predicting with a value of ",
penalty, " will give incorrect results from `glmnet()`.",
call. = FALSE)

penalty
}

# ------------------------------------------------------------------------------
# glmnet call stack for linear regression using `predict` when object has
# classes "_elnet" and "model_fit":
#
# predict()
# predict._elnet(penalty = NULL) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_numeric()
# predict_numeric._elnet()
# predict_numeric.model_fit()
# predict.elnet()


# glmnet call stack for linear regression using `multi_predict` when object has
# classes "_elnet" and "model_fit":
#
# multi_predict()
# multi_predict._elnet(penalty = NULL)
# predict._elnet(multi = TRUE) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_raw()
# predict_raw._elnet()
# predict_raw.model_fit(opts = list(s = penalty))
# predict.elnet()


#' @export
predict._elnet <-
function(object, new_data, type = NULL, opts = list(), ...) {
function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)


object$spec$args$penalty <- check_penalty(penalty, object, multi)

object$spec <- eval_args(object$spec)
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
}
Expand All @@ -230,7 +284,7 @@ predict._elnet <-
predict_numeric._elnet <- function(object, new_data, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict_numeric.model_fit(object, new_data = new_data, ...)
}
Expand All @@ -239,8 +293,9 @@ predict_numeric._elnet <- function(object, new_data, ...) {
predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
opts$s <- object$spec$args$penalty
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
}

Expand All @@ -251,14 +306,17 @@ multi_predict._elnet <-
function(object, new_data, type = NULL, penalty = NULL, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

dots <- list(...)
if (is.null(penalty))
penalty <- object$fit$lambda
dots$s <- penalty

object$spec <- eval_args(object$spec)
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)

if (is.null(penalty)) {
penalty <- object$fit$lambda
}

pred <- predict._elnet(object, new_data = new_data, type = "raw",
opts = dots, penalty = penalty, multi = TRUE)
param_key <- tibble(group = colnames(pred), penalty = penalty)
pred <- as_tibble(pred)
pred$.row <- 1:nrow(pred)
Expand Down
113 changes: 74 additions & 39 deletions R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")}
#'
#' When using `glmnet` models, there is the option to pass
#' multiple values (or no values) to the `penalty` argument.
#' This can have an effect on the model object results. When using
#' the `predict()` method in these cases, the return object type
#' depends on the value of `penalty`. If a single value is
#' given, the results will be a simple numeric vector. When
#' multiple values or no values for `penalty` are used in
#' `logistic_reg()`, the `predict()` method will return a data frame with
#' columns `values` and `lambda`.
#' multiple values (or no values) to the `penalty` argument. This
#' can have an effect on the model object results. When using the
#' `predict()` method in these cases, the return value depends on
#' the value of `penalty`. When using `predict()`, only a single
#' value of the penalty can be used. When predicting on multiple
#' penalties, the `multi_predict()` function can be used. It
#' returns a tibble with a list column called `.pred` that contains
#' a tibble with all of the penalty results.
#'
#' For prediction, the `stan` engine can compute posterior
#' intervals analogous to confidence and prediction intervals. In
Expand Down Expand Up @@ -235,41 +235,41 @@ organize_glmnet_prob <- function(x, object) {
}

# ------------------------------------------------------------------------------
# glmnet call stack for linear regression using `predict` when object has
# classes "_lognet" and "model_fit" (for class predictions):
#
# predict()
# predict._lognet(penalty = NULL) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_class()
# predict_class._lognet()
# predict_class.model_fit()
# predict.lognet()


# glmnet call stack for linear regression using `multi_predict` when object has
# classes "_lognet" and "model_fit" (for class predictions):
#
# multi_predict()
# multi_predict._lognet(penalty = NULL)
# predict._lognet(multi = TRUE) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_raw()
# predict_raw._lognet()
# predict_raw.model_fit(opts = list(s = penalty))
# predict.lognet()

#' @export
predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
}

#' @export
predict_class._lognet <- function (object, new_data, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict_class.model_fit(object, new_data = new_data, ...)
}
# ------------------------------------------------------------------------------

#' @export
predict_classprob._lognet <- function (object, new_data, ...) {
predict._lognet <- function (object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict_classprob.model_fit(object, new_data = new_data, ...)
}

#' @export
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
object$spec$args$penalty <- check_penalty(penalty, object, multi)

object$spec <- eval_args(object$spec)
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
}


Expand All @@ -281,23 +281,26 @@ multi_predict._lognet <-
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

if (is_quosure(penalty))
penalty <- eval_tidy(penalty)

dots <- list(...)
if (is.null(penalty))
penalty <- object$fit$lambda
penalty <- eval_tidy(object$fit$lambda)
dots$s <- penalty

if (is.null(type))
type <- "class"
if (!(type %in% c("class", "prob", "link"))) {
stop ("`type` should be either 'class', 'link', or 'prob'.", call. = FALSE)
if (!(type %in% c("class", "prob", "link", "raw"))) {
stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
}
if (type == "prob")
dots$type <- "response"
else
dots$type <- type

object$spec <- eval_args(object$spec)
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)
pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots)
param_key <- tibble(group = colnames(pred), penalty = penalty)
pred <- as_tibble(pred)
pred$.row <- 1:nrow(pred)
Expand All @@ -321,6 +324,38 @@ multi_predict._lognet <-
tibble(.pred = pred)
}





#' @export
predict_class._lognet <- function (object, new_data, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict_class.model_fit(object, new_data = new_data, ...)
}

#' @export
predict_classprob._lognet <- function (object, new_data, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict_classprob.model_fit(object, new_data = new_data, ...)
}

#' @export
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
}


# ------------------------------------------------------------------------------

#' @importFrom utils globalVariables
Expand Down
Loading