Skip to content

Commit

Permalink
Adding mutators for binary target level
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-czech authored and zachmayer committed Feb 16, 2016
1 parent bd7ccf5 commit 56dc094
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 26 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ export(caretEnsemble)
export(caretList)
export(caretModelSpec)
export(caretStack)
export(getBinaryTargetLevel)
export(getMetric)
export(getMetricSD)
export(is.caretEnsemble)
export(is.caretList)
export(is.caretStack)
export(setBinaryTargetLevel)
importFrom(caret,createDataPartition)
importFrom(caret,createFolds)
importFrom(caret,createMultiFolds)
Expand Down
2 changes: 1 addition & 1 deletion R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ predict.caretList <- function(object, newdata = NULL, ..., verbose = FALSE){
if(x$control$classProbs){
# Return probability predictions for only one of the classes
# as determined by configured default response class level
caret::predict.train(x, type="prob", newdata=newdata, ...)[, getBinaryLevel()]
caret::predict.train(x, type="prob", newdata=newdata, ...)[, getBinaryTargetLevel()]
} else{
caret::predict.train(x, type="raw", newdata=newdata, ...)
}
Expand Down
2 changes: 1 addition & 1 deletion R/caretStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ predict.caretStack <- function(
if(class(out) %in% c("data.frame", "matrix")){
# Return probability predictions for only one of the classes
# as determined by configured default response class level
est <- out[, getBinaryLevel(), drop = TRUE]
est <- out[, getBinaryTargetLevel(), drop = TRUE]
} else{
est <- out
}
Expand Down
53 changes: 41 additions & 12 deletions R/helper_functions.R
Original file line number Diff line number Diff line change
@@ -1,20 +1,49 @@
#####################################################
# Configuration Functions
#####################################################
#' @title Return the configured default binary class level
#' @title Return the configured target binary class level
#' @description For binary classification problems, ensemble
#' stacks and certain performance measures require an awareness
#' of which class in a two-factor outcome is "positive". By default,
#' this class will be assumed to be the first class in an outcome factor
#' but that value can be overriden using global options (e.g.
#' \code{options(caret.ensemble.target.bin.level=2)}).
getBinaryLevel <- function() {
value <- as.numeric(getOption("caret.ensemble.target.bin.level", default = 1))
if (!value %in% c(1, 2))
#' of which class in a two-factor outcome is the "target" class.
#' By default, this class will be assumed to be the first level in
#' an outcome factor but that setting can be overridden using
#' \code{setBinaryTargetLevel(2L)}.
#' @seealso setBinaryTargetLevel
#' @return Currently configured binary target level (as integer equal to 1 or 2)
#' @export
getBinaryTargetLevel <- function() {
arg <- getOption("caret.ensemble.binary.target.level", default = 1L)
validateBinaryTargetLevel(arg)
}

#' @title Set the target binary class level
#' @description For binary classification problems, ensemble
#' stacks and certain performance measures require an awareness
#' of which class in a two-factor outcome is the "target" class.
#' By default, the first level in an outcome factor is used but
#' this value can be overridden using \code{setBinaryTargetLevel(2L)}
#' @param level an integer in \{1, 2\} to be used as target outcome level
#' @seealso getBinaryTargetLevel
#' @export
setBinaryTargetLevel <- function(level){
level <- validateBinaryTargetLevel(level)
options(caret.ensemble.binary.target.level=level)
}

#' @title Validate arguments given as binary target level
#' @description Helper function used to ensure that target
#' binary class levels given by clients can be coerced to an integer
#' and that the resulting integer is in \{1, 2\}.
#' @param arg argument to potentially be used as new target level
#' @return Binary target level (as integer equal to 1 or 2)
validateBinaryTargetLevel <- function(arg){
val <- suppressWarnings(try(as.integer(arg), silent=T))
if (!is.integer(val) || !val %in% c(1L, 2L))
stop(paste0(
"Configured default binary class level is not valid. ",
"Value should be either 1 or 2 but '", value, "' was given"))
value
"Specified target binary class level is not valid. ",
"Value should be either 1 or 2 but '", arg, "' was given ",
"(see caretEnsemble::setBinaryTargetLevel for more details)"))
val
}


Expand Down Expand Up @@ -261,7 +290,7 @@ makePredObsMatrix <- function(list_of_models){
# Determine the string name for the positive class
if (!is.factor(modelLibrary$obs) || length(levels(modelLibrary$obs)) != 2)
stop("Response vector must be a two-level factor for classification.")
positive <- levels(modelLibrary$obs)[getBinaryLevel()]
positive <- levels(modelLibrary$obs)[getBinaryTargetLevel()]

# Use the string name for the positive class determined above to select
# predictions from base estimators as predictors for ensemble model
Expand Down
23 changes: 23 additions & 0 deletions man/getBinaryTargetLevel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions man/setBinaryTargetLevel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/validateBinaryTargetLevel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 19 additions & 12 deletions tests/testthat/test-classSelection.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ library(caret)
library(caretEnsemble)

# Load and prepare data for subsequent tests
set.seed(2239)
seed <- 2239
set.seed(seed)
data(models.class)
data(X.class)
data(Y.class)
Expand Down Expand Up @@ -89,11 +90,12 @@ test_that("Ensembled classifiers do not rearrange outcome factor levels", {

# Make sure that caretEnsemble uses the first level in the
# outcome factor as the target class
options(caret.ensemble.target.bin.level=1)
bin.level <- getBinaryTargetLevel()
setBinaryTargetLevel(1L)

# First run the level selection test using the default levels
# of the response (i.e. c('No', 'Yes'))
set.seed(2239)
set.seed(seed)
runBinaryLevelValidation(Y.train, Y.test)

# Now reverse the assigment of the response labels as well as
Expand All @@ -108,26 +110,31 @@ test_that("Ensembled classifiers do not rearrange outcome factor levels", {
ifelse(d == Y.levels[1], Y.levels[2], Y.levels[1]),
levels=rev(Y.levels))

set.seed(2239)
set.seed(seed)
runBinaryLevelValidation(refactor(Y.train), refactor(Y.test))

# Set the target binary level back to what it was before this test
setBinaryTargetLevel(bin.level)
})

test_that("Target class selection configuration works", {
skip_on_cran()
set.seed(2239)
expect_equal(1, 1)

# Make sure that caretEnsemble uses the first level in the
# outcome factor as the target class
options(caret.ensemble.target.bin.level=2)
# Get the current target binary level
bin.level <- getBinaryTargetLevel()

# Verify binary target level argument validation
expect_error(setBinaryTargetLevel("x"))

# Configure caret ensemble to use the second class as the target
setBinaryTargetLevel(2L)

Y.levels <- levels(Y.train)
refactor <- function(d) factor(as.character(d), levels=rev(Y.levels))

set.seed(2239)
set.seed(seed)
runBinaryLevelValidation(refactor(Y.train), refactor(Y.test), pos.level=2)

# Set the target class back to the default level of 1
options(caret.ensemble.target.bin.level=1)
# Set the target binary level back to what it was before this test
setBinaryTargetLevel(bin.level)
})

0 comments on commit 56dc094

Please sign in to comment.