Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for custom models #198

Merged
merged 1 commit into from
Mar 21, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#' @examples
#' caretModelSpec("rf", tuneLength=5, preProcess="ica")
caretModelSpec <- function(method="rf", ...){
stopifnot(is.character(method))
out <- c(list(method=method), list(...))
return(out)
}
Expand All @@ -20,16 +19,18 @@ tuneCheck <- function(x){

#Check model methods
stopifnot(is.list(x))
methods <- sapply(x, function(a) a$method)

methods <- lapply(x, function(m) m$method)
methodCheck(methods)
method_names <- sapply(x, extractModelName)

#Name models
if(is.null(names(x))){
names(x) <- methods
names(x) <- method_names
}
i <- names(x)==""
if(any(i)){
names(x)[i] <- methods[i]
names(x)[i] <- method_names[i]
}
names(x) <- make.names(names(x), unique=TRUE)

Expand All @@ -44,12 +45,36 @@ tuneCheck <- function(x){
#' @importFrom caret modelLookup
#' @return NULL
methodCheck <- function(x){
all_models <- unique(modelLookup()$model)
bad_models <- setdiff(x, all_models)

# Fetch list of existing caret models
supported_models <- unique(modelLookup()$model)

# Split given model methods based on whether or not they
# are specified as strings or model info lists (ie custom models)
models <- lapply(x, function(m) {
if (is.list(m)){
validateCustomModel(m)
data.frame(type="custom", model=m$method)
} else if (is.character(m)){
data.frame(type="native", model=m)
} else {
stop(paste0(
"Method \"", m, "\" is invalid. Methods must either be character names ",
"supported by caret (e.g. \"gbm\") or modelInfo lists ",
"(e.g. getModelInfo(\"gbm\", regex=F))"))
}
})
models <- do.call(rbind, models)

# Ensure that all non-custom models are valid
native_models <- subset(models, type == "native")$model
bad_models <- setdiff(native_models, supported_models)

if(length(bad_models)>0){
msg <- paste(bad_models, collapse=", ")
stop(paste("The following models are not valid caret models:", msg))
}

return(invisible(NULL))
}

Expand Down Expand Up @@ -182,9 +207,7 @@ caretList <- function(

#Make methodList into a tuneList and add onto tuneList
if(!is.null(methodList)){
methodCheck(methodList)
tuneList_extra <- lapply(methodList, caretModelSpec)
tuneList <- c(tuneList, tuneList_extra)
tuneList <- c(tuneList, lapply(methodList, caretModelSpec))
}

#Make sure tuneList is valid
Expand Down Expand Up @@ -283,6 +306,19 @@ predict.caretList <- function(object, newdata = NULL, ..., verbose = FALSE){
}
preds <- as.matrix(t(preds))
}
colnames(preds) <- make.names(sapply(object, function(x) x$method), unique=TRUE)

if (is.null(names(object))){
# If the model list used for predictions is not currently named,
# then exctract the model names from each model individually.
# Note that this should only be possible when caretList objects
# are created manually
predcols <- sapply(object, extractModelName)
colnames(preds) <- make.names(predcols, unique=TRUE)
} else {
# Otherwise, assign the names of the prediction columns to be
# equal to the names in the given model list
colnames(preds) <- names(object)
}

return(preds)
}
6 changes: 3 additions & 3 deletions R/caretStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ predict.caretStack <- function(
...){
stopifnot(is(object$models, "caretList"))
type <- extractModelTypes(object$models)

preds <- predict(object$models, newdata=newdata)
if(type == "Classification"){
out <- predict(object$ens_model, newdata=preds, ...)
Expand Down Expand Up @@ -167,9 +168,8 @@ summary.caretStack <- function(object, ...){
#' print(meta_model)
#' }
print.caretStack <- function(x, ...){
model_count <- length(x$models)
model_names <- paste(sapply(x$models, function(x) x$method), collapse=", ")
cat(sprintf("A %s ensemble of %s base models: %s", x$ens_model$method, model_count, model_names))
base.models <- paste(names(x$models), collapse=", ")
cat(sprintf("A %s ensemble of %s base models: %s", x$ens_model$method, length(x$models), base.models))
cat("\n\nEnsemble results:\n")
print(x$ens_model)
}
Expand Down
36 changes: 33 additions & 3 deletions R/helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ check_caretList_model_types <- function(list_of_models){
#Check that classification models saved probabilities
#TODO: ALLOW NON PROB MODELS!
if (type=="Classification"){
probModels <- sapply(list_of_models, function(x) modelLookup(x$method)[1, "probModel"])
probModels <- sapply(list_of_models, function(x) is.function(x$modelInfo$prob))
if(!all(probModels)) stop("All models for classification must be able to generate class probabilities.")
classProbs <- sapply(list_of_models, function(x) x$control$classProbs)
if(!all(classProbs)){
Expand Down Expand Up @@ -150,7 +150,7 @@ check_bestpreds_indexes <- function(modelLibrary){
names(rows) <- names(modelLibrary)
check <- length(unique(rows))
if(check != 1){
stop("Re-sampled predictions from each component model do not use the same rowIndexs from the origial dataset")
stop("Re-sampled predictions from each component model do not use the same rowIndexes from the origial dataset")
}
return(invisible(NULL))
}
Expand Down Expand Up @@ -195,6 +195,36 @@ check_bestpreds_preds <- function(modelLibrary){
#####################################################
# Extraction functions
#####################################################
#' @title Extract the method name associated with a single train object
#' @description Extracts the method name associated with a single train object. Note
#' that for standard models (i.e. those already prespecified by caret), the
#' "method" attribute on the train object is used directly while for custom
#' models the "method" attribute within the model$modelInfo attribute is
#' used instead.
#' @param x a single caret train object
#' @return Name associated with model
extractModelName <- function(x) {
if (is.list(x$method)){
validateCustomModel(x$method)$method
} else if (x$method == "custom"){
validateCustomModel(x$modelInfo)$method
} else x$method
}

#' @title Validate a custom caret model info list
#' @description Currently, this only ensures that all model info lists
#' were also assigned a "method" attribute for consistency with usage
#' of non-custom models
#' @param x a model info list (e.g. \code{getModelInfo("rf", regex=F)\[[1]]})
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks correct to me, but doesn't match the .Rd file (see below). I think running devtools::document() will fix the problem.

You can check your work locally with devtools::check(), which conveniently also runs devtools::document().

#' @return validated model info list (i.e. x)
validateCustomModel <- function(x) {
if (is.null(x$method))
stop(paste(
"Custom models must be defined with a \"method\" attribute containing the name",
"by which that model should be referenced. Example: my.glm.model$method <- \"custom_glm\""))
x
}

#' @title Extracts the model types from a list of train model
#' @description Extracts the model types from a list of train model
#'
Expand Down Expand Up @@ -236,7 +266,7 @@ bestPreds <- function(x){
extractBestPreds <- function(list_of_models){
out <- lapply(list_of_models, bestPreds)
if(is.null(names(out))){
names(out) <- make.names(sapply(list_of_models, function(x) x$method), unique=TRUE)
names(out) <- make.names(sapply(list_of_models, extractModelName), unique=TRUE)
}
sink <- gc(reset=TRUE)
return(out)
Expand Down
22 changes: 22 additions & 0 deletions man/extractModelName.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/validateCustomModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 56 additions & 0 deletions tests/testthat/test-ensemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,59 @@ test_that("It works for classification models", {
expect_is(pred.classc, "numeric")
expect_equal(length(pred.classc), 1)
})

context("Do ensembles of custom models work?")

test_that("Ensembles using custom models work correctly", {
set.seed(1234)

# Create custom caret models with a properly assigned method attribute
custom.rf <- getModelInfo("rf", regex=F)[[1]]
custom.rf$method <- "custom.rf"

custom.rpart <- getModelInfo("rpart", regex=F)[[1]]
custom.rpart$method <- "custom.rpart"

# Define models to be used in ensemble
tune.list <- list(
# Add an unnamed model to ensure that method names are extracted from model info
caretModelSpec(method=custom.rf, tuneLength=1),
# Add a named custom model, to contrast the above
myrpart=caretModelSpec(method=custom.rpart, tuneLength=1),
# Add a non-custom model
treebag=caretModelSpec(method="treebag", tuneLength=1)
)
train.control <- trainControl(method="cv", number=2, classProbs=T)
X.df <- as.data.frame(X.class)

# Create an ensemble using the above models
expect_warning(cl <- caretList(X.df, Y.class, tuneList=tune.list, trControl=train.control))
expect_that(cl, is_a("caretList"))
expect_silent(cs <- caretEnsemble(cl))
expect_that(cs, is_a("caretEnsemble"))

# Validate names assigned to ensembled models
expect_equal(sort(names(cs$models)), c("custom.rf", "myrpart", "treebag"))

# Validate ensemble predictions
expect_warning(pred.classa <- predict(cs, type="prob"))
expect_silent(pred.classb <- predict(cs, newdata = X.df, type="prob"))
expect_silent(pred.classc <- predict(cs, newdata = X.df[2, ], type="prob"))
expect_true(is.numeric(pred.classa))
expect_true(is.numeric(pred.classb))
expect_true(is.numeric(pred.classc))
expect_true(length(pred.classa)==150)
expect_true(length(pred.classb)==150)
expect_true(length(pred.classc)==1)
expect_identical(pred.classa, pred.classb)
expect_less_than(abs(0.9749462 - pred.classc), 0.01)

# Verify that not specifying a method attribute for custom models causes an error
tune.list <- list(
# Add a custom caret model WITHOUT a properly assigned method attribute
caretModelSpec(method=getModelInfo("rf", regex=F)[[1]], tuneLength=1),
treebag=caretModelSpec(method="treebag", tuneLength=1)
)
msg <- "Custom models must be defined with a \"method\" attribute"
expect_error(caretList(X.class, Y.class, tuneList=tune.list, trControl=train.control), regexp=msg)
})