Skip to content

Commit

Permalink
Merge 0bfaff9 into 330b74c
Browse files Browse the repository at this point in the history
  • Loading branch information
zachmayer committed Jun 23, 2015
2 parents 330b74c + 0bfaff9 commit 2cf56c5
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 13 deletions.
6 changes: 0 additions & 6 deletions R/OptRMSE.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,5 @@ greedOptRMSE <- function(X, Y, iter = 100L){
}
weights2 <- weights/sum(weights)
maxtest <- sqrt(sum((X %*% weights2 - Y) ^ 2L, na.rm=TRUE))
if(stopper < maxtest){
testresult <- round(maxtest/stopper, 5) * 100
wstr <- paste0("Optimized weights not better than best model. Ensembled result is ",
testresult, "%", " of best model RMSE. Try more iterations.")
message(wstr)
}
return(weights)
}
5 changes: 1 addition & 4 deletions R/caretStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,5 @@ plot.caretStack <- function(x, ...){
#' dotplot.caretStack(meta_model)
#' }
dotplot.caretStack <- function(x, data=NULL, ...){
final <- list(x$ens_model)
names(final) <- paste(paste(x$ens_model$method, collapse="_"), "ENSEMBLE", sep="_")
base <- x$models
dotplot(resamples(c(final, base)), data=data, ...)
dotplot(resamples(x$models), data=data, ...)
}
100 changes: 100 additions & 0 deletions tests/testthat/test-caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,84 @@ test_that("caretModelSpec returns valid specs", {
expect_true(is.list(tuneList))
expect_equal(length(tuneList), 4)
expect_equal(sum(duplicated(names(tuneList))), 0)
})

test_that("caretModelSpec and checking functions work as expected", {

all_models <- sort(unique(modelLookup()$model))
for(model in all_models){
expect_equal(caretModelSpec(model, tuneLength=5, preProcess='knnImpute')$method, model)
}

tuneList <- lapply(all_models, function(x) list(method=x, preProcess='pca'))
all_models_check <- tuneCheck(tuneList)
expect_is(all_models_check, 'list')
expect_equal(length(all_models), length(all_models_check))

tuneList <- lapply(all_models, function(x) list(method=x, preProcess='pca'))
names(tuneList) <- all_models
names(tuneList)[c(1, 5, 10)] <- ""
all_models_check <- tuneCheck(tuneList)
expect_is(all_models_check, 'list')
expect_equal(length(all_models), length(all_models_check))

methodCheck(all_models)
expect_error(methodCheck(c(all_models, 'THIS_IS_NOT_A_REAL_MODEL')))
expect_error(methodCheck(c(all_models, 'THIS_IS_NOT_A_REAL_MODEL', 'GBM')))
})

test_that("Target extraction functions work", {
data(iris)
expect_equal(extractCaretTarget(iris[,1:4], iris[,5]), iris[,5])
expect_equal(extractCaretTarget(iris[,2:5], iris[,1]), iris[,1])
expect_equal(extractCaretTarget(Species ~ ., iris), iris[,'Species'])
expect_equal(extractCaretTarget(Sepal.Width ~ ., iris), iris[,'Sepal.Width'])
})

test_that("caretList errors for bad models", {
data(iris)
expect_error(caretList(Sepal.Width ~ ., iris))
expect_warning(caretList(Sepal.Width ~ ., iris, methodList=c('lm', 'lm')))
expect_warning(expect_is(caretList(Sepal.Width ~ ., iris, methodList='lm', continue_on_fail=TRUE), 'caretList'))

my_control <- trainControl(method='cv', number=2)
bad_bad <- list(
bad1=caretModelSpec(method="glm", tuneLength=1),
bad2=caretModelSpec(method="glm", tuneLength=1)
)
good_bad <- list(
good=caretModelSpec(method="glmnet", tuneLength=1),
bad=caretModelSpec(method="glm", tuneLength=1)
)
sink <- capture.output(expect_error(caretList(iris[,1:4], iris[,5], tuneList=bad_bad, trControl=my_control)))
sink <- capture.output(expect_error(caretList(iris[,1:4], iris[,5], tuneList=good_bad, trControl=my_control)))
sink <- capture.output(expect_error(caretList(iris[,1:4], iris[,5], tuneList=bad_bad, trControl=my_control, continue_on_fail=TRUE)))
sink <- capture.output(expect_is(caretList(iris[,1:4], iris[,5], tuneList=good_bad, trControl=my_control, continue_on_fail=TRUE), 'caretList'))
})

test_that("caretList predictions", {
models <- caretList(
iris[,1:2], iris[,5],
tuneLength=1,
methodList=c('rf', 'gbm'),
trControl=trainControl(method='cv', number=2, savePredictions=TRUE, classProbs=FALSE))
p1 <- predict(models)
expect_is(p1, 'matrix')
expect_is(p1[,1], 'character')
expect_is(p1[,2], 'character')

models <- caretList(
iris[,1:2], iris[,5],
tuneLength=1,
methodList=c('rf', 'gbm'),
trControl=trainControl(method='cv', number=2, savePredictions=TRUE, classProbs=TRUE))
p2 <- predict(models)
expect_is(p2, 'matrix')
expect_is(p2[,1], 'numeric')
expect_is(p2[,2], 'numeric')

models[[1]]$modelType <- "Bogus"
expect_error(predict(models))
})

###############################################
Expand Down Expand Up @@ -65,6 +143,7 @@ test_that("We can handle different CV methods", {
"boot",
"adaptive_boot",
"cv",
"repeatedcv",
"adaptive_cv",
"LGOCV",
"adaptive_LGOCV")
Expand Down Expand Up @@ -111,6 +190,27 @@ test_that("We can handle different CV methods", {
}
})

test_that("CV methods we can't handle fail", {
for(m in c(
"boot632",
"LOOCV",
"none",
"oob"
)
){
test_name <- paste0("CV doesn't works with method=", m)
test_that(test_name, {
data(iris)
model <- train(
Sepal.Length ~ Sepal.Width, tuneLength=1,
data=iris, method=ifelse(m=='oob', 'rf', 'lm'),
trControl=trainControl(method=m))
expect_is(model, 'train')
expect_error(trControlCheck(model))
})
}
})

###############################################
context("Classification models")
################################################
Expand Down
18 changes: 17 additions & 1 deletion tests/testthat/test-caretStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,36 @@ test_that("We can stack regression models", {
ens.reg <- caretStack(models.reg, method="lm", preProcess="pca",
trControl=trainControl(number=2, allowParallel=FALSE))
expect_that(ens.reg, is_a("caretStack"))
expect_is(summary(ens.reg), 'summary.lm')
sink <- capture.output(print(ens.reg))
pred.reg <- predict(ens.reg, X.reg)
expect_true(is.numeric(pred.reg))
expect_true(length(pred.reg)==150)
})

test_that("We can stack classification models", {
set.seed(42)
ens.class <- caretStack(models.class, method="rpart",
ens.class <- caretStack(models.class, method="glm",
trControl=trainControl(number=2, allowParallel=FALSE))
expect_that(ens.class, is_a("caretStack"))
expect_is(summary(ens.class), 'summary.glm')
sink <- capture.output(print(ens.class))
pred.class <- predict(ens.class, X.class, type="prob")[,2]
expect_true(is.numeric(pred.class))
expect_true(length(pred.class)==150)
raw.class <- predict(ens.class, X.class, type="raw")
expect_true(is.factor(raw.class))
expect_true(length(raw.class)==150)
})

test_that("caretStack plots", {
test_plot_file <- "caretEnsemble_test_plots.png"
ens.reg <- caretStack(
models.reg, method="gbm", tuneLength=2, verbose=FALSE,
trControl=trainControl(number=2, allowParallel=FALSE))
png(filename = test_plot_file)
plot(ens.reg)
dotplot(ens.reg, metric='RMSE')
dev.off()
unlink(test_plot_file)
})
43 changes: 41 additions & 2 deletions tests/testthat/test-helper_functions.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

#TODO: add tests for every helper function

########################################################################
context("Do the helper functions work for regression objects?")
########################################################################
library("caret")
library("randomForest")
library("rpart")
Expand All @@ -19,6 +19,43 @@ load(system.file("testdata/X.class.rda",
load(system.file("testdata/Y.class.rda",
package="caretEnsemble", mustWork=TRUE))

test_that("Recycling generates a warning", {
expect_error(wtd.sd(matrix(1:10, ncol=2),weights=1))
})

test_that("No predictions generates an error", {
models_multi <- caretList(
iris[,1:2], iris[,5],
tuneLength=1,
methodList=c('rf', 'gbm'),
trControl=trainControl(method='cv', number=2, savePredictions=TRUE, classProbs=TRUE))
expect_error(check_caretList_model_types(models_multi))

models <- caretList(
iris[,1:2], factor(ifelse(iris[,5]=='setosa', 'Yes', 'No')),
tuneLength=1,
methodList=c('rf', 'gbm'),
trControl=trainControl(method='cv', number=2, savePredictions=TRUE, classProbs=TRUE))
ctrl <- models[[1]]
new_model <- train(
iris[,1:2], factor(ifelse(iris[,5]=='setosa', 'Yes', 'No')),
tuneLength=1,
method=c('glmnet'),
trControl=trainControl(method='cv', number=2, savePredictions=FALSE, classProbs=TRUE)
)
models2 <- c(list('glmnet'=new_model), models)
models3 <- c(models, list('glmnet'=new_model))
check_caretList_model_types(models)
expect_error(check_caretList_model_types(models2))
#expect_error(check_caretList_model_types(models3)) #THIS IS A BUG THAT NEEDS FIXING!!!!!!!!!!


})

test_that("Multi-class generates an error", {
expect_error(wtd.sd(matrix(1:10, ncol=2),weights=1))
})

test_that("We can make the predobs matrix", {
out <- makePredObsMatrix(models.reg)
expect_that(out, is_a("list"))
Expand All @@ -33,7 +70,9 @@ test_that("We can predict", {
expect_true(all(colnames(out)==c("rf", "lm", "glm", "knn")))
})

########################################################################
context("Do the helper functions work for classification objects?")
########################################################################

test_that("We can make the predobs matrix", {
out <- makePredObsMatrix(models.class)
Expand Down
6 changes: 6 additions & 0 deletions tests/testthat/test-optimizers.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ myControl <- trainControl(

context("Test optimizer passing to caretEnsemble correctly")

test_that("optAUC converts to factor", {
x <- matrix(1:20, ncol=2)
y <- rep(c('a', 'b'), 5)
expect_is(greedOptAUC(x, y), 'integer')
})

test_that("Test that optFUN does not take random values", {
skip_on_cran()
myCL <- caretList(
Expand Down

0 comments on commit 2cf56c5

Please sign in to comment.