Skip to content

Commit

Permalink
Required colnames (issue #561), cleaned up some stops, and updated ma…
Browse files Browse the repository at this point in the history
…n files
  • Loading branch information
topepo committed Apr 11, 2017
1 parent bb1d58b commit fc69e02
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 34 deletions.
68 changes: 36 additions & 32 deletions pkg/caret/R/train.default.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#' @aliases train train.default train.formula
#' @param x an object where samples are in rows and features are in columns.
#' This could be a simple matrix, data frame or other type (e.g. sparse
#' matrix). See Details below. Preprocessing using the \code{preProcess}
#' matrix) but must have column names. See Details below. Preprocessing using the \code{preProcess}
#' argument only supports matrices or data frames.
#' @param y a numeric or factor vector containing the outcome for each sample.
#' @param form A formula of the form \code{y ~ x1 + x2 + ...}
Expand Down Expand Up @@ -233,21 +233,25 @@ train.default <- function(x, y,
tuneGrid = NULL,
tuneLength = ifelse(trControl$method == "none", 1, 3)) {
startTime <- proc.time()


if(is.null(colnames(x)))
stop("Please use column names for `x`", call. = FALSE)

if(is.character(y)) y <- as.factor(y)

if(is.list(method)) {
minNames <- c("library", "type", "parameters", "grid",
"fit", "predict", "prob")
nameCheck <- minNames %in% names(method)
if(!all(nameCheck)) stop(paste("some required components are missing:",
paste(minNames[!nameCheck], collapse = ", ")))
paste(minNames[!nameCheck], collapse = ", ")),
call. = FALSE)
models <- method
method <- "custom"
} else {
models <- getModelInfo(method, regex = FALSE)[[1]]
if (length(models) == 0)
stop(paste("Model", method, "is not in caret's built-in library"))
stop(paste("Model", method, "is not in caret's built-in library"), call. = FALSE)
}
checkInstall(models$library)
for(i in seq(along = models$library)) do.call("require", list(package = models$library[i]))
Expand All @@ -260,17 +264,17 @@ train.default <- function(x, y,

funcCall <- match.call(expand.dots = TRUE)
modelType <- get_model_type(y)
if(!(modelType %in% models$type)) stop(paste("wrong model type for", tolower(modelType)))
if(!(modelType %in% models$type)) stop(paste("wrong model type for", tolower(modelType)), call. = FALSE)

if(grepl("^svm", method) & grepl("String$", method)) {
if(is.vector(x) && is.character(x)) {
stop("'x' should be a character matrix with a single column for string kernel methods")
stop("'x' should be a character matrix with a single column for string kernel methods", call. = FALSE)
}
if(is.matrix(x) && is.numeric(x)) {
stop("'x' should be a character matrix with a single column for string kernel methods")
stop("'x' should be a character matrix with a single column for string kernel methods", call. = FALSE)
}
if(is.data.frame(x)) {
stop("'x' should be a character matrix with a single column for string kernel methods")
stop("'x' should be a character matrix with a single column for string kernel methods", call. = FALSE)
}
}

Expand All @@ -280,7 +284,7 @@ train.default <- function(x, y,
"If so, use a 2 level factor as your outcome column."))

if(modelType != "Classification" & !is.null(trControl$sampling))
stop("sampling methods are only implemented for classification problems")
stop("sampling methods are only implemented for classification problems", call. = FALSE)
if(!is.null(trControl$sampling)) {
trControl$sampling <- parse_sampling(trControl$sampling)
}
Expand All @@ -297,7 +301,7 @@ train.default <- function(x, y,
flush.console()

if(!is.null(preProcess) && !(all(names(preProcess) %in% ppMethods)))
stop(paste('pre-processing methods are limited to:', paste(ppMethods, collapse = ", ")))
stop(paste('pre-processing methods are limited to:', paste(ppMethods, collapse = ", ")), call. = FALSE)
if(modelType == "Classification") {
## We should get and save the class labels to ensure that predictions are coerced
## to factors that have the same levels as the original data. This is especially
Expand All @@ -308,7 +312,7 @@ train.default <- function(x, y,
xtab <- table(y)
if(any(xtab == 0)) {
xtab_msg <- paste("'", names(xtab)[xtab == 0], "'", collapse = ", ", sep = "")
stop(paste("One or more factor levels in the outcome has no data:", xtab_msg))
stop(paste("One or more factor levels in the outcome has no data:", xtab_msg), call. = FALSE)
}

if(trControl$classProbs && any(classLevels != make.names(classLevels))) {
Expand All @@ -317,15 +321,15 @@ train.default <- function(x, y,
"the variables names will be converted to ",
paste(make.names(classLevels), collapse = ", "),
". Please use factor levels that can be used as valid R variable names",
" (see ?make.names for help)."))
" (see ?make.names for help)."), call. = FALSE)
}

if(metric %in% c("RMSE", "Rsquared"))
stop(paste("Metric", metric, "not applicable for classification models"))
stop(paste("Metric", metric, "not applicable for classification models"), call. = FALSE)
if(!trControl$classProbs && metric == "ROC")
stop(paste("Class probabilities are needed to score models using the",
"area under the ROC curve. Set `classProbs = TRUE`",
"in the trainControl() function."))
"in the trainControl() function."), call. = FALSE)

if(trControl$classProbs) {
if(!is.function(models$prob)) {
Expand All @@ -335,7 +339,7 @@ train.default <- function(x, y,
}
} else {
if(metric %in% c("Accuracy", "Kappa"))
stop(paste("Metric", metric, "not applicable for regression models"))
stop(paste("Metric", metric, "not applicable for regression models"), call. = FALSE)
classLevels <- NA
if(trControl$classProbs) {
warning("cannnot compute class probabilities for regression")
Expand All @@ -345,14 +349,14 @@ train.default <- function(x, y,


if(trControl$method == "oob" & is.null(models$oob))
stop("Out of bag estimates are not implemented for this model")
stop("Out of bag estimates are not implemented for this model", call. = FALSE)

## SURV TODO: make resampling functions classes or ifelses based on data type

## If they don't exist, make the data partitions for the resampling iterations.
if(is.null(trControl$index)) {
if(trControl$method == "custom")
stop("'custom' resampling is appropriate when the `trControl` argument `index` is used")
stop("'custom' resampling is appropriate when the `trControl` argument `index` is used", call. = FALSE)
trControl$index <- switch(tolower(trControl$method),
oob = NULL,
none = list(seq(along = y)),
Expand All @@ -373,11 +377,11 @@ train.default <- function(x, y,
} else {
index_types <- unlist(lapply(trControl$index, is.integer))
if(!isTRUE(all(index_types)))
stop("`index` should be lists of integers.")
stop("`index` should be lists of integers.", call. = FALSE)
if(!is.null(trControl$indexOut)) {
index_types <- unlist(lapply(trControl$indexOut, is.integer))
if(!isTRUE(all(index_types)))
stop("`indexOut` should be lists of integers.")
stop("`indexOut` should be lists of integers.", call. = FALSE)
}
}

Expand All @@ -393,7 +397,7 @@ train.default <- function(x, y,
trControl$savePredictions <- if(trControl$savePredictions) "all" else "none"
} else {
if(!(trControl$savePredictions %in% c("all", "final", "none")))
stop('`savePredictions` should be either logical or "all", "final" or "none"')
stop('`savePredictions` should be either logical or "all", "final" or "none"', call. = FALSE)
}

## Create holdout indices
Expand Down Expand Up @@ -453,7 +457,7 @@ train.default <- function(x, y,
## Check to make sure that there are tuning parameters in some cases
if(grepl("adaptive", trControl$method) & nrow(tuneGrid) == 1) {
stop(paste("For adaptive resampling, there needs to be more than one",
"tuning parameter for evaluation"))
"tuning parameter for evaluation"), call. = FALSE)
}

dotNames <- hasDots(tuneGrid, models)
Expand All @@ -464,11 +468,11 @@ train.default <- function(x, y,

if(!is.logical(goodNames) || !goodNames) {
stop(paste("The tuning parameter grid should have columns",
paste(tuneNames, collapse = ", ", sep = "")))
paste(tuneNames, collapse = ", ", sep = "")), call. = FALSE)
}

if(trControl$method == "none" && nrow(tuneGrid) != 1)
stop("Only one model should be specified in tuneGrid with no resampling")
stop("Only one model should be specified in tuneGrid with no resampling", call. = FALSE)


## In case prediction bounds are used, compute the limits. For now,
Expand Down Expand Up @@ -557,7 +561,7 @@ train.default <- function(x, y,
if(is.function(models$loop) && nrow(tuneGrid) > 1){
trainInfo <- models$loop(tuneGrid)
if(!all(c("loop", "submodels") %in% names(trainInfo)))
stop("The 'loop' function should produce a list with elements 'loop' and 'submodels'")
stop("The 'loop' function should produce a list with elements 'loop' and 'submodels'", call. = FALSE)
lengths <- unlist(lapply(trainInfo$submodels, nrow))
if(all(lengths == 0)) trainInfo$submodels <- NULL
} else trainInfo <- list(loop = tuneGrid)
Expand All @@ -583,8 +587,8 @@ train.default <- function(x, y,
num_rs + 1, "with",
num_rs, "integer vectors of size",
nrow(trainInfo$loop), "and the last list element having at least a",
"single integer"))
if(any(is.na(unlist(trControl$seeds)))) stop("At least one seed is missing (NA)")
"single integer"), call. = FALSE)
if(any(is.na(unlist(trControl$seeds)))) stop("At least one seed is missing (NA)", call. = FALSE)
}
}

Expand Down Expand Up @@ -683,7 +687,7 @@ train.default <- function(x, y,
if(all(is.na(performance[, metric]))) {
cat(paste("Something is wrong; all the", metric, "metric values are missing:\n"))
print(summary(performance[, perfCols[!grepl("SD$", perfCols)], drop = FALSE]))
stop("Stopping")
stop("Stopping", call. = FALSE)
}

## Sort the tuning parameters from least complex to most complex
Expand Down Expand Up @@ -726,7 +730,7 @@ train.default <- function(x, y,
}
}

if(is.na(bestIter) || length(bestIter) != 1) stop("final tuning parameters could not be determined")
if(is.na(bestIter) || length(bestIter) != 1) stop("final tuning parameters could not be determined", call. = FALSE)

if(grepl("adapt", trControl$method)) {
best_perf <- perf_check[bestIter,as.character(models$parameters$parameter),drop = FALSE]
Expand Down Expand Up @@ -901,7 +905,7 @@ train.formula <- function (form, data, ..., weights, subset, na.action = na.fail
# do we need the double colon here?
m[[1]] <- quote(stats::model.frame)
m <- eval.parent(m)
if(nrow(m) < 1) stop("Every row has at least one missing value were found")
if(nrow(m) < 1) stop("Every row has at least one missing value were found", call. = FALSE)
Terms <- attr(m, "terms")
x <- model.matrix(Terms, m, contrasts)
cons <- attr(x, "contrast")
Expand Down Expand Up @@ -936,12 +940,12 @@ summary.train <- function(object, ...) summary(object$finalModel, ...)
#' @importFrom stats predict residuals
#' @export
residuals.train <- function(object, ...) {
if(object$modelType != "Regression") stop("train() only produces residuals on numeric outcomes")
if(object$modelType != "Regression") stop("train() only produces residuals on numeric outcomes", call. = FALSE)
resid <- residuals(object$finalModel, ...)
if(is.null(resid)) {
if(!is.null(object$trainingData)) {
resid <- object$trainingData$.outcome - predict(object, object$trainingData[, names(object$trainingData) != ".outcome",drop = FALSE])
} else stop("The training data must be saved to produce residuals")
} else stop("The training data must be saved to produce residuals", call. = FALSE)
}
resid
}
Expand All @@ -953,7 +957,7 @@ fitted.train <- function(object, ...) {
if(is.null(prd)) {
if(!is.null(object$trainingData)) {
prd <- predict(object, object$trainingData[, names(object$trainingData) != ".outcome",drop = FALSE])
} else stop("The training data must be saved to produce fitted values")
} else stop("The training data must be saved to produce fitted values", call. = FALSE)
}
prd

Expand Down
2 changes: 1 addition & 1 deletion pkg/caret/man/modelLookup.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/caret/man/train.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit fc69e02

Please sign in to comment.