Skip to content

Commit

Permalink
Merge 0801d92 into e20620d
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-czech committed Feb 16, 2016
2 parents e20620d + 0801d92 commit be15bf9
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 4 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ export(caretEnsemble)
export(caretList)
export(caretModelSpec)
export(caretStack)
export(getBinaryTargetLevel)
export(getMetric)
export(getMetricSD)
export(is.caretEnsemble)
export(is.caretList)
export(is.caretStack)
export(setBinaryTargetLevel)
importFrom(caret,createDataPartition)
importFrom(caret,createFolds)
importFrom(caret,createMultiFolds)
Expand Down
4 changes: 3 additions & 1 deletion R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ predict.caretList <- function(object, newdata = NULL, ..., verbose = FALSE){
type <- x$modelType
if (type=="Classification"){
if(x$control$classProbs){
caret::predict.train(x, type="prob", newdata=newdata, ...)[, 2]
# Return probability predictions for only one of the classes
# as determined by configured default response class level
caret::predict.train(x, type="prob", newdata=newdata, ...)[, getBinaryTargetLevel()]
} else{
caret::predict.train(x, type="raw", newdata=newdata, ...)
}
Expand Down
4 changes: 3 additions & 1 deletion R/caretStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ predict.caretStack <- function(
out <- predict(object$ens_model, newdata=preds, ...)
# Need a check here
if(class(out) %in% c("data.frame", "matrix")){
est <- out[, 2, drop = TRUE] # return only the probabilities for the second class
# Return probability predictions for only one of the classes
# as determined by configured default response class level
est <- out[, getBinaryTargetLevel(), drop = TRUE]
} else{
est <- out
}
Expand Down
57 changes: 56 additions & 1 deletion R/helper_functions.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,52 @@
#####################################################
# Configuration Functions
#####################################################
#' @title Return the configured target binary class level
#' @description For binary classification problems, ensemble
#' stacks and certain performance measures require an awareness
#' of which class in a two-factor outcome is the "target" class.
#' By default, this class will be assumed to be the first level in
#' an outcome factor but that setting can be overridden using
#' \code{setBinaryTargetLevel(2L)}.
#' @seealso setBinaryTargetLevel
#' @return Currently configured binary target level (as integer equal to 1 or 2)
#' @export
getBinaryTargetLevel <- function() {
arg <- getOption("caret.ensemble.binary.target.level", default = 1L)
validateBinaryTargetLevel(arg)
}

#' @title Set the target binary class level
#' @description For binary classification problems, ensemble
#' stacks and certain performance measures require an awareness
#' of which class in a two-factor outcome is the "target" class.
#' By default, the first level in an outcome factor is used but
#' this value can be overridden using \code{setBinaryTargetLevel(2L)}
#' @param level an integer in \{1, 2\} to be used as target outcome level
#' @seealso getBinaryTargetLevel
#' @export
setBinaryTargetLevel <- function(level){
level <- validateBinaryTargetLevel(level)
options(caret.ensemble.binary.target.level=level)
}

#' @title Validate arguments given as binary target level
#' @description Helper function used to ensure that target
#' binary class levels given by clients can be coerced to an integer
#' and that the resulting integer is in \{1, 2\}.
#' @param arg argument to potentially be used as new target level
#' @return Binary target level (as integer equal to 1 or 2)
validateBinaryTargetLevel <- function(arg){
val <- suppressWarnings(try(as.integer(arg), silent=T))
if (!is.integer(val) || !val %in% c(1L, 2L))
stop(paste0(
"Specified target binary class level is not valid. ",
"Value should be either 1 or 2 but '", arg, "' was given ",
"(see caretEnsemble::setBinaryTargetLevel for more details)"))
val
}


#####################################################
# Misc. Functions
#####################################################
Expand Down Expand Up @@ -238,7 +287,13 @@ makePredObsMatrix <- function(list_of_models){
#For classification models that produce probs, use the probs as preds
#Otherwise, just use class predictions
if (type=="Classification"){
positive <- as.character(unique(modelLibrary$obs)[2]) #IMPROVE THIS!
# Determine the string name for the positive class
if (!is.factor(modelLibrary$obs) || length(levels(modelLibrary$obs)) != 2)
stop("Response vector must be a two-level factor for classification.")
positive <- levels(modelLibrary$obs)[getBinaryTargetLevel()]

# Use the string name for the positive class determined above to select
# predictions from base estimators as predictors for ensemble model
pos <- as.numeric(modelLibrary[[positive]])
good_pos_values <- which(is.finite(pos))
set(modelLibrary, j="pred", value=as.numeric(modelLibrary[["pred"]]))
Expand Down
23 changes: 23 additions & 0 deletions man/getBinaryTargetLevel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions man/setBinaryTargetLevel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/validateBinaryTargetLevel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

140 changes: 140 additions & 0 deletions tests/testthat/test-classSelection.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
context("Does binary class selection work?")
library(caret)
library(caretEnsemble)

# Load and prepare data for subsequent tests
seed <- 2239
set.seed(seed)
data(models.class)
data(X.class)
data(Y.class)

# Create 80/20 train/test split
index <- createDataPartition(Y.class, p=.8)[[1]]
X.train <- X.class[index, ]; X.test <- X.class[-index, ]
Y.train <- Y.class[index]; Y.test <- Y.class[-index]

#############################################################################
context("Do classifier predictions use the correct target classes?")
#############################################################################

runBinaryLevelValidation <- function(Y.train, Y.test, pos.level=1){

# Extract levels of response input data
Y.levels <- levels(Y.train)
expect_identical(Y.levels, levels(Y.test))
expect_equal(length(Y.levels), 2)

# Manually generate fold indexes. Note that this must be
# done explicitly because using built-in caret functions like
# "createFolds" generate splits that are dependent upon the alphabetical
# order of the factor levels in the response (which is what this test module
# needs to prove invariance to). This happens because createFolds uses
# the base table function to create class frequency counts and that table
# command sorts results alphabetically.
k <- 3
folds <- sample(seq_along(Y.train))
folds <- setNames(split(folds, seq_along(folds) %% k), sprintf("Fold%s", 1:k))
folds <- lapply(folds, function(x) setdiff(seq_along(Y.train), x))
fold.idx <- sort(unique(unlist(folds)))
expect_true(all(fold.idx == seq_along(Y.train)), "CV indexes not generated correctly")

# Train a caret ensemble
ctrl <- trainControl(method="cv", savePredictions="final", classProbs=TRUE, index=folds)
model.list <- caretList(
X.train, Y.train, metric = "Accuracy",
trControl = ctrl, methodList = c("rpart", "glmnet"))
model.ens <- caretEnsemble(model.list)

# Verify that the observed responses in each fold, for each model,
# have the same levels and that the first level is equal to the first
# level in the original data (i.e. Y.class). This check exists to
# avoid regressions to bugs like this:
# https://github.com/zachmayer/caretEnsemble/pull/190
unique.levels <- unique(sapply(model.ens$models, function(x) levels(x$pred$obs)[1]))
expect_identical(unique.levels, Y.levels[1])

# Verify that the training data given to the ensemble model has the
# same levels in the response as the original, raw data
expect_identical(levels(model.ens$ens_model$trainingData$.outcome), Y.levels)

# Create class and probability predictions, as well as class predictions
# generated from probability predictions using a .5 cutoff
Y.pred <- predict(model.ens, newdata=X.test, type="raw")
Y.prob <- predict(model.ens, newdata=X.test, type="prob")
Y.cutoff <- factor(ifelse(Y.prob > .5, Y.levels[pos.level], Y.levels[-pos.level]), levels=Y.levels)

# Create confusion matricies for each class prediction vector
cmat.pred <- confusionMatrix(Y.pred, Y.test, positive=Y.levels[pos.level])
cmat.cutoff <- confusionMatrix(Y.cutoff, Y.test, positive=Y.levels[pos.level])

# Verify that the positive level of the Y response is equal to the positive
# class label used by caret. This could potentially become untrue if
# the levels of the response were ever rearranged by caretEnsemble at some point.
expect_identical(cmat.pred$positive, Y.levels[pos.level])

# Verify that the accuracy score on predicted classes is relatively high. This
# check exists to avoid previous errors where classifer ensemble predictions were
# being made using the incorrect level of the response, causing the opposite
# class labels to be predicted with new data.
expect_equal(as.numeric(cmat.pred$overall["Accuracy"]), 0.7586, tol = 0.0001)

# Similar to the above, ensure that probability predictions are working correctly
# by checking to see that accuracy is also high for class predictions created
# from probabilities
expect_equal(as.numeric(cmat.cutoff$overall["Accuracy"]), 0.7586, tol = 0.0001)
}

test_that("Ensembled classifiers do not rearrange outcome factor levels", {
skip_on_cran()

# Make sure that caretEnsemble uses the first level in the
# outcome factor as the target class
bin.level <- getBinaryTargetLevel()
setBinaryTargetLevel(1L)

# First run the level selection test using the default levels
# of the response (i.e. c('No', 'Yes'))
set.seed(seed)
runBinaryLevelValidation(Y.train, Y.test)

# Now reverse the assigment of the response labels as well as
# the levels of the response factor. Reversing the assignment
# is necessary to make sure the expected accuracy numbers are
# the same (i.e. Making a "No" into a "Yes" in the response means
# predictions of the first class will still be as accurate).
# Reversing the level order then ensures that the outcome is not
# releveled at some point by caretEnsemble.
Y.levels <- levels(Y.train)
refactor <- function(d) factor(
ifelse(d == Y.levels[1], Y.levels[2], Y.levels[1]),
levels=rev(Y.levels))

set.seed(seed)
runBinaryLevelValidation(refactor(Y.train), refactor(Y.test))

# Set the target binary level back to what it was before this test
setBinaryTargetLevel(bin.level)
})

test_that("Target class selection configuration works", {
skip_on_cran()

# Get the current target binary level
bin.level <- getBinaryTargetLevel()

# Verify binary target level argument validation
expect_error(setBinaryTargetLevel("x"))

# Configure caret ensemble to use the second class as the target
setBinaryTargetLevel(2L)

Y.levels <- levels(Y.train)
refactor <- function(d) factor(as.character(d), levels=rev(Y.levels))

set.seed(seed)
runBinaryLevelValidation(refactor(Y.train), refactor(Y.test), pos.level=2)

# Set the target binary level back to what it was before this test
setBinaryTargetLevel(bin.level)
})
2 changes: 1 addition & 1 deletion tests/testthat/test-ensemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ test_that("It works for classification models", {
expect_true(is.numeric(pred.class))
expect_true(length(pred.class)==150)
expect_identical(pred.class, pred.classb)
expect_less_than(abs(0.03727609 - pred.classc), 0.01)
expect_less_than(abs(0.9633519 - pred.classc), 0.01)
expect_is(pred.class, "numeric")
expect_is(pred.classb, "numeric")
expect_is(pred.classc, "numeric")
Expand Down

0 comments on commit be15bf9

Please sign in to comment.