Skip to content

Commit

Permalink
added functionality for distribution changes
Browse files Browse the repository at this point in the history
See
[#128
aret/issues/128)
  • Loading branch information
topepo committed Apr 2, 2015
1 parent b26543f commit cd7d095
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 34 deletions.
21 changes: 18 additions & 3 deletions RegressionTests/Code/gam.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ test_class_cv_model <- train(trainX, trainY,
metric = "ROC",
preProc = c("center", "scale"))

set.seed(849)
test_class_cv_dist <- train(trainX, trainY,
method = "gam",
trControl = cctrl1,
metric = "ROC",
preProc = c("center", "scale"),
family = negbin(theta = 1))

set.seed(849)
test_class_cv_form <- train(Class ~ ., data = training,
method = "gam",
Expand Down Expand Up @@ -94,11 +102,18 @@ test_reg_cv_model <- train(trainX, trainY,
preProc = c("center", "scale"))
test_reg_pred <- predict(test_reg_cv_model, testX)

set.seed(849)
test_reg_cv_dist <- train(trainX, trainY,
method = "gam",
trControl = rctrl1,
preProc = c("center", "scale"),
family = scat())

set.seed(849)
test_reg_cv_form <- train(y ~ ., data = training,
method = "gam",
trControl = rctrl1,
preProc = c("center", "scale"))
method = "gam",
trControl = rctrl1,
preProc = c("center", "scale"))
test_reg_pred_form <- predict(test_reg_cv_form, testX)

set.seed(849)
Expand Down
15 changes: 15 additions & 0 deletions RegressionTests/Code/gamSpline.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ test_class_cv_model <- train(trainX, trainY,
metric = "ROC",
preProc = c("center", "scale"))

set.seed(849)
test_class_cv_dist <- train(trainX, trainY,
method = "gamSpline",
trControl = cctrl1,
metric = "ROC",
preProc = c("center", "scale"),
family = binomial(link = "cloglog"))

set.seed(849)
test_class_cv_form <- train(Class ~ ., data = training,
method = "gamSpline",
Expand Down Expand Up @@ -95,6 +103,13 @@ test_reg_cv_model <- train(trainX, trainY,
preProc = c("center", "scale"))
test_reg_pred <- predict(test_reg_cv_model, testX)

set.seed(849)
test_reg_cv_dist <- train(trainX, abs(trainY),
method = "gamSpline",
trControl = rctrl1,
preProc = c("center", "scale"),
family = Gamma)

set.seed(849)
test_reg_cv_form <- train(y ~ ., data = training,
method = "gamSpline",
Expand Down
20 changes: 20 additions & 0 deletions RegressionTests/Code/gbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ test_class_cv_model <- train(trainX, trainY,
tuneGrid = gbmGrid,
verbose = FALSE)

set.seed(849)
test_class_cv_dist <- train(trainX, trainY,
method = "gbm",
trControl = cctrl1,
metric = "ROC",
preProc = c("center", "scale"),
tuneGrid = gbmGrid,
verbose = FALSE
distribution = "adaboost")


set.seed(849)
test_class_cv_form <- train(Class ~ ., data = training,
method = "gbm",
Expand Down Expand Up @@ -108,6 +119,15 @@ test_reg_cv_model <- train(trainX, trainY,
verbose = FALSE)
test_reg_pred <- predict(test_reg_cv_model, testX)

set.seed(849)
test_reg_cv_dist <- train(trainX, trainY,
method = "gbm",
trControl = rctrl1,
preProc = c("center", "scale"),
tuneGrid = gbmGrid,
verbose = FALSE,
distribution = "laplace")

set.seed(849)
test_reg_cv_form <- train(y ~ ., data = training,
method = "gbm",
Expand Down
24 changes: 10 additions & 14 deletions models/files/gam.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,22 @@ modelInfo <- list(label = "Generalized Additive Model using Splines",
dat$.outcome <- y
dist <- gaussian()
}
out <- mgcv:::gam(modForm, data = dat, family = dist,
modelArgs <- list(formula = modForm,
data = dat,
select = param$select,
method = as.character(param$method),
...)
# if(is.null(wts)) {
#
# } else {
# out <- mgcv:::gam(modForm, data = dat, family = dist,
# select = param$select,
# method = as.character(param$method),
# weights = wts,
# ...)
# }
method = as.character(param$method))
## Intercept family if passed in
theDots <- list(...)
if(!any(names(theDots) == "family")) modelArgs$family <- dist
modelArgs <- c(modelArgs, theDots)

out <- do.call(getFromNamespace("gam", "mgcv"), modelArgs)
out

},
predict = function(modelFit, newdata, submodels = NULL) {
if(!is.data.frame(newdata)) newdata <- as.data.frame(newdata)
if(modelFit$problemType == "Classification")
{
if(modelFit$problemType == "Classification") {
probs <- predict(modelFit, newdata, type = "response")
out <- ifelse(probs < .5,
modelFit$obsLevel[1],
Expand Down
23 changes: 14 additions & 9 deletions models/files/gamLoess.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@ modelInfo <- list(label = "Generalized Additive Model using LOESS",
grid = function(x, y, len = NULL)
expand.grid(span = .5, degree = 1),
fit = function(x, y, wts, param, lev, last, classProbs, ...) {
dat <- if(is.data.frame(x)) x else as.data.frame(x)
dat$.outcome <- y
args <- list(data = if(is.data.frame(x)) x else as.data.frame(x))
args$data$.outcome <- y
args$formula <- caret:::smootherFormula(x,
smoother = "lo",
span = param$span,
degree = param$degree)
theDots <- list(...)

gam:::gam(caret:::smootherFormula(x,
smoother = "lo",
span = param$span,
degree = param$degree),
data = dat,
family = if(is.factor(y)) binomial() else gaussian(),
...)

if(!any(names(theDots) == "family"))
args$family <- if(is.factor(y)) binomial else gaussian

if(length(theDots) > 0) args <- c(args, theDots)

do.call(getFromNamespace("gam", "gam"), args)
},
predict = function(modelFit, newdata, submodels = NULL) {
if(!is.data.frame(newdata)) newdata <- as.data.frame(newdata)
Expand Down
7 changes: 5 additions & 2 deletions models/files/gamSpline.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ modelInfo <- list(label = "Generalized Additive Model using Splines",
args$formula <- caret:::smootherFormula(x,
smoother = "s",
df = param$df)
args$family <- if(is.factor(y)) binomial else gaussian

theDots <- list(...)


if(!any(names(theDots) == "family"))
args$family <- if(is.factor(y)) binomial else gaussian

if(length(theDots) > 0) args <- c(args, theDots)

do.call(getFromNamespace("gam", "gam"), args)
Expand Down
4 changes: 2 additions & 2 deletions models/files/gbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ modelInfo <- list(label = "Stochastic Gradient Boosting",
modDist <- theDots$distribution
theDots$distribution <- NULL
} else {
if(is.numeric(y))
{
if(is.numeric(y)) {
modDist <- "gaussian"
} else modDist <- if(length(lev) == 2) "bernoulli" else "multinomial"
}
Expand All @@ -45,6 +44,7 @@ modelInfo <- list(label = "Stochastic Gradient Boosting",
n.trees = param$n.trees,
shrinkage = param$shrinkage,
distribution = modDist)
if(any(names(theDots) == "family")) modArgs$distribution <- NULL

if(length(theDots) > 0) modArgs <- c(modArgs, theDots)

Expand Down
4 changes: 2 additions & 2 deletions pkg/caret/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: caret
Version: 6.0-43
Date: 2015-01-27
Version: 6.0-44
Date: 2015-04-01
Title: Classification and Regression Training
Author: Max Kuhn. Contributions from Jed Wing, Steve Weston, Andre
Williams, Chris Keefer, Allan Engelhardt, Tony Cooper, Zachary Mayer,
Expand Down
10 changes: 8 additions & 2 deletions pkg/caret/inst/NEWS.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@
\item A new option to \code{trainControl} called \code{trim} was added where, if implemented, will reduce the model's footprint. However, features beyond simple prediction may not work.
\item A rarely occurring bug in \code{gbm} model code was fixed (thanks to Wade Cooper)
\item \code{splom.resamples} now respects the \code{models} argument
\item A new argument to \code{lift} called \code{cuts} was added to allow more control over what thresholds are used to calculat the curve.
\item A new argument to \code{lift} called \code{cuts} was added to allow more control over what thresholds are used to calculate the curve.
\item The \code{cuts} argument of \code{calibration} now accepts a vector of cut points.
\item Jason Schadewald noticed and fixed a bug in the man page for \code{dummyVars}
\item Call objects were remoed from the following models: \code{avNNet}, \code{bagFDA}, \code{icr}, \code{knn3}, \code{knnreg}, \code{pcaNNet}, and \code{plsda}.
\item Call objects were removed from the following models: \code{avNNet}, \code{bagFDA}, \code{icr}, \code{knn3}, \code{knnreg}, \code{pcaNNet}, and \code{plsda}.
\item An argument was added to \code{createTimeSlices} to thin the number of resamples
\item The RFE-related functions \code{lrFuncs}, \code{lmFuncs}, and \code{gamFuncs} were updated so that \code{rfe} accepts a matrix \code{x} argument.
\item Using the default grid generation with \code{train} and \code{glmnet}, an initial \code{glmnet} fit is created with \code{alpha = 0.50} to define the \code{lambda} values.
\item \code{train} models for \code{"gbm"}, \code{"gam"}, \code{"gamSpline"}, and \code{"gamLoess"} now allow their respective arguments for the outcome probability distribution to be passed to the underlying function.
\item A bug in \code{print.varImp.train} was fixed.
\item \code{train} now returns an additional column called \code{rowIndex} that is exposed when calling the summary function during resampling.
\item The ability to compute class probabilities was removed from the \code{rpartCost} model since they are unlikely to agree with the class predictions.
\item \code{extractProb} no longer redundantly calls \code{extractPrediction} to generate the class predictions.
}
}
Expand Down

0 comments on commit cd7d095

Please sign in to comment.