Skip to content

Commit

Permalink
added glmnet model object trimming
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Jan 28, 2015
1 parent 93f2378 commit c29d5aa
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
8 changes: 7 additions & 1 deletion models/files/glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,10 @@ modelInfo <- list(label = "glmnet",
tags = c("Generalized Linear Model", "Implicit Feature Selection",
"L1 Regularization", "L2 Regularization", "Linear Classifier",
"Linear Regression"),
sort = function(x) x[order(-x$lambda, x$alpha),])
sort = function(x) x[order(-x$lambda, x$alpha),],
trim = function(x) {
x$call <- NULL
x$df <- NULL
x$dev.ratio <- NULL
x
})
Binary file modified pkg/caret/inst/models/models.RData
Binary file not shown.
57 changes: 57 additions & 0 deletions pkg/caret/tests/testthat/trim_glmnet.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
library(caret)

test_that('glmnet classification', {
skip_on_cran()
set.seed(1)
tr_dat <- twoClassSim(200)
te_dat <- twoClassSim(200)

set.seed(2)
class_trim <- train(Class ~ ., data = tr_dat,
method = "glmnet",
tuneGrid = data.frame(lambda = .1, alpha = .5),
trControl = trainControl(method = "none",
classProbs = TRUE,
trim = TRUE))

set.seed(2)
class_notrim <- train(Class ~ ., data = tr_dat,
method = "glmnet",
tuneGrid = data.frame(lambda = .1, alpha = .5),
trControl = trainControl(method = "none",
classProbs = TRUE,
trim = FALSE))

expect_equal(predict(class_trim, te_dat),
predict(class_notrim, te_dat))

expect_equal(predict(class_trim, te_dat, type = "prob"),
predict(class_notrim, te_dat, type = "prob"))

expect_less_than(object.size(class_trim)-object.size(class_notrim), 0)
})

test_that('glmnet regression', {
skip_on_cran()
set.seed(1)
tr_dat <- SLC14_1(200)
te_dat <- SLC14_1(200)

set.seed(2)
reg_trim <- train(y ~ ., data = tr_dat,
method = "glmnet",
tuneGrid = data.frame(lambda = .1, alpha = .5),
trControl = trainControl(method = "none",
trim = TRUE))

set.seed(2)
reg_notrim <- train(y ~ ., data = tr_dat,
method = "glmnet",
tuneGrid = data.frame(lambda = .1, alpha = .5),
trControl = trainControl(method = "none",
trim = FALSE))
expect_equal(predict(reg_trim, te_dat),
predict(reg_notrim, te_dat))
expect_less_than(object.size(reg_trim)-object.size(reg_notrim), 0)
})

0 comments on commit c29d5aa

Please sign in to comment.