Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

Glmnet varimp fix 2 #190

Closed
wants to merge 6 commits into
from
View
@@ -122,7 +122,7 @@ modelInfo <- list(label = "glmnet",
if(length(lambda) > 1) stop("Only one value of lambda is allowed right now")
if(!is.null(x$lambdaOpt)) {
lambda <- x$lambdaOpt
- } else stop("must supply a vaue of lambda")
+ } else stop("must supply a value of lambda")
}
allVar <- if(is.list(x$beta)) rownames(x$beta[[1]]) else rownames(x$beta)
out <- unlist(predict(x, s = lambda, type = "nonzero"))
@@ -138,14 +138,14 @@ modelInfo <- list(label = "glmnet",
if(length(lambda) > 1) stop("Only one value of lambda is allowed right now")
if(!is.null(object$lambdaOpt)) {
lambda <- object$lambdaOpt
- } else stop("must supply a vaue of lambda")
+ } else stop("must supply a value of lambda")
}
beta <- predict(object, s = lambda, type = "coef")
if(is.list(beta)) {
out <- do.call("cbind", lapply(beta, function(x) x[,1]))
out <- as.data.frame(out)
} else out <- data.frame(Overall = beta[,1])
- out <- out[rownames(out) != "(Intercept)",,drop = FALSE]
+ out <- abs(out[rownames(out) != "(Intercept)",,drop = FALSE])
out
},
levels = function(x) if(any(names(x) == "obsLevels")) x$obsLevels else NULL,
Binary file not shown.
@@ -0,0 +1,25 @@
+library(caret)
+
+context('Testing varImp')
+
+test_that('glmnet varImp returns non-negative values', {
+ skip_on_cran()
+ skip_if_not_installed('glmnet')
+ set.seed(1)
+ dat <- SLC14_1(200)
+
+ reg <- train(y ~ ., data = dat,
+ method = "glmnet",
+ tuneGrid = data.frame(lambda = .1, alpha = .5),
+ trControl = trainControl(method = "none"))
+
+ # this checks that some coefficients are negative
+ coefs <- predict(reg$finalModel, s=0.1, type="coef")
+ expect_less_than(0, sum(0 > coefs))
+ # now check that all elements of varImp are nonnegative,
+ # in spite of negative coefficients
+ vis <- varImp(reg, s=0.1, scale=F)$importance
+ expect_equal(0, sum(0 > vis))
+})
+
+