Skip to content

Commit

Permalink
fixed multi_predict column names
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Sep 3, 2019
1 parent f305412 commit 43c15db
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 17 deletions.
2 changes: 1 addition & 1 deletion R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ utils::globalVariables(
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
"max_terms", "max_tree", "name", "num_terms", "penalty", "trees",
"sub_neighbors")
"sub_neighbors", ".pred_class")
)

# nocov end
6 changes: 4 additions & 2 deletions R/aaa_multi_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")`
#' such as `type`.
#' @return A tibble with the same number of rows as the data being predicted.
#' Mostly likely, there is a list-column named `.pred` that is a tibble with
#' multiple rows per sub-model.
#' There is a list-column named `.pred` that contains tibbles with
#' multiple rows per sub-model. Note that, within the tibbles, the column names
#' follow the usual standard based on prediction `type` (i.e. `.pred_class` for
#' `type = "class"` and so on).
#' @export
multi_predict <- function(object, ...) {
if (inherits(object$fit, "try-error")) {
Expand Down
4 changes: 2 additions & 2 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {
} else {
if (type == "class") {
pred <- object$spec$method$pred$class$post(pred, object)
pred <- tibble(.pred = factor(pred, levels = object$lvl))
pred <- tibble(.pred_class = factor(pred, levels = object$lvl))
} else {
pred <- object$spec$method$pred$prob$post(pred, object)
pred <- as_tibble(pred)
Expand Down Expand Up @@ -503,7 +503,7 @@ C50_by_tree <- function(tree, object, new_data, type, ...) {

# switch based on prediction type
if (type == "class") {
pred <- tibble(.pred = factor(pred, levels = object$lvl))
pred <- tibble(.pred_class = factor(pred, levels = object$lvl))
} else {
pred <- as_tibble(pred)
names(pred) <- paste0(".pred_", names(pred))
Expand Down
9 changes: 5 additions & 4 deletions R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ multi_predict._lognet <-
if (is.null(type))
type <- "class"
if (!(type %in% c("class", "prob", "link", "raw"))) {
stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
stop("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
}
if (type == "prob")
dots$type <- "response"
Expand All @@ -321,12 +321,12 @@ multi_predict._lognet <-
param_key <- tibble(group = colnames(pred), penalty = penalty)
pred <- as_tibble(pred)
pred$.row <- 1:nrow(pred)
pred <- gather(pred, group, .pred, -.row)
pred <- gather(pred, group, .pred_class, -.row)
if (dots$type == "class") {
pred[[".pred"]] <- factor(pred[[".pred"]], levels = object$lvl)
pred[[".pred_class"]] <- factor(pred[[".pred_class"]], levels = object$lvl)
} else {
if (dots$type == "response") {
pred[[".pred2"]] <- 1 - pred[[".pred"]]
pred[[".pred2"]] <- 1 - pred[[".pred_class"]]
names(pred) <- c(".row", "group", paste0(".pred_", rev(object$lvl)))
pred <- pred[, c(".row", "group", paste0(".pred_", object$lvl))]
}
Expand Down Expand Up @@ -371,3 +371,4 @@ predict_raw._lognet <- function(object, new_data, opts = list(), ...) {
object$spec <- eval_args(object$spec)
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
}

2 changes: 1 addition & 1 deletion R/multinom_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ multi_predict._multnet <-
pred <-
tibble(
.row = rep(1:nrow(new_data), length(penalty)),
.pred = as.vector(pred),
.pred_class = as.vector(pred),
penalty = rep(penalty, each = nrow(new_data))
)
}
Expand Down
6 changes: 4 additions & 2 deletions man/multi_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions tests/testthat/test_logistic_reg_glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ test_that('glmnet prediction, mulitiple lambda', {
mult_pred$rows <- rep(1:7, 2)
mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ]
mult_pred <- mult_pred[, c("penalty", "values")]
names(mult_pred) <- c("penalty", ".pred")
names(mult_pred) <- c("penalty", ".pred_class")
mult_pred <- tibble::as_tibble(mult_pred)

expect_equal(
Expand Down Expand Up @@ -148,7 +148,7 @@ test_that('glmnet prediction, mulitiple lambda', {
form_pred$rows <- rep(1:7, 2)
form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ]
form_pred <- form_pred[, c("penalty", "values")]
names(form_pred) <- c("penalty", ".pred")
names(form_pred) <- c("penalty", ".pred_class")
form_pred <- tibble::as_tibble(form_pred)

expect_equal(
Expand Down Expand Up @@ -180,7 +180,7 @@ test_that('glmnet prediction, no lambda', {
mult_pred$rows <- rep(1:7, 2)
mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ]
mult_pred <- mult_pred[, c("penalty", "values")]
names(mult_pred) <- c("penalty", ".pred")
names(mult_pred) <- c("penalty", ".pred_class")
mult_pred <- tibble::as_tibble(mult_pred)

expect_equal(mult_pred, multi_predict(xy_fit, lending_club[1:7, num_pred]) %>% unnest())
Expand All @@ -206,7 +206,7 @@ test_that('glmnet prediction, no lambda', {
form_pred$rows <- rep(1:7, 2)
form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ]
form_pred <- form_pred[, c("penalty", "values")]
names(form_pred) <- c("penalty", ".pred")
names(form_pred) <- c("penalty", ".pred_class")
form_pred <- tibble::as_tibble(form_pred)

expect_equal(
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_multinom_reg_glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ test_that('glmnet probabilities, mulitiple lambda', {

mult_class <- names(mult_probs)[apply(mult_probs, 1, which.max)]
mult_class <- tibble(
.pred = mult_class,
.pred_class = mult_class,
penalty = rep(lams, each = 3),
row = rep(1:3, 2)
)
Expand Down

0 comments on commit 43c15db

Please sign in to comment.