diff --git a/DESCRIPTION b/DESCRIPTION index 9512cea5f..f3c07f7db 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -54,6 +54,7 @@ Suggests: nlme, modeldata, LiblineaR, - Matrix + Matrix, + mgcv Remotes: topepo/C5.0 diff --git a/NAMESPACE b/NAMESPACE index 2cd256bcf..628fff2db 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,7 @@ S3method(augment,model_fit) S3method(fit,model_spec) +S3method(fit_xy,gen_additive_mod) S3method(fit_xy,model_spec) S3method(glance,model_fit) S3method(has_multi_predict,default) @@ -46,6 +47,7 @@ S3method(predict_time,model_fit) S3method(print,boost_tree) S3method(print,control_parsnip) S3method(print,decision_tree) +S3method(print,gen_additive_mod) S3method(print,linear_reg) S3method(print,logistic_reg) S3method(print,mars) @@ -76,6 +78,7 @@ S3method(tidy,nullmodel) S3method(translate,boost_tree) S3method(translate,decision_tree) S3method(translate,default) +S3method(translate,gen_additive_mod) S3method(translate,linear_reg) S3method(translate,logistic_reg) S3method(translate,mars) @@ -91,6 +94,7 @@ S3method(type_sum,model_fit) S3method(type_sum,model_spec) S3method(update,boost_tree) S3method(update,decision_tree) +S3method(update,gen_additive_mod) S3method(update,linear_reg) S3method(update,logistic_reg) S3method(update,mars) @@ -136,6 +140,7 @@ export(fit.model_spec) export(fit_control) export(fit_xy) export(fit_xy.model_spec) +export(gen_additive_mod) export(get_dependency) export(get_encoding) export(get_fit) @@ -295,6 +300,7 @@ importFrom(stats,na.omit) importFrom(stats,na.pass) importFrom(stats,predict) importFrom(stats,qnorm) +importFrom(stats,qt) importFrom(stats,quantile) importFrom(stats,setNames) importFrom(stats,terms) diff --git a/NEWS.md b/NEWS.md index 95edff59a..c54669f21 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,8 @@ * Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510). +* A model function (`gen_additive_mod()`) was added for generalized additive models. + # parsnip 0.1.6 ## Model Specification Changes diff --git a/R/aaa.R b/R/aaa.R index 1a7109d0e..b2dc334b9 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -30,6 +30,59 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) { res } +# ------------------------------------------------------------------------------ + +#' @importFrom stats qt +# used by logistic_reg() and gen_additive_mod() +logistic_lp_to_conf_int <- function(results, object) { + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 + const <- + stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) + trans <- object$fit$family$linkinv + res_2 <- + tibble( + lo = trans(results$fit - const * results$se.fit), + hi = trans(results$fit + const * results$se.fit) + ) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res_1) <- c(lo_nms[1], hi_nms[1]) + colnames(res_2) <- c(lo_nms[2], hi_nms[2]) + res <- bind_cols(res_1, res_2) + + if (object$spec$method$pred$conf_int$extras$std_error) + res$.std_error <- results$se.fit + res +} + +# used by gen_additive_mod() +linear_lp_to_conf_int <- +function(results, object) { + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 + const <- + stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) + trans <- object$fit$family$linkinv + res <- + tibble( + .pred_lower = trans(results$fit - const * results$se.fit), + .pred_upper = trans(results$fit + const * results$se.fit) + ) + # In case of inverse or other links + if (any(res$.pred_upper < res$.pred_lower)) { + nms <- names(res) + res <- res[, 2:1] + names(res) <- nms + } + + if (object$spec$method$pred$conf_int$extras$std_error) { + res$.std_error <- results$se.fit + } + res +} + # ------------------------------------------------------------------------------ # nocov diff --git a/R/gen_additive_mod.R b/R/gen_additive_mod.R new file mode 100644 index 000000000..64f03ccb2 --- /dev/null +++ b/R/gen_additive_mod.R @@ -0,0 +1,166 @@ +# gen_additive_mod() - General Interface to Linear GAM Models +# - backend: gam +# - prediction: +# - mode = "regression" (default) uses +# - mode = "classification" + +#' Generalized additive models (GAMs) +#' +#' `gen_additive_mod()` defines a model that can use smoothed functions of +#' numeric predictors in a generalized linear model. +#' +#' There are different ways to fit this model. See the engine-specific pages +#' for more details +#' +#' More information on how `parsnip` is used for modeling is at +#' \url{https://www.tidymodels.org}. +#' +#' @inheritParams boost_tree +#' @param select_features TRUE or FALSE. If this is TRUE then can add an +#' extra penalty to each term so that it can be penalized to zero. +#' This means that the smoothing parameter estimation that is part of +#' fitting can completely remove terms from the model. If the corresponding +#' smoothing parameter is estimated as zero then the extra penalty has no effect. +#' Use `adjust_deg_free` to increase level of penalization. +#' @param adjust_deg_free If `select_features = TRUE`, then acts as a multiplier for smoothness. +#' Increase this beyond 1 to produce smoother models. +#' +#' +#' @return +#' A `parsnip` model specification +#' +#' @details +#' +#' This function only defines what _type_ of model is being fit. Once an engine +#' is specified, the _method_ to fit the model is also defined. +#' +#' The model is not trained or fit until the [fit.model_spec()] function is used +#' with the data. +#' +#' __gam__ +#' +#' This engine uses [mgcv::gam()] and has the following parameters, +#' which can be modified through the [set_engine()] function. +#' +#' ``` {r echo=F} +#' str(mgcv::gam) +#' ``` +#' +#' @section Fit Details: +#' +#' __MGCV Formula Interface__ +#' +#' Fitting GAMs is accomplished using parameters including: +#' +#' - [mgcv::s()]: GAM spline smooths +#' - [mgcv::te()]: GAM tensor product smooths +#' +#' These are applied in the `fit()` function: +#' +#' ``` r +#' fit(value ~ s(date_mon, k = 12) + s(date_num), data = df) +#' ``` +#' +#' @references \url{https://www.tidymodels.org}, +#' [_Tidy Models with R_](https://tmwr.org) +#' @examples +#' +#' #show_engines("gen_additive_mod") +#' +#' #gen_additive_mod() +#' +#' +#' @export +gen_additive_mod <- function(mode = "unknown", + select_features = NULL, + adjust_deg_free = NULL) { + + args <- list( + select_features = rlang::enquo(select_features), + adjust_deg_free = rlang::enquo(adjust_deg_free) + ) + + new_model_spec( + "gen_additive_mod", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) + +} + +#' @export +print.gen_additive_mod <- function(x, ...) { + cat("GAM Specification (", x$mode, ")\n\n", sep = "") + model_printer(x, ...) + + if(!is.null(x$method$fit$args)) { + cat("Model fit template:\n") + print(show_call(x)) + } + + invisible(x) +} + +#' @export +#' @rdname parsnip_update +#' @importFrom stats update +#' @inheritParams gen_additive_mod +update.gen_additive_mod <- function(object, + select_features = NULL, + adjust_deg_free = NULL, + parameters = NULL, + fresh = FALSE, ...) { + + update_dot_check(...) + + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } + + args <- list( + select_features = rlang::enquo(select_features), + adjust_deg_free = rlang::enquo(adjust_deg_free) + ) + + args <- update_main_parameters(args, parameters) + + if (fresh) { + object$args <- args + } else { + null_args <- purrr::map_lgl(args, null_value) + if (any(null_args)) + args <- args[!null_args] + if (length(args) > 0) + object$args[names(args)] <- args + } + + new_model_spec( + "gen_additive_mod", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) +} + + +#' @export +translate.gen_additive_mod <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'mgcv'` for translation.") + engine <- "gam" + } + x <- translate.default(x, engine, ...) + + x +} + +#' @export +#' @keywords internal +fit_xy.gen_additive_mod <- function(object, ...) { + rlang::abort("`fit()` must be used with GAM models (due to its use of formulas).") +} diff --git a/R/gen_additive_mod_data.R b/R/gen_additive_mod_data.R new file mode 100644 index 000000000..bdb9261a6 --- /dev/null +++ b/R/gen_additive_mod_data.R @@ -0,0 +1,215 @@ + +set_new_model("gen_additive_mod") +set_model_mode("gen_additive_mod", "classification") +set_model_mode("gen_additive_mod", "regression") + +# ------------------------------------------------------------------------------ +#### REGRESION ---- +set_model_engine(model = "gen_additive_mod", mode = "regression", eng = "mgcv") +set_dependency(model = "gen_additive_mod", eng = "mgcv", pkg = "mgcv") + +#Args + +# TODO make dials PR +set_model_arg( + model = "gen_additive_mod", + eng = "mgcv", + parsnip = "select_features", + original = "select", + func = list(pkg = "dials", fun = "select_features"), + has_submodel = FALSE +) + +set_model_arg( + model = "gen_additive_mod", + eng = "mgcv", + parsnip = "adjust_deg_free", + original = "gamma", + func = list(pkg = "dials", fun = "adjust_deg_free"), + has_submodel = FALSE +) + +set_encoding( + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", + options = list( + predictor_indicators = "none", + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE + ) +) + +set_fit( + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data"), + func = c(pkg = "mgcv", fun = "gam"), + defaults = list() + ) +) + +set_pred( + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = function(x, object) as.numeric(x), + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, + post = linear_lp_to_conf_int, + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data), + type = "link", + se.fit = TRUE + ) + ) +) + +set_pred( + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data) + ) + ) +) + +# ------------------------------------------------------------------------------ +#### CLASSIFICATION +set_model_engine(model = "gen_additive_mod", mode = "classification", eng = "mgcv") +set_dependency(model = "gen_additive_mod", eng = "mgcv", pkg = "mgcv") + + +set_encoding( + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", + options = list( + predictor_indicators = "none", + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE + ) +) + +set_fit( + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data"), + func = c(pkg = "mgcv", fun = "gam"), + defaults = list( + family = stats::binomial(link = "logit") + ) + ) +) + +set_pred( + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) + unname(x) + }, + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- tibble(v1 = 1 - x, v2 = x) + colnames(x) <- object$lvl + x + }, + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data) + ) + ) +) + + +set_pred( + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = logistic_lp_to_conf_int, + func = c(fun = "predict"), + args = + list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data), + type = "link", + se.fit = TRUE + ) + ) +) + diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 0fa45e7bf..906b733d8 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -95,29 +95,7 @@ set_pred( type = "conf_int", value = list( pre = NULL, - post = function(results, object) { - hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 - const <- - qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv - res_2 <- - tibble( - lo = trans(results$fit - const * results$se.fit), - hi = trans(results$fit + const * results$se.fit) - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res_1) <- c(lo_nms[1], hi_nms[1]) - colnames(res_2) <- c(lo_nms[2], hi_nms[2]) - res <- bind_cols(res_1, res_2) - - if (object$spec$method$pred$conf_int$extras$std_error) - res$.std_error <- results$se.fit - res - }, + post = logistic_lp_to_conf_int, func = c(fun = "predict"), args = list( diff --git a/man/gen_additive_mod.Rd b/man/gen_additive_mod.Rd new file mode 100644 index 000000000..3296bc04f --- /dev/null +++ b/man/gen_additive_mod.Rd @@ -0,0 +1,85 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/gen_additive_mod.R +\name{gen_additive_mod} +\alias{gen_additive_mod} +\title{Generalized additive models (GAMs)} +\usage{ +gen_additive_mod( + mode = "unknown", + select_features = NULL, + adjust_deg_free = NULL +) +} +\arguments{ +\item{mode}{A single character string for the type of model. +Possible values for this model are "unknown", "regression", or +"classification".} + +\item{select_features}{TRUE or FALSE. If this is TRUE then can add an +extra penalty to each term so that it can be penalized to zero. +This means that the smoothing parameter estimation that is part of +fitting can completely remove terms from the model. If the corresponding +smoothing parameter is estimated as zero then the extra penalty has no effect. +Use \code{adjust_deg_free} to increase level of penalization.} + +\item{adjust_deg_free}{If \code{select_features = TRUE}, then acts as a multiplier for smoothness. +Increase this beyond 1 to produce smoother models.} +} +\value{ +A \code{parsnip} model specification +} +\description{ +\code{gen_additive_mod()} defines a model that can use smoothed functions of +numeric predictors in a generalized linear model. +} +\details{ +There are different ways to fit this model. See the engine-specific pages +for more details + +More information on how \code{parsnip} is used for modeling is at +\url{https://www.tidymodels.org}. + +This function only defines what \emph{type} of model is being fit. Once an engine +is specified, the \emph{method} to fit the model is also defined. + +The model is not trained or fit until the \code{\link[=fit.model_spec]{fit.model_spec()}} function is used +with the data. + +\strong{gam} + +This engine uses \code{\link[mgcv:gam]{mgcv::gam()}} and has the following parameters, +which can be modified through the \code{\link[=set_engine]{set_engine()}} function.\preformatted{## function (formula, family = gaussian(), data = list(), weights = NULL, +## subset = NULL, na.action, offset = NULL, method = "GCV.Cp", optimizer = c("outer", +## "newton"), control = list(), scale = 0, select = FALSE, knots = NULL, +## sp = NULL, min.sp = NULL, H = NULL, gamma = 1, fit = TRUE, paraPen = NULL, +## G = NULL, in.out = NULL, drop.unused.levels = TRUE, drop.intercept = NULL, +## discrete = FALSE, ...) +} +} +\section{Fit Details}{ + + +\strong{MGCV Formula Interface} + +Fitting GAMs is accomplished using parameters including: +\itemize{ +\item \code{\link[mgcv:s]{mgcv::s()}}: GAM spline smooths +\item \code{\link[mgcv:te]{mgcv::te()}}: GAM tensor product smooths +} + +These are applied in the \code{fit()} function:\if{html}{\out{
}}\preformatted{fit(value ~ s(date_mon, k = 12) + s(date_num), data = df) +}\if{html}{\out{
}} +} + +\examples{ + +#show_engines("gen_additive_mod") + +#gen_additive_mod() + + +} +\references{ +\url{https://www.tidymodels.org}, +\href{https://tmwr.org}{\emph{Tidy Models with R}} +} diff --git a/man/parsnip_update.Rd b/man/parsnip_update.Rd index b1cea173f..8bec62e48 100644 --- a/man/parsnip_update.Rd +++ b/man/parsnip_update.Rd @@ -1,12 +1,13 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/boost_tree.R, R/decision_tree.R, -% R/linear_reg.R, R/logistic_reg.R, R/mars.R, R/mlp.R, R/multinom_reg.R, -% R/nearest_neighbor.R, R/proportional_hazards.R, R/rand_forest.R, -% R/surv_reg.R, R/survival_reg.R, R/svm_linear.R, R/svm_poly.R, R/svm_rbf.R, -% R/update.R +% R/gen_additive_mod.R, R/linear_reg.R, R/logistic_reg.R, R/mars.R, +% R/mlp.R, R/multinom_reg.R, R/nearest_neighbor.R, R/proportional_hazards.R, +% R/rand_forest.R, R/surv_reg.R, R/survival_reg.R, R/svm_linear.R, +% R/svm_poly.R, R/svm_rbf.R, R/update.R \name{update.boost_tree} \alias{update.boost_tree} \alias{update.decision_tree} +\alias{update.gen_additive_mod} \alias{update.linear_reg} \alias{update.logistic_reg} \alias{update.mars} @@ -48,6 +49,15 @@ ... ) +\method{update}{gen_additive_mod}( + object, + select_features = NULL, + adjust_deg_free = NULL, + parameters = NULL, + fresh = FALSE, + ... +) + \method{update}{linear_reg}( object, parameters = NULL, @@ -203,6 +213,16 @@ modified in-place or replaced wholesale.} \item{cost_complexity}{A positive number for the the cost/complexity parameter (a.k.a. \code{Cp}) used by CART models (\code{rpart} only).} +\item{select_features}{TRUE or FALSE. If this is TRUE then can add an +extra penalty to each term so that it can be penalized to zero. +This means that the smoothing parameter estimation that is part of +fitting can completely remove terms from the model. If the corresponding +smoothing parameter is estimated as zero then the extra penalty has no effect. +Use \code{adjust_deg_free} to increase level of penalization.} + +\item{adjust_deg_free}{If \code{select_features = TRUE}, then acts as a multiplier for smoothness. +Increase this beyond 1 to produce smoother models.} + \item{penalty}{A non-negative number representing the total amount of regularization (\code{glmnet}, \code{keras}, and \code{spark} only). For \code{keras} models, this corresponds to purely L2 regularization diff --git a/tests/testthat/test_gen_additive_model.R b/tests/testthat/test_gen_additive_model.R new file mode 100644 index 000000000..4c82483c6 --- /dev/null +++ b/tests/testthat/test_gen_additive_model.R @@ -0,0 +1,105 @@ +library(testthat) +library(parsnip) +library(rlang) +library(tibble) +library(mgcv) + +data(two_class_dat, package = "modeldata") + +# ------------------------------------------------------------------------------ + +context("generalized additive models") + +# ------------------------------------------------------------------------------ + + +test_that('regression', { + skip_if_not_installed("mgcv") + + reg_mod <- + gen_additive_mod(select_features = TRUE) %>% + set_engine("mgcv") %>% + set_mode("regression") + + expect_error( + f_res <- fit( + reg_mod, + mpg ~ s(disp) + wt + gear, + data = mtcars + ), + regexp = NA + ) + expect_error( + xy_res <- fit_xy( + reg_mod, + x = mtcars[, 1:5], + y = mtcars$mpg, + control = ctrl + ), + regexp = "must be used with GAM models" + ) + mgcv_mod <- mgcv::gam(mpg ~ s(disp) + wt + gear, data = mtcars, select = TRUE) + expect_equal(coef(mgcv_mod), coef(f_res$fit)) + + f_pred <- predict(f_res, head(mtcars)) + mgcv_pred <- predict(mgcv_mod, head(mtcars), type = "response") + expect_equal(names(f_pred), ".pred") + expect_equivalent(f_pred[[".pred"]], unname(mgcv_pred)) + + f_ci <- predict(f_res, head(mtcars), type = "conf_int", std_error = TRUE) + mgcv_ci <- predict(mgcv_mod, head(mtcars), type = "link", se.fit = TRUE) + expect_equivalent(f_ci[[".std_error"]], unname(mgcv_ci$se.fit)) + lower <- + mgcv_ci$fit - qt(0.025, df = mgcv_mod$df.residual, lower.tail = FALSE) * mgcv_ci$se.fit + expect_equivalent(f_ci[[".pred_lower"]], unname(lower)) + +}) + +# ------------------------------------------------------------------------------ + +test_that('classification', { + skip_if_not_installed("mgcv") + + cls_mod <- + gen_additive_mod(adjust_deg_free = 1.5) %>% + set_engine("mgcv") %>% + set_mode("classification") + + expect_error( + f_res <- fit( + cls_mod, + Class ~ s(A, k = 10) + B, + data = two_class_dat + ), + regexp = NA + ) + expect_error( + xy_res <- fit_xy( + cls_mod, + x = two_class_dat[, 2:3], + y = two_class_dat$Class, + control = ctrl + ), + regexp = "must be used with GAM models" + ) + mgcv_mod <- + mgcv::gam(Class ~ s(A, k = 10) + B, + data = two_class_dat, + gamma = 1.5, + family = binomial) + expect_equal(coef(mgcv_mod), coef(f_res$fit)) + + f_pred <- predict(f_res, head(two_class_dat), type = "prob") + mgcv_pred <- predict(mgcv_mod, head(two_class_dat), type = "response") + expect_equal(names(f_pred), c(".pred_Class1", ".pred_Class2")) + expect_equivalent(f_pred[[".pred_Class2"]], unname(mgcv_pred)) + + f_ci <- predict(f_res, head(two_class_dat), type = "conf_int", std_error = TRUE) + mgcv_ci <- predict(mgcv_mod, head(two_class_dat), type = "link", se.fit = TRUE) + expect_equivalent(f_ci[[".std_error"]], unname(mgcv_ci$se.fit)) + lower <- + mgcv_ci$fit - qt(0.025, df = mgcv_mod$df.residual, lower.tail = FALSE) * mgcv_ci$se.fit + lower <- binomial()$linkinv(lower) + expect_equivalent(f_ci[[".pred_lower_Class2"]], unname(lower)) + +})