diff --git a/models/files/glmnet.R b/models/files/glmnet.R index 4650971b9..85106b8b4 100644 --- a/models/files/glmnet.R +++ b/models/files/glmnet.R @@ -134,4 +134,10 @@ modelInfo <- list(label = "glmnet", tags = c("Generalized Linear Model", "Implicit Feature Selection", "L1 Regularization", "L2 Regularization", "Linear Classifier", "Linear Regression"), - sort = function(x) x[order(-x$lambda, x$alpha),]) + sort = function(x) x[order(-x$lambda, x$alpha),], + trim = function(x) { + x$call <- NULL + x$df <- NULL + x$dev.ratio <- NULL + x + }) diff --git a/pkg/caret/inst/models/models.RData b/pkg/caret/inst/models/models.RData index 453da3f38..e966abc33 100644 Binary files a/pkg/caret/inst/models/models.RData and b/pkg/caret/inst/models/models.RData differ diff --git a/pkg/caret/tests/testthat/trim_glmnet.R b/pkg/caret/tests/testthat/trim_glmnet.R new file mode 100644 index 000000000..edba16e97 --- /dev/null +++ b/pkg/caret/tests/testthat/trim_glmnet.R @@ -0,0 +1,57 @@ +library(caret) + +test_that('glmnet classification', { + skip_on_cran() + set.seed(1) + tr_dat <- twoClassSim(200) + te_dat <- twoClassSim(200) + + set.seed(2) + class_trim <- train(Class ~ ., data = tr_dat, + method = "glmnet", + tuneGrid = data.frame(lambda = .1, alpha = .5), + trControl = trainControl(method = "none", + classProbs = TRUE, + trim = TRUE)) + + set.seed(2) + class_notrim <- train(Class ~ ., data = tr_dat, + method = "glmnet", + tuneGrid = data.frame(lambda = .1, alpha = .5), + trControl = trainControl(method = "none", + classProbs = TRUE, + trim = FALSE)) + + expect_equal(predict(class_trim, te_dat), + predict(class_notrim, te_dat)) + + expect_equal(predict(class_trim, te_dat, type = "prob"), + predict(class_notrim, te_dat, type = "prob")) + + expect_less_than(object.size(class_trim)-object.size(class_notrim), 0) +}) + +test_that('glmnet regression', { + skip_on_cran() + set.seed(1) + tr_dat <- SLC14_1(200) + te_dat <- SLC14_1(200) + + set.seed(2) + reg_trim <- train(y ~ ., data = tr_dat, + method = "glmnet", + tuneGrid = data.frame(lambda = .1, alpha = .5), + trControl = trainControl(method = "none", + trim = TRUE)) + + set.seed(2) + reg_notrim <- train(y ~ ., data = tr_dat, + method = "glmnet", + tuneGrid = data.frame(lambda = .1, alpha = .5), + trControl = trainControl(method = "none", + trim = FALSE)) + expect_equal(predict(reg_trim, te_dat), + predict(reg_notrim, te_dat)) + expect_less_than(object.size(reg_trim)-object.size(reg_notrim), 0) +}) +