From 6f0576c15af0af14f18515817513c936ca0b3dfa Mon Sep 17 00:00:00 2001 From: Matt Dancho Date: Thu, 10 Jun 2021 13:58:06 -0400 Subject: [PATCH 01/11] add gen_additive_mod --- NAMESPACE | 5 + R/gen_additive_mod.R | 158 ++++++++++++++++++++++++++ R/gen_additive_mod_data.R | 225 ++++++++++++++++++++++++++++++++++++++ man/gen_additive_mod.Rd | 81 ++++++++++++++ 4 files changed, 469 insertions(+) create mode 100644 R/gen_additive_mod.R create mode 100644 R/gen_additive_mod_data.R create mode 100644 man/gen_additive_mod.Rd diff --git a/NAMESPACE b/NAMESPACE index 2cd256bcf..e8d9caa9d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -46,6 +46,7 @@ S3method(predict_time,model_fit) S3method(print,boost_tree) S3method(print,control_parsnip) S3method(print,decision_tree) +S3method(print,gen_additive_mod) S3method(print,linear_reg) S3method(print,logistic_reg) S3method(print,mars) @@ -76,6 +77,7 @@ S3method(tidy,nullmodel) S3method(translate,boost_tree) S3method(translate,decision_tree) S3method(translate,default) +S3method(translate,gen_additive_mod) S3method(translate,linear_reg) S3method(translate,logistic_reg) S3method(translate,mars) @@ -91,6 +93,7 @@ S3method(type_sum,model_fit) S3method(type_sum,model_spec) S3method(update,boost_tree) S3method(update,decision_tree) +S3method(update,gen_additive_mod) S3method(update,linear_reg) S3method(update,logistic_reg) S3method(update,mars) @@ -136,6 +139,7 @@ export(fit.model_spec) export(fit_control) export(fit_xy) export(fit_xy.model_spec) +export(gen_additive_mod) export(get_dependency) export(get_encoding) export(get_fit) @@ -248,6 +252,7 @@ importFrom(generics,tidy) importFrom(generics,varying_args) importFrom(glue,glue_collapse) importFrom(magrittr,"%>%") +importFrom(parsnip,translate) importFrom(purrr,as_vector) importFrom(purrr,imap) importFrom(purrr,imap_lgl) diff --git a/R/gen_additive_mod.R b/R/gen_additive_mod.R new file mode 100644 index 000000000..dac541f24 --- /dev/null +++ b/R/gen_additive_mod.R @@ -0,0 +1,158 @@ +# gen_additive_mod() - General Interface to Linear GAM Models +# - backend: gam +# - prediction: +# - mode = "regression" (default) uses +# - mode = "classification" + +#' Interface for Generalized Additive Models (GAM) +#' +#' @param mode A single character string for the type of model. +#' @param select_features TRUE or FALSE. If this is TRUE then can add an +#' extra penalty to each term so that it can be penalized to zero. +#' This means that the smoothing parameter estimation that is part of +#' fitting can completely remove terms from the model. If the corresponding +#' smoothing parameter is estimated as zero then the extra penalty has no effect. +#' Use `adjust_deg_free` to increase level of penalization. +#' @param adjust_deg_free If `select_features = TRUE`, then acts as a multiplier for smoothness. +#' Increase this beyond 1 to produce smoother models. +#' +#' +#' @return +#' A `parsnip` model specification +#' +#' @details +#' +#' __Available Engines:__ +#' - __gam__: Connects to `mgcv::gam()` +#' +#' __Parameter Mapping:__ +#' +#' ```{r echo = FALSE} +#' tibble::tribble( +#' ~ "modelgam", ~ "mgcv::gam", +#' "select_features", "select (FALSE)", +#' "adjust_deg_free", "gamma (1)" +#' ) %>% knitr::kable() +#' ``` +#' +#' @section Engine Details: +#' +#' __gam__ +#' +#' This engine uses [mgcv::gam()] and has the following parameters, +#' which can be modified through the [set_engine()] function. +#' +#' ``` {r echo=F} +#' str(mgcv::gam) +#' ``` +#' +#' @section Fit Details: +#' +#' __MGCV Formula Interface__ +#' +#' Fitting GAMs is accomplished using parameters including: +#' +#' - [mgcv::s()]: GAM spline smooths +#' - [mgcv::te()]: GAM tensor product smooths +#' +#' These are applied in the `fit()` function: +#' +#' ``` r +#' fit(value ~ s(date_mon, k = 12) + s(date_num), data = df) +#' ``` +#' +#' +#' @examples +#' +#' show_engines("gen_additive_mod") +#' +#' gen_additive_mod() +#' +#' +#' @export +gen_additive_mod <- function(mode = "regression", + select_features = NULL, + adjust_deg_free = NULL) { + + args <- list( + select_features = rlang::enquo(select_features), + adjust_deg_free = rlang::enquo(adjust_deg_free) + ) + + new_model_spec( + "gen_additive_mod", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) + +} + +#' @export +print.gen_additive_mod <- function(x, ...) { + cat("GAM Model Specification (", x$mode, ")\n\n", sep = "") + model_printer(x, ...) + + if(!is.null(x$method$fit$args)) { + cat("Model fit template:\n") + print(show_call(x)) + } + + invisible(x) +} + +#' @export +#' @importFrom stats update +update.gen_additive_mod <- function(object, + select_features = NULL, + adjust_deg_free = NULL, + parameters = NULL, + fresh = FALSE, ...) { + + update_dot_check(...) + + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } + + args <- list( + select_features = rlang::enquo(select_features), + adjust_deg_free = rlang::enquo(adjust_deg_free) + ) + + args <- update_main_parameters(args, parameters) + + if (fresh) { + object$args <- args + } else { + null_args <- purrr::map_lgl(args, null_value) + if (any(null_args)) + args <- args[!null_args] + if (length(args) > 0) + object$args[names(args)] <- args + } + + new_model_spec( + "gen_additive_mod", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) +} + + +#' @export +#' @importFrom parsnip translate +translate.gen_additive_mod <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'gam'` for translation.") + engine <- "gam" + } + x <- translate.default(x, engine, ...) + + x +} diff --git a/R/gen_additive_mod_data.R b/R/gen_additive_mod_data.R new file mode 100644 index 000000000..5120bed03 --- /dev/null +++ b/R/gen_additive_mod_data.R @@ -0,0 +1,225 @@ + +set_new_model("gen_additive_mod") + +#### REGRESION ---- +model = "gen_additive_mod" +mode = "regression" +engine = "gam" + +set_model_engine(model = model, mode = mode, eng = engine) +set_dependency(model = model, eng = engine, pkg = "mgcv") +set_dependency(model = model, eng = engine, pkg = "parnsip") + +#Args + +set_model_arg( + model = "gen_additive_mod", + eng = "gam", + parsnip = "select_features", + original = "select", + func = list(pkg = "parnsip", fun = "select_features"), + has_submodel = FALSE +) + +set_model_arg( + model = "gen_additive_mod", + eng = "gam", + parsnip = "adjust_deg_free", + original = "gamma", + func = list(pkg = "parnsip", fun = "adjust_deg_free"), + has_submodel = FALSE +) + +set_encoding( + model = model, + eng = engine, + mode = mode, + options = list( + predictor_indicators = "none", + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE + ) +) + +set_fit( + model = model, + eng = engine, + mode = mode, + value = list( + interface = "formula", + protect = c("formula", "data"), + func = c(pkg = "mgcv", fun = "gam"), + defaults = list( + select = FALSE, + gamma = 1 + ) + ) +) + +set_pred( + model = model, + eng = engine, + mode = mode, + type = "numeric", + value = list( + pre = NULL, + post = function(x, object) as.numeric(x), + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data), + type = "response" + ) + ) +) + +set_pred( + model = model, + eng = engine, + mode = mode, + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + res <-tibble::tibble(.pre_lower = results$fit - 2*results$se.fit, + .pre_upper = results$fit + 2*results$se.fit) + }, + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data), + type = "link", + se.fit = TRUE + ) + ) +) + +set_pred( + model = model, + eng = engine, + mode = mode, + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data) + ) + ) +) + +#### CLASSIFICATION + +model = "gen_additive_mod" +mode = "classification" +engine = "gam" + +set_model_engine(model = model, mode = mode, eng = engine) +set_dependency(model = model, eng = engine, pkg = "mgcv") +set_dependency(model = model, eng = engine, pkg = "parnsip") + +set_encoding( + model = model, + eng = engine, + mode = mode, + options = list( + predictor_indicators = "none", + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE + ) +) + +set_fit( + model = model, + eng = engine, + mode = mode, + value = list( + interface = "formula", + protect = c("formula", "data"), + func = c(pkg = "mgcv", fun = "gam"), + defaults = list( + select = FALSE, + gamma = 1, + family = stats::binomial(link = "logit") + ) + ) +) + +prob_to_class_2 <- function(x, object){ + + x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) + unname(x) +} + +set_pred( + model = model, + eng = engine, + mode = mode, + type = "class", + value = list( + pre = NULL, + post = function(results, object) { + + tbl <-tibble::as_tibble(results) + + if (ncol(tbl)==1){ + res<-prob_to_class_2(tbl, object) %>% + tibble::as_tibble() %>% + stats::setNames("values") %>% + dplyr::mutate(values = as.factor(values)) + } else{ + res <- tbl %>% + apply(.,1,function(x) which(max(x)==x)[1])-1 %>% #modify in the future for something more elegant when gets the formula ok + tibble::as_tibble() + } + + }, + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data), + type = "response" + ) + ) +) + +set_pred( + model = model, + eng = engine, + mode = mode, + type = "prob", + value = list( + pre = NULL, + post = function(results, object) { + res <-tibble::as_tibble(results) + }, + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data), + type = "response" + ) + ) +) + +set_pred( + model = model, + eng = engine, + mode = mode, + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data) + ) + ) +) + + diff --git a/man/gen_additive_mod.Rd b/man/gen_additive_mod.Rd new file mode 100644 index 000000000..f160633b6 --- /dev/null +++ b/man/gen_additive_mod.Rd @@ -0,0 +1,81 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/gen_additive_mod.R +\name{gen_additive_mod} +\alias{gen_additive_mod} +\title{Interface for Generalized Additive Models (GAM)} +\usage{ +gen_additive_mod( + mode = "regression", + select_features = NULL, + adjust_deg_free = NULL +) +} +\arguments{ +\item{mode}{A single character string for the type of model.} + +\item{select_features}{TRUE or FALSE. If this is TRUE then can add an +extra penalty to each term so that it can be penalized to zero. +This means that the smoothing parameter estimation that is part of +fitting can completely remove terms from the model. If the corresponding +smoothing parameter is estimated as zero then the extra penalty has no effect. +Use \code{adjust_deg_free} to increase level of penalization.} + +\item{adjust_deg_free}{If \code{select_features = TRUE}, then acts as a multiplier for smoothness. +Increase this beyond 1 to produce smoother models.} +} +\value{ +A \code{parsnip} model specification +} +\description{ +Interface for Generalized Additive Models (GAM) +} +\details{ +\strong{Available Engines:} +\itemize{ +\item \strong{gam}: Connects to \code{mgcv::gam()} +} + +\strong{Parameter Mapping:}\tabular{ll}{ + modelgam \tab mgcv::gam \cr + select_features \tab select (FALSE) \cr + adjust_deg_free \tab gamma (1) \cr +} +} +\section{Engine Details}{ + + +\strong{gam} + +This engine uses \code{\link[mgcv:gam]{mgcv::gam()}} and has the following parameters, +which can be modified through the \code{\link[=set_engine]{set_engine()}} function.\preformatted{## function (formula, family = gaussian(), data = list(), weights = NULL, +## subset = NULL, na.action, offset = NULL, method = "GCV.Cp", optimizer = c("outer", +## "newton"), control = list(), scale = 0, select = FALSE, knots = NULL, +## sp = NULL, min.sp = NULL, H = NULL, gamma = 1, fit = TRUE, paraPen = NULL, +## G = NULL, in.out = NULL, drop.unused.levels = TRUE, drop.intercept = NULL, +## discrete = FALSE, ...) +} +} + +\section{Fit Details}{ + + +\strong{MGCV Formula Interface} + +Fitting GAMs is accomplished using parameters including: +\itemize{ +\item \code{\link[mgcv:s]{mgcv::s()}}: GAM spline smooths +\item \code{\link[mgcv:te]{mgcv::te()}}: GAM tensor product smooths +} + +These are applied in the \code{fit()} function:\if{html}{\out{
}}\preformatted{fit(value ~ s(date_mon, k = 12) + s(date_num), data = df) +}\if{html}{\out{
}} +} + +\examples{ + +show_engines("gen_additive_mod") + +gen_additive_mod() + + +} From ea2c4efb4c1ce1164aa6add15f7d48b6dbd69b78 Mon Sep 17 00:00:00 2001 From: Matt Dancho Date: Thu, 10 Jun 2021 14:27:57 -0400 Subject: [PATCH 02/11] add mgcv to suggests --- DESCRIPTION | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 9512cea5f..f3c07f7db 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -54,6 +54,7 @@ Suggests: nlme, modeldata, LiblineaR, - Matrix + Matrix, + mgcv Remotes: topepo/C5.0 From ff95b1e6b922b3fc773a791b346972d3fa25fea1 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Fri, 11 Jun 2021 11:36:58 -0400 Subject: [PATCH 03/11] udpates for new documentation system --- R/gen_additive_mod.R | 36 +++++++++++++++++++----------------- man/gen_additive_mod.Rd | 34 +++++++++++++++++++--------------- man/parsnip_update.Rd | 18 ++++++++++++++---- 3 files changed, 52 insertions(+), 36 deletions(-) diff --git a/R/gen_additive_mod.R b/R/gen_additive_mod.R index dac541f24..f6025c133 100644 --- a/R/gen_additive_mod.R +++ b/R/gen_additive_mod.R @@ -4,9 +4,18 @@ # - mode = "regression" (default) uses # - mode = "classification" -#' Interface for Generalized Additive Models (GAM) +#' Generalized additive models (GAMs) #' -#' @param mode A single character string for the type of model. +#' `gen_additive_mod()` defines a model that can use smoothed functions of +#' numeric predictors in a generalized linear model. +#' +#' There are different ways to fit this model. See the engine-specific pages +#' for more details +#' +#' More information on how `parsnip` is used for modeling is at +#' \url{https://www.tidymodels.org}. +#' +#' @inheritParams boost_tree #' @param select_features TRUE or FALSE. If this is TRUE then can add an #' extra penalty to each term so that it can be penalized to zero. #' This means that the smoothing parameter estimation that is part of @@ -22,20 +31,11 @@ #' #' @details #' -#' __Available Engines:__ -#' - __gam__: Connects to `mgcv::gam()` +#' This function only defines what _type_ of model is being fit. Once an engine +#' is specified, the _method_ to fit the model is also defined. #' -#' __Parameter Mapping:__ -#' -#' ```{r echo = FALSE} -#' tibble::tribble( -#' ~ "modelgam", ~ "mgcv::gam", -#' "select_features", "select (FALSE)", -#' "adjust_deg_free", "gamma (1)" -#' ) %>% knitr::kable() -#' ``` -#' -#' @section Engine Details: +#' The model is not trained or fit until the [fit.model_spec()] function is used +#' with the data. #' #' __gam__ #' @@ -61,7 +61,8 @@ #' fit(value ~ s(date_mon, k = 12) + s(date_num), data = df) #' ``` #' -#' +#' @references \url{https://www.tidymodels.org}, +#' [_Tidy Models with R_](https://tmwr.org) #' @examples #' #' show_engines("gen_additive_mod") @@ -92,7 +93,7 @@ gen_additive_mod <- function(mode = "regression", #' @export print.gen_additive_mod <- function(x, ...) { - cat("GAM Model Specification (", x$mode, ")\n\n", sep = "") + cat("GAM Specification (", x$mode, ")\n\n", sep = "") model_printer(x, ...) if(!is.null(x$method$fit$args)) { @@ -104,6 +105,7 @@ print.gen_additive_mod <- function(x, ...) { } #' @export +#' @rdname parsnip_update #' @importFrom stats update update.gen_additive_mod <- function(object, select_features = NULL, diff --git a/man/gen_additive_mod.Rd b/man/gen_additive_mod.Rd index f160633b6..b9b9e769a 100644 --- a/man/gen_additive_mod.Rd +++ b/man/gen_additive_mod.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/gen_additive_mod.R \name{gen_additive_mod} \alias{gen_additive_mod} -\title{Interface for Generalized Additive Models (GAM)} +\title{Generalized additive models (GAMs)} \usage{ gen_additive_mod( mode = "regression", @@ -11,7 +11,9 @@ gen_additive_mod( ) } \arguments{ -\item{mode}{A single character string for the type of model.} +\item{mode}{A single character string for the type of model. +Possible values for this model are "unknown", "regression", or +"classification".} \item{select_features}{TRUE or FALSE. If this is TRUE then can add an extra penalty to each term so that it can be penalized to zero. @@ -27,22 +29,21 @@ Increase this beyond 1 to produce smoother models.} A \code{parsnip} model specification } \description{ -Interface for Generalized Additive Models (GAM) +\code{gen_additive_mod()} defines a model that can use smoothed functions of +numeric predictors in a generalized linear model. } \details{ -\strong{Available Engines:} -\itemize{ -\item \strong{gam}: Connects to \code{mgcv::gam()} -} +There are different ways to fit this model. See the engine-specific pages +for more details -\strong{Parameter Mapping:}\tabular{ll}{ - modelgam \tab mgcv::gam \cr - select_features \tab select (FALSE) \cr - adjust_deg_free \tab gamma (1) \cr -} -} -\section{Engine Details}{ +More information on how \code{parsnip} is used for modeling is at +\url{https://www.tidymodels.org}. +This function only defines what \emph{type} of model is being fit. Once an engine +is specified, the \emph{method} to fit the model is also defined. + +The model is not trained or fit until the \code{\link[=fit.model_spec]{fit.model_spec()}} function is used +with the data. \strong{gam} @@ -55,7 +56,6 @@ which can be modified through the \code{\link[=set_engine]{set_engine()}} functi ## discrete = FALSE, ...) } } - \section{Fit Details}{ @@ -79,3 +79,7 @@ gen_additive_mod() } +\references{ +\url{https://www.tidymodels.org}, +\href{https://tmwr.org}{\emph{Tidy Models with R}} +} diff --git a/man/parsnip_update.Rd b/man/parsnip_update.Rd index b1cea173f..b9a0fc4e1 100644 --- a/man/parsnip_update.Rd +++ b/man/parsnip_update.Rd @@ -1,12 +1,13 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/boost_tree.R, R/decision_tree.R, -% R/linear_reg.R, R/logistic_reg.R, R/mars.R, R/mlp.R, R/multinom_reg.R, -% R/nearest_neighbor.R, R/proportional_hazards.R, R/rand_forest.R, -% R/surv_reg.R, R/survival_reg.R, R/svm_linear.R, R/svm_poly.R, R/svm_rbf.R, -% R/update.R +% R/gen_additive_mod.R, R/linear_reg.R, R/logistic_reg.R, R/mars.R, +% R/mlp.R, R/multinom_reg.R, R/nearest_neighbor.R, R/proportional_hazards.R, +% R/rand_forest.R, R/surv_reg.R, R/survival_reg.R, R/svm_linear.R, +% R/svm_poly.R, R/svm_rbf.R, R/update.R \name{update.boost_tree} \alias{update.boost_tree} \alias{update.decision_tree} +\alias{update.gen_additive_mod} \alias{update.linear_reg} \alias{update.logistic_reg} \alias{update.mars} @@ -48,6 +49,15 @@ ... ) +\method{update}{gen_additive_mod}( + object, + select_features = NULL, + adjust_deg_free = NULL, + parameters = NULL, + fresh = FALSE, + ... +) + \method{update}{linear_reg}( object, parameters = NULL, From fa35df8d0db2f263de8b2c1bad272b07aae70d57 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Fri, 11 Jun 2021 12:00:38 -0400 Subject: [PATCH 04/11] udpates for move form external package to parsnip --- NAMESPACE | 1 - R/gen_additive_mod.R | 9 ++- R/gen_additive_mod_data.R | 112 ++++++++++++++++---------------------- man/gen_additive_mod.Rd | 6 +- 4 files changed, 55 insertions(+), 73 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index e8d9caa9d..8574e5739 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -252,7 +252,6 @@ importFrom(generics,tidy) importFrom(generics,varying_args) importFrom(glue,glue_collapse) importFrom(magrittr,"%>%") -importFrom(parsnip,translate) importFrom(purrr,as_vector) importFrom(purrr,imap) importFrom(purrr,imap_lgl) diff --git a/R/gen_additive_mod.R b/R/gen_additive_mod.R index f6025c133..f5fd1b446 100644 --- a/R/gen_additive_mod.R +++ b/R/gen_additive_mod.R @@ -65,13 +65,13 @@ #' [_Tidy Models with R_](https://tmwr.org) #' @examples #' -#' show_engines("gen_additive_mod") +#' #show_engines("gen_additive_mod") #' -#' gen_additive_mod() +#' #gen_additive_mod() #' #' #' @export -gen_additive_mod <- function(mode = "regression", +gen_additive_mod <- function(mode = "unknown", select_features = NULL, adjust_deg_free = NULL) { @@ -148,10 +148,9 @@ update.gen_additive_mod <- function(object, #' @export -#' @importFrom parsnip translate translate.gen_additive_mod <- function(x, engine = x$engine, ...) { if (is.null(engine)) { - message("Used `engine = 'gam'` for translation.") + message("Used `engine = 'mgcv'` for translation.") engine <- "gam" } x <- translate.default(x, engine, ...) diff --git a/R/gen_additive_mod_data.R b/R/gen_additive_mod_data.R index 5120bed03..e67080c5b 100644 --- a/R/gen_additive_mod_data.R +++ b/R/gen_additive_mod_data.R @@ -1,39 +1,36 @@ set_new_model("gen_additive_mod") +# ------------------------------------------------------------------------------ #### REGRESION ---- -model = "gen_additive_mod" -mode = "regression" -engine = "gam" - -set_model_engine(model = model, mode = mode, eng = engine) -set_dependency(model = model, eng = engine, pkg = "mgcv") -set_dependency(model = model, eng = engine, pkg = "parnsip") +set_model_engine(model = "gen_additive_mod", mode = "regression", eng = "mgcv") +set_dependency(model = "gen_additive_mod", eng = "mgcv", pkg = "mgcv") #Args +# TODO make dials PR set_model_arg( model = "gen_additive_mod", - eng = "gam", + eng = "mgcv", parsnip = "select_features", original = "select", - func = list(pkg = "parnsip", fun = "select_features"), + func = list(pkg = "dials", fun = "select_features"), has_submodel = FALSE ) set_model_arg( model = "gen_additive_mod", - eng = "gam", + eng = "mgcv", parsnip = "adjust_deg_free", original = "gamma", - func = list(pkg = "parnsip", fun = "adjust_deg_free"), + func = list(pkg = "dials", fun = "adjust_deg_free"), has_submodel = FALSE ) set_encoding( - model = model, - eng = engine, - mode = mode, + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", options = list( predictor_indicators = "none", compute_intercept = FALSE, @@ -43,24 +40,21 @@ set_encoding( ) set_fit( - model = model, - eng = engine, - mode = mode, + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", value = list( interface = "formula", protect = c("formula", "data"), func = c(pkg = "mgcv", fun = "gam"), - defaults = list( - select = FALSE, - gamma = 1 - ) + defaults = list() ) ) set_pred( - model = model, - eng = engine, - mode = mode, + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", type = "numeric", value = list( pre = NULL, @@ -75,13 +69,14 @@ set_pred( ) set_pred( - model = model, - eng = engine, - mode = mode, + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", type = "conf_int", value = list( pre = NULL, post = function(results, object) { + # TODO fix this; see the logistic regression code res <-tibble::tibble(.pre_lower = results$fit - 2*results$se.fit, .pre_upper = results$fit + 2*results$se.fit) }, @@ -96,9 +91,9 @@ set_pred( ) set_pred( - model = model, - eng = engine, - mode = mode, + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", type = "raw", value = list( pre = NULL, @@ -111,20 +106,16 @@ set_pred( ) ) +# ------------------------------------------------------------------------------ #### CLASSIFICATION +set_model_engine(model = "gen_additive_mod", mode = "classification", eng = "mgcv") +set_dependency(model = "gen_additive_mod", eng = "mgcv", pkg = "mgcv") -model = "gen_additive_mod" -mode = "classification" -engine = "gam" - -set_model_engine(model = model, mode = mode, eng = engine) -set_dependency(model = model, eng = engine, pkg = "mgcv") -set_dependency(model = model, eng = engine, pkg = "parnsip") set_encoding( - model = model, - eng = engine, - mode = mode, + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", options = list( predictor_indicators = "none", compute_intercept = FALSE, @@ -134,31 +125,23 @@ set_encoding( ) set_fit( - model = model, - eng = engine, - mode = mode, + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", value = list( interface = "formula", protect = c("formula", "data"), func = c(pkg = "mgcv", fun = "gam"), defaults = list( - select = FALSE, - gamma = 1, family = stats::binomial(link = "logit") ) ) ) -prob_to_class_2 <- function(x, object){ - - x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) - unname(x) -} - set_pred( - model = model, - eng = engine, - mode = mode, + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", type = "class", value = list( pre = NULL, @@ -166,14 +149,15 @@ set_pred( tbl <-tibble::as_tibble(results) - if (ncol(tbl)==1){ - res<-prob_to_class_2(tbl, object) %>% + if (ncol(tbl) == 1) { + res <- prob_to_class_2(tbl, object) %>% tibble::as_tibble() %>% stats::setNames("values") %>% dplyr::mutate(values = as.factor(values)) } else{ res <- tbl %>% - apply(.,1,function(x) which(max(x)==x)[1])-1 %>% #modify in the future for something more elegant when gets the formula ok + apply(., 1, function(x) + which(max(x) == x)[1]) - 1 %>% #modify in the future for something more elegant when gets the formula ok tibble::as_tibble() } @@ -188,14 +172,14 @@ set_pred( ) set_pred( - model = model, - eng = engine, - mode = mode, + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", type = "prob", value = list( pre = NULL, post = function(results, object) { - res <-tibble::as_tibble(results) + res <- tibble::as_tibble(results) }, func = c(fun = "predict"), args = list( @@ -207,9 +191,9 @@ set_pred( ) set_pred( - model = model, - eng = engine, - mode = mode, + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", type = "raw", value = list( pre = NULL, diff --git a/man/gen_additive_mod.Rd b/man/gen_additive_mod.Rd index b9b9e769a..3296bc04f 100644 --- a/man/gen_additive_mod.Rd +++ b/man/gen_additive_mod.Rd @@ -5,7 +5,7 @@ \title{Generalized additive models (GAMs)} \usage{ gen_additive_mod( - mode = "regression", + mode = "unknown", select_features = NULL, adjust_deg_free = NULL ) @@ -73,9 +73,9 @@ These are applied in the \code{fit()} function:\if{html}{\out{
}}\ \examples{ -show_engines("gen_additive_mod") +#show_engines("gen_additive_mod") -gen_additive_mod() +#gen_additive_mod() } From b535a07cdeb02dc6df9826ea598bbb50c2038970 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Fri, 11 Jun 2021 12:19:11 -0400 Subject: [PATCH 05/11] fixed missing doc references --- R/gen_additive_mod.R | 1 + R/gen_additive_mod_data.R | 4 ++-- man/parsnip_update.Rd | 10 ++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/R/gen_additive_mod.R b/R/gen_additive_mod.R index f5fd1b446..180730eb4 100644 --- a/R/gen_additive_mod.R +++ b/R/gen_additive_mod.R @@ -107,6 +107,7 @@ print.gen_additive_mod <- function(x, ...) { #' @export #' @rdname parsnip_update #' @importFrom stats update +#' @inheritParams gen_additive_mod update.gen_additive_mod <- function(object, select_features = NULL, adjust_deg_free = NULL, diff --git a/R/gen_additive_mod_data.R b/R/gen_additive_mod_data.R index e67080c5b..db51c3ad0 100644 --- a/R/gen_additive_mod_data.R +++ b/R/gen_additive_mod_data.R @@ -77,8 +77,8 @@ set_pred( pre = NULL, post = function(results, object) { # TODO fix this; see the logistic regression code - res <-tibble::tibble(.pre_lower = results$fit - 2*results$se.fit, - .pre_upper = results$fit + 2*results$se.fit) + res <-tibble::tibble(.pred_lower = results$fit - 2*results$se.fit, + .pred_upper = results$fit + 2*results$se.fit) }, func = c(fun = "predict"), args = list( diff --git a/man/parsnip_update.Rd b/man/parsnip_update.Rd index b9a0fc4e1..8bec62e48 100644 --- a/man/parsnip_update.Rd +++ b/man/parsnip_update.Rd @@ -213,6 +213,16 @@ modified in-place or replaced wholesale.} \item{cost_complexity}{A positive number for the the cost/complexity parameter (a.k.a. \code{Cp}) used by CART models (\code{rpart} only).} +\item{select_features}{TRUE or FALSE. If this is TRUE then can add an +extra penalty to each term so that it can be penalized to zero. +This means that the smoothing parameter estimation that is part of +fitting can completely remove terms from the model. If the corresponding +smoothing parameter is estimated as zero then the extra penalty has no effect. +Use \code{adjust_deg_free} to increase level of penalization.} + +\item{adjust_deg_free}{If \code{select_features = TRUE}, then acts as a multiplier for smoothness. +Increase this beyond 1 to produce smoother models.} + \item{penalty}{A non-negative number representing the total amount of regularization (\code{glmnet}, \code{keras}, and \code{spark} only). For \code{keras} models, this corresponds to purely L2 regularization From 44213b24f0d381cd10a65deaebee35b962f4d712 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Mon, 14 Jun 2021 11:10:14 -0400 Subject: [PATCH 06/11] confidence intervals and other model info changes --- R/gen_additive_mod_data.R | 92 +++++++++++++++++++++++++++++---------- 1 file changed, 70 insertions(+), 22 deletions(-) diff --git a/R/gen_additive_mod_data.R b/R/gen_additive_mod_data.R index db51c3ad0..cbcd27bc4 100644 --- a/R/gen_additive_mod_data.R +++ b/R/gen_additive_mod_data.R @@ -76,9 +76,26 @@ set_pred( value = list( pre = NULL, post = function(results, object) { - # TODO fix this; see the logistic regression code - res <-tibble::tibble(.pred_lower = results$fit - 2*results$se.fit, - .pred_upper = results$fit + 2*results$se.fit) + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 + const <- + qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) + trans <- object$fit$family$linkinv + res <- + tibble( + .pred_lower = trans(results$fit - const * results$se.fit), + .pred_upper = trans(results$fit + const * results$se.fit) + ) + # In case of inverse or other links + if (any(res$.pred_upper < res$.pred_lower)) { + nms <- names(res) + res <- res[, 2:1] + names(res) <- nms + } + + if (object$spec$method$pred$conf_int$extras$std_error) { + res$.std_error <- results$se.fit + } + res }, func = c(fun = "predict"), args = list( @@ -145,22 +162,9 @@ set_pred( type = "class", value = list( pre = NULL, - post = function(results, object) { - - tbl <-tibble::as_tibble(results) - - if (ncol(tbl) == 1) { - res <- prob_to_class_2(tbl, object) %>% - tibble::as_tibble() %>% - stats::setNames("values") %>% - dplyr::mutate(values = as.factor(values)) - } else{ - res <- tbl %>% - apply(., 1, function(x) - which(max(x) == x)[1]) - 1 %>% #modify in the future for something more elegant when gets the formula ok - tibble::as_tibble() - } - + post = function(x, object) { + x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) + unname(x) }, func = c(fun = "predict"), args = list( @@ -177,9 +181,11 @@ set_pred( mode = "classification", type = "prob", value = list( - pre = NULL, - post = function(results, object) { - res <- tibble::as_tibble(results) + pre = NULL, + post = function(x, object) { + x <- tibble(v1 = 1 - x, v2 = x) + colnames(x) <- object$lvl + x }, func = c(fun = "predict"), args = list( @@ -207,3 +213,45 @@ set_pred( ) +set_pred( + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 + const <- + qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) + trans <- object$fit$family$linkinv + res_2 <- + tibble( + lo = trans(results$fit - const * results$se.fit), + hi = trans(results$fit + const * results$se.fit) + ) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res_1) <- c(lo_nms[1], hi_nms[1]) + colnames(res_2) <- c(lo_nms[2], hi_nms[2]) + res <- bind_cols(res_1, res_2) + + if (object$spec$method$pred$conf_int$extras$std_error) { + res$.std_error <- results$se.fit + } + res + }, + func = c(fun = "predict"), + args = + list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data), + type = "link", + se.fit = TRUE + ) + ) +) + From a4b293fbc3a5a535a3c3822d2b76bdd6810871ab Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Mon, 14 Jun 2021 19:31:32 -0400 Subject: [PATCH 07/11] test cases --- NAMESPACE | 1 + R/gen_additive_mod.R | 6 ++ tests/testthat/test_gen_additive_model.R | 99 ++++++++++++++++++++++++ 3 files changed, 106 insertions(+) create mode 100644 tests/testthat/test_gen_additive_model.R diff --git a/NAMESPACE b/NAMESPACE index 8574e5739..594d9e7bd 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,7 @@ S3method(augment,model_fit) S3method(fit,model_spec) +S3method(fit_xy,gen_additive_mod) S3method(fit_xy,model_spec) S3method(glance,model_fit) S3method(has_multi_predict,default) diff --git a/R/gen_additive_mod.R b/R/gen_additive_mod.R index 180730eb4..64f03ccb2 100644 --- a/R/gen_additive_mod.R +++ b/R/gen_additive_mod.R @@ -158,3 +158,9 @@ translate.gen_additive_mod <- function(x, engine = x$engine, ...) { x } + +#' @export +#' @keywords internal +fit_xy.gen_additive_mod <- function(object, ...) { + rlang::abort("`fit()` must be used with GAM models (due to its use of formulas).") +} diff --git a/tests/testthat/test_gen_additive_model.R b/tests/testthat/test_gen_additive_model.R new file mode 100644 index 000000000..407c1f7da --- /dev/null +++ b/tests/testthat/test_gen_additive_model.R @@ -0,0 +1,99 @@ +library(testthat) +library(parsnip) +library(rlang) +library(tibble) +library(mgcv) + +data(two_class_dat, package = "modeldata") + +# ------------------------------------------------------------------------------ + +context("generalized additive models") + +# ------------------------------------------------------------------------------ + +reg_mod <- gen_additive_mod(select_features = TRUE) %>% set_engine("mgcv") %>% set_mode("regression") + +test_that('regression', { + skip_if_not_installed("mgcv") + + expect_error( + f_res <- fit( + reg_mod, + mpg ~ s(disp) + wt + gear, + data = mtcars + ), + regexp = NA + ) + expect_error( + xy_res <- fit_xy( + reg_mod, + x = mtcars[, 1:5], + y = mtcars$mpg, + control = ctrl + ), + regexp = "must be used with GAM models" + ) + mgcv_mod <- mgcv::gam(mpg ~ s(disp) + wt + gear, data = mtcars, select = TRUE) + expect_equal(coef(mgcv_mod), coef(f_res$fit)) + + f_pred <- predict(f_res, head(mtcars)) + mgcv_pred <- predict(mgcv_mod, head(mtcars), type = "response") + expect_equal(names(f_pred), ".pred") + expect_equivalent(f_pred[[".pred"]], unname(mgcv_pred)) + + f_ci <- predict(f_res, head(mtcars), type = "conf_int", std_error = TRUE) + mgcv_ci <- predict(mgcv_mod, head(mtcars), type = "link", se.fit = TRUE) + expect_equivalent(f_ci[[".std_error"]], unname(mgcv_ci$se.fit)) + lower <- + mgcv_ci$fit - qt(0.025, df = mgcv_mod$df.residual, lower.tail = FALSE) * mgcv_ci$se.fit + expect_equivalent(f_ci[[".pred_lower"]], unname(lower)) + +}) + + + +# ------------------------------------------------------------------------------ + +cls_mod <- gen_additive_mod(adjust_deg_free = 1.5) %>% set_engine("mgcv") %>% set_mode("classification") + +test_that('classification', { + skip_if_not_installed("mgcv") + expect_error( + f_res <- fit( + cls_mod, + Class ~ s(A, k = 10) + B, + data = two_class_dat + ), + regexp = NA + ) + expect_error( + xy_res <- fit_xy( + cls_mod, + x = two_class_dat[, 2:3], + y = two_class_dat$Class, + control = ctrl + ), + regexp = "must be used with GAM models" + ) + mgcv_mod <- + mgcv::gam(Class ~ s(A, k = 10) + B, + data = two_class_dat, + gamma = 1.5, + family = binomial) + expect_equal(coef(mgcv_mod), coef(f_res$fit)) + + f_pred <- predict(f_res, head(two_class_dat), type = "prob") + mgcv_pred <- predict(mgcv_mod, head(two_class_dat), type = "response") + expect_equal(names(f_pred), c(".pred_Class1", ".pred_Class2")) + expect_equivalent(f_pred[[".pred_Class2"]], unname(mgcv_pred)) + + f_ci <- predict(f_res, head(two_class_dat), type = "conf_int", std_error = TRUE) + mgcv_ci <- predict(mgcv_mod, head(two_class_dat), type = "link", se.fit = TRUE) + expect_equivalent(f_ci[[".std_error"]], unname(mgcv_ci$se.fit)) + lower <- + mgcv_ci$fit - qt(0.025, df = mgcv_mod$df.residual, lower.tail = FALSE) * mgcv_ci$se.fit + lower <- binomial()$linkinv(lower) + expect_equivalent(f_ci[[".pred_lower_Class2"]], unname(lower)) + +}) From cbb93d2fbe93b889e38e15586ad0892cf9278c17 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Wed, 16 Jun 2021 10:34:01 -0400 Subject: [PATCH 08/11] modularize confidence interval code --- NAMESPACE | 1 + R/aaa.R | 53 +++++++++++++++++++++++++++++++++++++++ R/gen_additive_mod_data.R | 48 ++--------------------------------- R/logistic_reg_data.R | 24 +----------------- 4 files changed, 57 insertions(+), 69 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 594d9e7bd..628fff2db 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -300,6 +300,7 @@ importFrom(stats,na.omit) importFrom(stats,na.pass) importFrom(stats,predict) importFrom(stats,qnorm) +importFrom(stats,qt) importFrom(stats,quantile) importFrom(stats,setNames) importFrom(stats,terms) diff --git a/R/aaa.R b/R/aaa.R index 1a7109d0e..b2dc334b9 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -30,6 +30,59 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) { res } +# ------------------------------------------------------------------------------ + +#' @importFrom stats qt +# used by logistic_reg() and gen_additive_mod() +logistic_lp_to_conf_int <- function(results, object) { + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 + const <- + stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) + trans <- object$fit$family$linkinv + res_2 <- + tibble( + lo = trans(results$fit - const * results$se.fit), + hi = trans(results$fit + const * results$se.fit) + ) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res_1) <- c(lo_nms[1], hi_nms[1]) + colnames(res_2) <- c(lo_nms[2], hi_nms[2]) + res <- bind_cols(res_1, res_2) + + if (object$spec$method$pred$conf_int$extras$std_error) + res$.std_error <- results$se.fit + res +} + +# used by gen_additive_mod() +linear_lp_to_conf_int <- +function(results, object) { + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 + const <- + stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) + trans <- object$fit$family$linkinv + res <- + tibble( + .pred_lower = trans(results$fit - const * results$se.fit), + .pred_upper = trans(results$fit + const * results$se.fit) + ) + # In case of inverse or other links + if (any(res$.pred_upper < res$.pred_lower)) { + nms <- names(res) + res <- res[, 2:1] + names(res) <- nms + } + + if (object$spec$method$pred$conf_int$extras$std_error) { + res$.std_error <- results$se.fit + } + res +} + # ------------------------------------------------------------------------------ # nocov diff --git a/R/gen_additive_mod_data.R b/R/gen_additive_mod_data.R index cbcd27bc4..1bf389a27 100644 --- a/R/gen_additive_mod_data.R +++ b/R/gen_additive_mod_data.R @@ -75,28 +75,7 @@ set_pred( type = "conf_int", value = list( pre = NULL, - post = function(results, object) { - hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 - const <- - qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv - res <- - tibble( - .pred_lower = trans(results$fit - const * results$se.fit), - .pred_upper = trans(results$fit + const * results$se.fit) - ) - # In case of inverse or other links - if (any(res$.pred_upper < res$.pred_lower)) { - nms <- names(res) - res <- res[, 2:1] - names(res) <- nms - } - - if (object$spec$method$pred$conf_int$extras$std_error) { - res$.std_error <- results$se.fit - } - res - }, + post = linear_lp_to_conf_int, func = c(fun = "predict"), args = list( object = rlang::expr(object$fit), @@ -220,30 +199,7 @@ set_pred( type = "conf_int", value = list( pre = NULL, - post = function(results, object) { - hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 - const <- - qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv - res_2 <- - tibble( - lo = trans(results$fit - const * results$se.fit), - hi = trans(results$fit + const * results$se.fit) - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res_1) <- c(lo_nms[1], hi_nms[1]) - colnames(res_2) <- c(lo_nms[2], hi_nms[2]) - res <- bind_cols(res_1, res_2) - - if (object$spec$method$pred$conf_int$extras$std_error) { - res$.std_error <- results$se.fit - } - res - }, + post = logistic_lp_to_conf_int, func = c(fun = "predict"), args = list( diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 0fa45e7bf..906b733d8 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -95,29 +95,7 @@ set_pred( type = "conf_int", value = list( pre = NULL, - post = function(results, object) { - hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 - const <- - qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv - res_2 <- - tibble( - lo = trans(results$fit - const * results$se.fit), - hi = trans(results$fit + const * results$se.fit) - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res_1) <- c(lo_nms[1], hi_nms[1]) - colnames(res_2) <- c(lo_nms[2], hi_nms[2]) - res <- bind_cols(res_1, res_2) - - if (object$spec$method$pred$conf_int$extras$std_error) - res$.std_error <- results$se.fit - res - }, + post = logistic_lp_to_conf_int, func = c(fun = "predict"), args = list( From 96f9adae74c6ab9289165c5429ebf70e197573db Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Wed, 16 Jun 2021 10:34:35 -0400 Subject: [PATCH 09/11] set_model_mode --- R/gen_additive_mod_data.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/R/gen_additive_mod_data.R b/R/gen_additive_mod_data.R index 1bf389a27..bdb9261a6 100644 --- a/R/gen_additive_mod_data.R +++ b/R/gen_additive_mod_data.R @@ -1,5 +1,7 @@ set_new_model("gen_additive_mod") +set_model_mode("gen_additive_mod", "classification") +set_model_mode("gen_additive_mod", "regression") # ------------------------------------------------------------------------------ #### REGRESION ---- From 34f4ad5a637b4884006eb8462483088e46891973 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Wed, 16 Jun 2021 10:47:12 -0400 Subject: [PATCH 10/11] updated news --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index 95edff59a..c54669f21 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,8 @@ * Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510). +* A model function (`gen_additive_mod()`) was added for generalized additive models. + # parsnip 0.1.6 ## Model Specification Changes From c2c2c3dfb45a709f360d96ce2eef4198ff2acc48 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Wed, 16 Jun 2021 10:49:06 -0400 Subject: [PATCH 11/11] updated unit test --- tests/testthat/test_gen_additive_model.R | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/testthat/test_gen_additive_model.R b/tests/testthat/test_gen_additive_model.R index 407c1f7da..4c82483c6 100644 --- a/tests/testthat/test_gen_additive_model.R +++ b/tests/testthat/test_gen_additive_model.R @@ -12,11 +12,15 @@ context("generalized additive models") # ------------------------------------------------------------------------------ -reg_mod <- gen_additive_mod(select_features = TRUE) %>% set_engine("mgcv") %>% set_mode("regression") test_that('regression', { skip_if_not_installed("mgcv") + reg_mod <- + gen_additive_mod(select_features = TRUE) %>% + set_engine("mgcv") %>% + set_mode("regression") + expect_error( f_res <- fit( reg_mod, @@ -51,14 +55,16 @@ test_that('regression', { }) - - # ------------------------------------------------------------------------------ -cls_mod <- gen_additive_mod(adjust_deg_free = 1.5) %>% set_engine("mgcv") %>% set_mode("classification") - test_that('classification', { skip_if_not_installed("mgcv") + + cls_mod <- + gen_additive_mod(adjust_deg_free = 1.5) %>% + set_engine("mgcv") %>% + set_mode("classification") + expect_error( f_res <- fit( cls_mod,