Skip to content

Commit

Permalink
Used a new recipes function to save some execution time
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Jun 21, 2017
1 parent 266da40 commit 26887d6
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 43 deletions.
1 change: 1 addition & 0 deletions pkg/caret/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ S3method(summary,resamples)
S3method(summary,train)
S3method(train,default)
S3method(train,formula)
S3method(train,recipe)
S3method(trim,train)
S3method(update,gafs)
S3method(update,rfe)
Expand Down
93 changes: 50 additions & 43 deletions pkg/caret/R/recipes.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

## Overall method for recipes
#' @export
train.recipe <- function(recipe,
data,
method = "rf",
Expand All @@ -22,17 +23,18 @@ train.recipe <- function(recipe,
# prep and bake recipe on entire training set

trained_rec <- prepare(recipe, training = data, fresh = TRUE,
retain = TRUE,
verbose = FALSE, stringsAsFactors = TRUE)
x <- bake(trained_rec, newdata = data, all_predictors())
y <- bake(trained_rec, newdata = data, all_outcomes())
x <- recipes::extract(trained_rec, all_predictors())
y <- recipes::extract(trained_rec, all_outcomes())
if(ncol(y) > 1)
stop("`train` doesn't support multivariate outcomes")
y <- getElement(y, names(y))
is_weight <- summary(trained_rec)$role == "case weight"
if(any(is_weight)) {
if(sum(is_weight) > 1)
stop("Ony one column can be used as a case weight.")
weights <- bake(trained_rec, newdata = data, has_role("case weight"))
weights <- recipes::extract(trained_rec, has_role("case weight"))
weights <- getElement(weights, names(weights))
} else weights <- NULL

Expand All @@ -48,7 +50,7 @@ train.recipe <- function(recipe,
models <- method
method <- "custom"
} else {
models <- caret:::getModelInfo(method, regex = FALSE)[[1]]
models <- getModelInfo(method, regex = FALSE)[[1]]
if (length(models) == 0)
stop(paste("Model", method, "is not in caret's built-in library"), call. = FALSE)
}
Expand All @@ -62,7 +64,7 @@ train.recipe <- function(recipe,
paramNames <- as.character(models$parameters$parameter)

funcCall <- match.call(expand.dots = TRUE)
modelType <- caret:::get_model_type(y)
modelType <- get_model_type(y)
if(!(modelType %in% models$type))
stop(paste("wrong model type for", tolower(modelType)), call. = FALSE)

Expand Down Expand Up @@ -95,7 +97,7 @@ train.recipe <- function(recipe,
trControl$sampling <- parse_sampling(trControl$sampling)
}

caret:::check_dims(x = x, y = y)
check_dims(x = x, y = y)
n <- if(class(y)[1] == "Surv") nrow(y) else length(y)

## Some models that use RWeka start multiple threads and this conflicts with multicore:
Expand Down Expand Up @@ -214,7 +216,7 @@ train.recipe <- function(recipe,
list(origIndex = y_index, bootIndex = training)
})
}
names(trControl$indexOut) <- caret:::prettySeq(trControl$indexOut)
names(trControl$indexOut) <- prettySeq(trControl$indexOut)
} else {
trControl$indexOut <- createTimeSlices(seq(along = y),
initialWindow = trControl$initialWindow,
Expand All @@ -225,11 +227,11 @@ train.recipe <- function(recipe,
}

if(trControl$method != "oob" & is.null(trControl$index))
names(trControl$index) <- caret:::prettySeq(trControl$index)
names(trControl$index) <- prettySeq(trControl$index)
if(trControl$method != "oob" & is.null(names(trControl$index)))
names(trControl$index) <- caret:::prettySeq(trControl$index)
names(trControl$index) <- prettySeq(trControl$index)
if(trControl$method != "oob" & is.null(names(trControl$indexOut)))
names(trControl$indexOut) <- caret:::prettySeq(trControl$indexOut)
names(trControl$indexOut) <- prettySeq(trControl$indexOut)

if(is.null(tuneGrid)) {
tuneGrid <- models$grid(x = x, y = y, len = tuneLength, search = trControl$search)
Expand All @@ -243,7 +245,7 @@ train.recipe <- function(recipe,
"tuning parameter for evaluation"), call. = FALSE)
}

dotNames <- caret:::hasDots(tuneGrid, models)
dotNames <- hasDots(tuneGrid, models)
if(dotNames) colnames(tuneGrid) <- gsub("^\\.", "", colnames(tuneGrid))
## Check tuning parameter names
tuneNames <- as.character(models$parameters$parameter)
Expand All @@ -259,7 +261,7 @@ train.recipe <- function(recipe,

## In case prediction bounds are used, compute the limits. For now,
## store these in the control object since that gets passed everywhere
trControl$yLimits <- if(is.numeric(y)) caret:::get_range(y) else NULL
trControl$yLimits <- if(is.numeric(y)) get_range(y) else NULL

if(trControl$method != "none") {

Expand Down Expand Up @@ -300,10 +302,10 @@ train.recipe <- function(recipe,
perfNames <- metric
} else {
## run some data thru the summary function and see what we get
testSummary <- caret:::evalSummaryFunction(y,
wts = weights, ctrl = trControl,
lev = classLevels, metric = metric,
method = method)
testSummary <- evalSummaryFunction(y,
wts = weights, ctrl = trControl,
lev = classLevels, metric = metric,
method = method)
perfNames <- names(testSummary)
}

Expand Down Expand Up @@ -363,7 +365,7 @@ train.recipe <- function(recipe,
}
}
}

## Remove extra indices
trControl$indexExtra <- NULL

Expand Down Expand Up @@ -452,7 +454,7 @@ train.recipe <- function(recipe,
performance <- cbind(performance, tuneGrid)
performance <- performance[-1,,drop = FALSE]
tmp <- resampledCM <- NULL

} # end(trControl$method != "none")

## Save some or all of the resampling summary metrics
Expand Down Expand Up @@ -494,20 +496,20 @@ train.recipe <- function(recipe,

indexFinal <- if(is.null(trControl$indexFinal))
seq(along = y) else trControl$indexFinal

if(!(length(trControl$seeds) == 1 && is.na(trControl$seeds)))
set.seed(trControl$seeds[[length(trControl$seeds)]][1])
finalTime <- system.time(
finalModel <- rec_model(recipe,
caret:::subset_x(data, indexFinal),
subset_x(data, indexFinal),
method = models,
tuneValue = bestTune,
obsLevels = classLevels,
last = TRUE,
classProbs = trControl$classProbs,
sampling = trControl$sampling,
...)
)
)

if(trControl$trim && !is.null(models$trim)) {
if(trControl$verboseIter) old_size <- object.size(finalModel$fit)
Expand Down Expand Up @@ -603,6 +605,8 @@ predict.train.recipe <- function(object,

## drop dimensions from a `tibble`
get_vector <- function(object) {
if(!inherits(object, "tbl_df") & !is.data.frame(object))
return(object)
if(ncol(object) > 1)
stop("Only one column should be available")
getElement(object, names(object)[1])
Expand All @@ -628,23 +632,24 @@ preproc_dots <- function(...) {

## Convert the recipe to holdout data
holdout_rec <- function(object, dat, index) {
##
ho_data <- bake(object$recipe,
newdata = caret:::subset_x(dat, index),
newdata = subset_x(dat, index),
all_outcomes())
names(ho_data) <- "obs"
## ~~~~~~ move these two to other functions:
wt_cols <- role_cols(object$recipe, "case weight")
if(length(wt_cols) > 0) {
wts <- bake(object$recipe,
newdata = caret:::subset_x(dat, index),
newdata = subset_x(dat, index),
has_role("case weight"))
ho_data$weights <- get_vector(wts)
rm(wts)
}
perf_cols <- role_cols(object$recipe, "performance var")
if(length(perf_cols) > 0) {
perf_data <- bake(object$recipe,
newdata = caret:::subset_x(dat, index),
newdata = subset_x(dat, index),
has_role("performance var"))
ho_data <- cbind(ho_data, perf_data)
}
Expand All @@ -662,8 +667,9 @@ rec_model <- function(rec, dat, method, tuneValue, obsLevels,
## get original column names for downsamping then reassemble
## the training set prior to making the recipe
var_info <- summary(rec)
y <- dat[, subset(summary(rec), role == "outcome")$variable]
if(ncol(y) > 1)
y_cols <- role_cols(rec, "outcome")
y <- dat[, y_cols]
if(length(y_cols) > 1)
stop("`train` doesn't support multivariate outcomes")
if(is.data.frame(y)) y <- getElement(y, names(y))
other_cols <- subset(var_info,
Expand All @@ -674,14 +680,16 @@ rec_model <- function(rec, dat, method, tuneValue, obsLevels,
tmp <- sampling$func(other_dat, y)
orig_dat <- dat
dat <- tmp$x
dat[, y] <- tmp$y
dat[, y_cols] <- tmp$y
rm(tmp, y, other_cols, other_dat, orig_dat)
}

trained_rec <- prepare(rec, training = dat, fresh = TRUE,
verbose = FALSE, stringsAsFactors = TRUE)
x <- bake(trained_rec, newdata = dat, all_predictors())
y <- get_vector(bake(trained_rec, newdata = dat, all_outcomes()))
verbose = FALSE, stringsAsFactors = TRUE,
retain = TRUE)
x <- recipes::extract(trained_rec, all_predictors())
y <- recipes::extract(trained_rec, all_outcomes())
y <- get_vector(y)

## Add an extra identifier for later
trained_rec$.numeric <- is.numeric(y)
Expand Down Expand Up @@ -751,7 +759,7 @@ loo_train_rec <- function(rec, dat, info, method,
printed <- format(info$loop)
colnames(printed) <- gsub("^\\.", "", colnames(printed))

`%op%` <- caret:::getOper(ctrl$allowParallel && getDoParWorkers() > 1)
`%op%` <- getOper(ctrl$allowParallel && getDoParWorkers() > 1)

pkgs <- c("methods", "caret", "recipes")
if(!is.null(method$library))
Expand Down Expand Up @@ -794,7 +802,7 @@ loo_train_rec <- function(rec, dat, info, method,

predicted <- rec_pred(method = method,
object = mod_rec,
newdata = caret:::subset_x(dat, holdoutIndex),
newdata = subset_x(dat, holdoutIndex),
param = submod)

predicted <- trim_values(predicted, ctrl, mod_rec$recipe$.numeric)
Expand All @@ -803,7 +811,7 @@ loo_train_rec <- function(rec, dat, info, method,
if(ctrl$classProbs) {
probValues <- rec_prob(method = method,
object = mod_rec,
newdata = caret:::subset_x(dat, holdoutIndex),
newdata = subset_x(dat, holdoutIndex),
param = submod)
if(testing) print(head(probValues))
}
Expand Down Expand Up @@ -833,7 +841,7 @@ loo_train_rec <- function(rec, dat, info, method,
cbind(predicted[[k]], probValues[[k]])
}
predicted <- do.call("rbind", predicted)
allParam <- caret:::expandParameters(info$loop[parm,,drop = FALSE], submod)
allParam <- expandParameters(info$loop[parm,,drop = FALSE], submod)
rownames(predicted) <- NULL
predicted <- cbind(predicted, allParam)
## if saveDetails then save and export 'predicted'
Expand Down Expand Up @@ -868,7 +876,7 @@ oob_train_rec <- function(rec, dat, info, method,

printed <- format(info$loop)
colnames(printed) <- gsub("^\\.", "", colnames(printed))
`%op%` <- caret:::getOper(ctrl$allowParallel && getDoParWorkers() > 1)
`%op%` <- getOper(ctrl$allowParallel && getDoParWorkers() > 1)
pkgs <- c("methods", "caret", "recipes")
if(!is.null(method$library)) pkgs <- c(pkgs, method$library)
result <- foreach(
Expand Down Expand Up @@ -919,7 +927,7 @@ train_rec <- function(rec, dat, info, method, ctrl, lev, testing = FALSE, ...) {
if(!is.null(ctrl$indexExtra))
ctrl$indexExtra <- c(list("AllData" = NULL), ctrl$indexExtra)
}
`%op%` <- caret:::getOper(ctrl$allowParallel && getDoParWorkers() > 1)
`%op%` <- getOper(ctrl$allowParallel && getDoParWorkers() > 1)
keep_pred <- isTRUE(ctrl$savePredictions) || ctrl$savePredictions %in% c("all", "final")
pkgs <- c("methods", "caret", "recipes")
if(!is.null(method$library)) pkgs <- c(pkgs, method$library)
Expand Down Expand Up @@ -958,7 +966,7 @@ train_rec <- function(rec, dat, info, method, ctrl, lev, testing = FALSE, ...) {

mod_rec <- try(
rec_model(rec,
caret:::subset_x(dat, modelIndex),
subset_x(dat, modelIndex),
method = method,
tuneValue = info$loop[parm,,drop = FALSE],
obsLevels = lev,
Expand All @@ -974,7 +982,7 @@ train_rec <- function(rec, dat, info, method, ctrl, lev, testing = FALSE, ...) {
predicted <- try(
rec_pred(method = method,
object = mod_rec,
newdata = caret:::subset_x(dat, holdoutIndex),
newdata = subset_x(dat, holdoutIndex),
param = submod),
silent = TRUE)

Expand All @@ -990,7 +998,7 @@ train_rec <- function(rec, dat, info, method, ctrl, lev, testing = FALSE, ...) {
predictedExtra <- lapply(extraIndex, function(idx) {
rec_pred(method = method,
object = mod_rec,
newdata = caret:::subset_x(dat, idx),
newdata = subset_x(dat, idx),
param = submod)
})
}
Expand All @@ -1009,14 +1017,14 @@ train_rec <- function(rec, dat, info, method, ctrl, lev, testing = FALSE, ...) {
if(class(mod_rec)[1] != "try-error") {
probValues <- rec_prob(method = method,
object = mod_rec,
newdata = caret:::subset_x(dat, holdoutIndex),
newdata = subset_x(dat, holdoutIndex),
param = submod)

if (!is.null(extraIndex))
probValuesExtra <- lapply(extraIndex, function(index) {
rec_prob(method = method,
object = mod_rec,
newdata = caret:::subset_x(dat, index),
newdata = subset_x(dat, index),
param = submod)
})
} else {
Expand All @@ -1040,7 +1048,6 @@ train_rec <- function(rec, dat, info, method, ctrl, lev, testing = FALSE, ...) {

## We'll attach data points/columns to the object used
## to assess holdout performance

ho_data <- holdout_rec(mod_rec, dat, holdoutIndex)

if(!is.null(submod)) {
Expand All @@ -1066,7 +1073,7 @@ train_rec <- function(rec, dat, info, method, ctrl, lev, testing = FALSE, ...) {
y <- y[rows]
wts <- wts[rows]

x <- caret:::outcome_conversion(x, lv = lev)
x <- outcome_conversion(x, lv = lev)
out <- data.frame(pred = x, obs = y, stringsAsFactors = FALSE)
if(!is.null(wts)) out$weights <- wts
out$rowIndex <- rows
Expand Down

0 comments on commit 26887d6

Please sign in to comment.