diff --git a/R/caretList.R b/R/caretList.R index 531e13f..4231e15 100644 --- a/R/caretList.R +++ b/R/caretList.R @@ -7,7 +7,6 @@ #' @examples #' caretModelSpec("rf", tuneLength=5, preProcess="ica") caretModelSpec <- function(method="rf", ...){ - stopifnot(is.character(method)) out <- c(list(method=method), list(...)) return(out) } @@ -20,16 +19,18 @@ tuneCheck <- function(x){ #Check model methods stopifnot(is.list(x)) - methods <- sapply(x, function(a) a$method) + + methods <- lapply(x, function(m) m$method) methodCheck(methods) + method_names <- sapply(x, extractModelName) #Name models if(is.null(names(x))){ - names(x) <- methods + names(x) <- method_names } i <- names(x)=="" if(any(i)){ - names(x)[i] <- methods[i] + names(x)[i] <- method_names[i] } names(x) <- make.names(names(x), unique=TRUE) @@ -44,12 +45,36 @@ tuneCheck <- function(x){ #' @importFrom caret modelLookup #' @return NULL methodCheck <- function(x){ - all_models <- unique(modelLookup()$model) - bad_models <- setdiff(x, all_models) + + # Fetch list of existing caret models + supported_models <- unique(modelLookup()$model) + + # Split given model methods based on whether or not they + # are specified as strings or model info lists (ie custom models) + models <- lapply(x, function(m) { + if (is.list(m)){ + validateCustomModel(m) + data.frame(type="custom", model=m$method) + } else if (is.character(m)){ + data.frame(type="native", model=m) + } else { + stop(paste0( + "Method \"", m, "\" is invalid. Methods must either be character names ", + "supported by caret (e.g. \"gbm\") or modelInfo lists ", + "(e.g. getModelInfo(\"gbm\", regex=F))")) + } + }) + models <- do.call(rbind, models) + + # Ensure that all non-custom models are valid + native_models <- subset(models, type == "native")$model + bad_models <- setdiff(native_models, supported_models) + if(length(bad_models)>0){ msg <- paste(bad_models, collapse=", ") stop(paste("The following models are not valid caret models:", msg)) } + return(invisible(NULL)) } @@ -182,9 +207,7 @@ caretList <- function( #Make methodList into a tuneList and add onto tuneList if(!is.null(methodList)){ - methodCheck(methodList) - tuneList_extra <- lapply(methodList, caretModelSpec) - tuneList <- c(tuneList, tuneList_extra) + tuneList <- c(tuneList, lapply(methodList, caretModelSpec)) } #Make sure tuneList is valid @@ -283,6 +306,19 @@ predict.caretList <- function(object, newdata = NULL, ..., verbose = FALSE){ } preds <- as.matrix(t(preds)) } - colnames(preds) <- make.names(sapply(object, function(x) x$method), unique=TRUE) + + if (is.null(names(object))){ + # If the model list used for predictions is not currently named, + # then exctract the model names from each model individually. + # Note that this should only be possible when caretList objects + # are created manually + predcols <- sapply(object, extractModelName) + colnames(preds) <- make.names(predcols, unique=TRUE) + } else { + # Otherwise, assign the names of the prediction columns to be + # equal to the names in the given model list + colnames(preds) <- names(object) + } + return(preds) } diff --git a/R/caretStack.R b/R/caretStack.R index cb79d5d..9a5273e 100644 --- a/R/caretStack.R +++ b/R/caretStack.R @@ -69,6 +69,7 @@ predict.caretStack <- function( ...){ stopifnot(is(object$models, "caretList")) type <- extractModelTypes(object$models) + preds <- predict(object$models, newdata=newdata) if(type == "Classification"){ out <- predict(object$ens_model, newdata=preds, ...) @@ -167,9 +168,8 @@ summary.caretStack <- function(object, ...){ #' print(meta_model) #' } print.caretStack <- function(x, ...){ - model_count <- length(x$models) - model_names <- paste(sapply(x$models, function(x) x$method), collapse=", ") - cat(sprintf("A %s ensemble of %s base models: %s", x$ens_model$method, model_count, model_names)) + base.models <- paste(names(x$models), collapse=", ") + cat(sprintf("A %s ensemble of %s base models: %s", x$ens_model$method, length(x$models), base.models)) cat("\n\nEnsemble results:\n") print(x$ens_model) } diff --git a/R/helper_functions.R b/R/helper_functions.R index 4302245..2d28ce1 100644 --- a/R/helper_functions.R +++ b/R/helper_functions.R @@ -109,7 +109,7 @@ check_caretList_model_types <- function(list_of_models){ #Check that classification models saved probabilities #TODO: ALLOW NON PROB MODELS! if (type=="Classification"){ - probModels <- sapply(list_of_models, function(x) modelLookup(x$method)[1, "probModel"]) + probModels <- sapply(list_of_models, function(x) is.function(x$modelInfo$prob)) if(!all(probModels)) stop("All models for classification must be able to generate class probabilities.") classProbs <- sapply(list_of_models, function(x) x$control$classProbs) if(!all(classProbs)){ @@ -150,7 +150,7 @@ check_bestpreds_indexes <- function(modelLibrary){ names(rows) <- names(modelLibrary) check <- length(unique(rows)) if(check != 1){ - stop("Re-sampled predictions from each component model do not use the same rowIndexs from the origial dataset") + stop("Re-sampled predictions from each component model do not use the same rowIndexes from the origial dataset") } return(invisible(NULL)) } @@ -195,6 +195,36 @@ check_bestpreds_preds <- function(modelLibrary){ ##################################################### # Extraction functions ##################################################### +#' @title Extract the method name associated with a single train object +#' @description Extracts the method name associated with a single train object. Note +#' that for standard models (i.e. those already prespecified by caret), the +#' "method" attribute on the train object is used directly while for custom +#' models the "method" attribute within the model$modelInfo attribute is +#' used instead. +#' @param x a single caret train object +#' @return Name associated with model +extractModelName <- function(x) { + if (is.list(x$method)){ + validateCustomModel(x$method)$method + } else if (x$method == "custom"){ + validateCustomModel(x$modelInfo)$method + } else x$method +} + +#' @title Validate a custom caret model info list +#' @description Currently, this only ensures that all model info lists +#' were also assigned a "method" attribute for consistency with usage +#' of non-custom models +#' @param x a model info list (e.g. \code{getModelInfo("rf", regex=F)\[[1]]}) +#' @return validated model info list (i.e. x) +validateCustomModel <- function(x) { + if (is.null(x$method)) + stop(paste( + "Custom models must be defined with a \"method\" attribute containing the name", + "by which that model should be referenced. Example: my.glm.model$method <- \"custom_glm\"")) + x +} + #' @title Extracts the model types from a list of train model #' @description Extracts the model types from a list of train model #' @@ -236,7 +266,7 @@ bestPreds <- function(x){ extractBestPreds <- function(list_of_models){ out <- lapply(list_of_models, bestPreds) if(is.null(names(out))){ - names(out) <- make.names(sapply(list_of_models, function(x) x$method), unique=TRUE) + names(out) <- make.names(sapply(list_of_models, extractModelName), unique=TRUE) } sink <- gc(reset=TRUE) return(out) diff --git a/man/extractModelName.Rd b/man/extractModelName.Rd new file mode 100644 index 0000000..4d6d9e6 --- /dev/null +++ b/man/extractModelName.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/helper_functions.R +\name{extractModelName} +\alias{extractModelName} +\title{Extract the method name associated with a single train object} +\usage{ +extractModelName(x) +} +\arguments{ +\item{x}{a single caret train object} +} +\value{ +Name associated with model +} +\description{ +Extracts the method name associated with a single train object. Note +that for standard models (i.e. those already prespecified by caret), the +"method" attribute on the train object is used directly while for custom +models the "method" attribute within the model$modelInfo attribute is +used instead. +} + diff --git a/man/validateCustomModel.Rd b/man/validateCustomModel.Rd new file mode 100644 index 0000000..7065041 --- /dev/null +++ b/man/validateCustomModel.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/helper_functions.R +\name{validateCustomModel} +\alias{validateCustomModel} +\title{Validate a custom caret model info list} +\usage{ +validateCustomModel(x) +} +\arguments{ +\item{x}{a model info list (e.g. \code{getModelInfo("rf", regex=F)\[[1]]})} +} +\value{ +validated model info list (i.e. x) +} +\description{ +Currently, this only ensures that all model info lists +were also assigned a "method" attribute for consistency with usage +of non-custom models +} + diff --git a/tests/testthat/test-ensemble.R b/tests/testthat/test-ensemble.R index 7bc98e7..5d8f28f 100644 --- a/tests/testthat/test-ensemble.R +++ b/tests/testthat/test-ensemble.R @@ -186,3 +186,59 @@ test_that("It works for classification models", { expect_is(pred.classc, "numeric") expect_equal(length(pred.classc), 1) }) + +context("Do ensembles of custom models work?") + +test_that("Ensembles using custom models work correctly", { + set.seed(1234) + + # Create custom caret models with a properly assigned method attribute + custom.rf <- getModelInfo("rf", regex=F)[[1]] + custom.rf$method <- "custom.rf" + + custom.rpart <- getModelInfo("rpart", regex=F)[[1]] + custom.rpart$method <- "custom.rpart" + + # Define models to be used in ensemble + tune.list <- list( + # Add an unnamed model to ensure that method names are extracted from model info + caretModelSpec(method=custom.rf, tuneLength=1), + # Add a named custom model, to contrast the above + myrpart=caretModelSpec(method=custom.rpart, tuneLength=1), + # Add a non-custom model + treebag=caretModelSpec(method="treebag", tuneLength=1) + ) + train.control <- trainControl(method="cv", number=2, classProbs=T) + X.df <- as.data.frame(X.class) + + # Create an ensemble using the above models + expect_warning(cl <- caretList(X.df, Y.class, tuneList=tune.list, trControl=train.control)) + expect_that(cl, is_a("caretList")) + expect_silent(cs <- caretEnsemble(cl)) + expect_that(cs, is_a("caretEnsemble")) + + # Validate names assigned to ensembled models + expect_equal(sort(names(cs$models)), c("custom.rf", "myrpart", "treebag")) + + # Validate ensemble predictions + expect_warning(pred.classa <- predict(cs, type="prob")) + expect_silent(pred.classb <- predict(cs, newdata = X.df, type="prob")) + expect_silent(pred.classc <- predict(cs, newdata = X.df[2, ], type="prob")) + expect_true(is.numeric(pred.classa)) + expect_true(is.numeric(pred.classb)) + expect_true(is.numeric(pred.classc)) + expect_true(length(pred.classa)==150) + expect_true(length(pred.classb)==150) + expect_true(length(pred.classc)==1) + expect_identical(pred.classa, pred.classb) + expect_less_than(abs(0.9749462 - pred.classc), 0.01) + + # Verify that not specifying a method attribute for custom models causes an error + tune.list <- list( + # Add a custom caret model WITHOUT a properly assigned method attribute + caretModelSpec(method=getModelInfo("rf", regex=F)[[1]], tuneLength=1), + treebag=caretModelSpec(method="treebag", tuneLength=1) + ) + msg <- "Custom models must be defined with a \"method\" attribute" + expect_error(caretList(X.class, Y.class, tuneList=tune.list, trControl=train.control), regexp=msg) +})