Skip to content

Commit

Permalink
Fault tolerance for LOO
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Jul 12, 2017
1 parent 53c702e commit 969bc32
Showing 1 changed file with 53 additions and 22 deletions.
75 changes: 53 additions & 22 deletions pkg/caret/R/recipes.R
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,11 @@ model_failed <- function(x) {
FALSE
}

pred_failed <- function(x)
inherits(x, "try-error")



## Convert the recipe to holdout data. rename this to something like
## get_perf_data
#' @importFrom recipes bake all_predictors all_outcomes has_role
Expand Down Expand Up @@ -809,37 +814,64 @@ loo_train_rec <- function(rec, dat, info, method,
submod <- info$submodels[[parm]]
} else submod <- NULL

mod_rec <- rec_model(rec, dat[ ctrl$index[[iter]], ],
method = method,
tuneValue = info$loop[parm,,drop = FALSE],
obsLevels = lev,
classProbs = ctrl$classProbs,
sampling = ctrl$sampling,
...)

holdoutIndex <- -unique(ctrl$index[[iter]])
mod_rec <-
try(
rec_model(rec, dat[ ctrl$index[[iter]], ],
method = method,
tuneValue = info$loop[parm,,drop = FALSE],
obsLevels = lev,
classProbs = ctrl$classProbs,
sampling = ctrl$sampling,
...),
silent = TRUE)

predicted <- rec_pred(method = method,
object = mod_rec,
newdata = subset_x(dat, holdoutIndex),
param = submod)
holdoutIndex <- ctrl$indexOut[[iter]]

predicted <- trim_values(predicted, ctrl, is_regression)
if(!model_failed(mod_rec)) {
predicted <- try(
rec_pred(method = method,
object = mod_rec,
newdata = subset_x(dat, holdoutIndex),
param = submod),
silent = TRUE)

if(pred_failed(predicted)) {
fail_warning(settings = printed[parm,,drop = FALSE],
msg = predicted,
where = "predictions",
iter = names(ctrl$index)[iter],
verb = ctrl$verboseIter)

predicted <- fill_failed_pred(index = holdoutIndex, lev = lev, submod)
}
} else {
fail_warning(settings = printed[parm,,drop = FALSE],
msg = mod_rec,
iter = names(ctrl$index)[iter],
verb = ctrl$verboseIter)
predicted <- fill_failed_pred(index = holdoutIndex, lev = lev, submod)
}

if(testing) print(head(predicted))
if(ctrl$classProbs) {
probValues <- rec_prob(method = method,
object = mod_rec,
newdata = subset_x(dat, holdoutIndex),
param = submod)
if(!model_failed(mod_rec)) {
probValues <- rec_prob(method = method,
object = mod_rec,
newdata = subset_x(dat, holdoutIndex),
param = submod)
} else {
probValues <- fill_failed_prob(holdoutIndex, lev, submod)
}
if(testing) print(head(probValues))
}

predicted <- trim_values(predicted, ctrl, is_regression)

##################################

## We'll attach data points/columns to the object used
## to assess holdout performance

ho_data <- holdout_rec(mod_rec, dat, holdoutIndex)

if(!is.null(info$submodels)) {
Expand All @@ -853,7 +885,6 @@ loo_train_rec <- function(rec, dat, info, method,
lv = lev,
dat = ho_data)
if(testing) print(head(predicted))

## same for the class probabilities
if(ctrl$classProbs) {
for(k in seq(along = predicted)) predicted[[k]] <-
Expand All @@ -870,11 +901,11 @@ loo_train_rec <- function(rec, dat, info, method,
predicted$pred <- pred_val
if(ctrl$classProbs) predicted <- cbind(predicted, probValues)
predicted <- cbind(predicted, info$loop[parm,,drop = FALSE])

}
if(ctrl$verboseIter)
progress(printed[parm,,drop = FALSE],
names(ctrl$index), iter, FALSE)

predicted
}

Expand Down Expand Up @@ -1005,7 +1036,7 @@ train_rec <- function(rec, dat, info, method, ctrl, lev, testing = FALSE, ...) {
param = submod),
silent = TRUE)

if(class(predicted)[1] == "try-error") {
if(pred_failed(predicted)) {
fail_warning(settings = printed[parm,,drop = FALSE],
msg = predicted,
where = "predictions",
Expand Down

0 comments on commit 969bc32

Please sign in to comment.