Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Suggests:
nlme,
modeldata,
LiblineaR,
Matrix
Matrix,
mgcv
Remotes:
topepo/C5.0
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

S3method(augment,model_fit)
S3method(fit,model_spec)
S3method(fit_xy,gen_additive_mod)
S3method(fit_xy,model_spec)
S3method(glance,model_fit)
S3method(has_multi_predict,default)
Expand Down Expand Up @@ -46,6 +47,7 @@ S3method(predict_time,model_fit)
S3method(print,boost_tree)
S3method(print,control_parsnip)
S3method(print,decision_tree)
S3method(print,gen_additive_mod)
S3method(print,linear_reg)
S3method(print,logistic_reg)
S3method(print,mars)
Expand Down Expand Up @@ -76,6 +78,7 @@ S3method(tidy,nullmodel)
S3method(translate,boost_tree)
S3method(translate,decision_tree)
S3method(translate,default)
S3method(translate,gen_additive_mod)
S3method(translate,linear_reg)
S3method(translate,logistic_reg)
S3method(translate,mars)
Expand All @@ -91,6 +94,7 @@ S3method(type_sum,model_fit)
S3method(type_sum,model_spec)
S3method(update,boost_tree)
S3method(update,decision_tree)
S3method(update,gen_additive_mod)
S3method(update,linear_reg)
S3method(update,logistic_reg)
S3method(update,mars)
Expand Down Expand Up @@ -136,6 +140,7 @@ export(fit.model_spec)
export(fit_control)
export(fit_xy)
export(fit_xy.model_spec)
export(gen_additive_mod)
export(get_dependency)
export(get_encoding)
export(get_fit)
Expand Down Expand Up @@ -295,6 +300,7 @@ importFrom(stats,na.omit)
importFrom(stats,na.pass)
importFrom(stats,predict)
importFrom(stats,qnorm)
importFrom(stats,qt)
importFrom(stats,quantile)
importFrom(stats,setNames)
importFrom(stats,terms)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510).

* A model function (`gen_additive_mod()`) was added for generalized additive models.

# parsnip 0.1.6

## Model Specification Changes
Expand Down
53 changes: 53 additions & 0 deletions R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,59 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
res
}

# ------------------------------------------------------------------------------

#' @importFrom stats qt
# used by logistic_reg() and gen_additive_mod()
logistic_lp_to_conf_int <- function(results, object) {
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
const <-
stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
trans <- object$fit$family$linkinv
res_2 <-
tibble(
lo = trans(results$fit - const * results$se.fit),
hi = trans(results$fit + const * results$se.fit)
)
res_1 <- res_2
res_1$lo <- 1 - res_2$hi
res_1$hi <- 1 - res_2$lo
lo_nms <- paste0(".pred_lower_", object$lvl)
hi_nms <- paste0(".pred_upper_", object$lvl)
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
res <- bind_cols(res_1, res_2)

if (object$spec$method$pred$conf_int$extras$std_error)
res$.std_error <- results$se.fit
res
}

# used by gen_additive_mod()
linear_lp_to_conf_int <-
function(results, object) {
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
const <-
stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
trans <- object$fit$family$linkinv
res <-
tibble(
.pred_lower = trans(results$fit - const * results$se.fit),
.pred_upper = trans(results$fit + const * results$se.fit)
)
# In case of inverse or other links
if (any(res$.pred_upper < res$.pred_lower)) {
nms <- names(res)
res <- res[, 2:1]
names(res) <- nms
}

if (object$spec$method$pred$conf_int$extras$std_error) {
res$.std_error <- results$se.fit
}
res
}

# ------------------------------------------------------------------------------
# nocov

Expand Down
166 changes: 166 additions & 0 deletions R/gen_additive_mod.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# gen_additive_mod() - General Interface to Linear GAM Models
# - backend: gam
# - prediction:
# - mode = "regression" (default) uses
# - mode = "classification"

#' Generalized additive models (GAMs)
#'
#' `gen_additive_mod()` defines a model that can use smoothed functions of
#' numeric predictors in a generalized linear model.
#'
#' There are different ways to fit this model. See the engine-specific pages
#' for more details
#'
#' More information on how `parsnip` is used for modeling is at
#' \url{https://www.tidymodels.org}.
#'
#' @inheritParams boost_tree
#' @param select_features TRUE or FALSE. If this is TRUE then can add an
#' extra penalty to each term so that it can be penalized to zero.
#' This means that the smoothing parameter estimation that is part of
#' fitting can completely remove terms from the model. If the corresponding
#' smoothing parameter is estimated as zero then the extra penalty has no effect.
#' Use `adjust_deg_free` to increase level of penalization.
#' @param adjust_deg_free If `select_features = TRUE`, then acts as a multiplier for smoothness.
#' Increase this beyond 1 to produce smoother models.
#'
#'
#' @return
#' A `parsnip` model specification
#'
#' @details
#'
#' This function only defines what _type_ of model is being fit. Once an engine
#' is specified, the _method_ to fit the model is also defined.
#'
#' The model is not trained or fit until the [fit.model_spec()] function is used
#' with the data.
#'
#' __gam__
#'
#' This engine uses [mgcv::gam()] and has the following parameters,
#' which can be modified through the [set_engine()] function.
#'
#' ``` {r echo=F}
#' str(mgcv::gam)
#' ```
#'
#' @section Fit Details:
#'
#' __MGCV Formula Interface__
#'
#' Fitting GAMs is accomplished using parameters including:
#'
#' - [mgcv::s()]: GAM spline smooths
#' - [mgcv::te()]: GAM tensor product smooths
#'
#' These are applied in the `fit()` function:
#'
#' ``` r
#' fit(value ~ s(date_mon, k = 12) + s(date_num), data = df)
#' ```
#'
#' @references \url{https://www.tidymodels.org},
#' [_Tidy Models with R_](https://tmwr.org)
#' @examples
#'
#' #show_engines("gen_additive_mod")
#'
#' #gen_additive_mod()
#'
#'
#' @export
gen_additive_mod <- function(mode = "unknown",
select_features = NULL,
adjust_deg_free = NULL) {

args <- list(
select_features = rlang::enquo(select_features),
adjust_deg_free = rlang::enquo(adjust_deg_free)
)

new_model_spec(
"gen_additive_mod",
args = args,
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
)

}

#' @export
print.gen_additive_mod <- function(x, ...) {
cat("GAM Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if(!is.null(x$method$fit$args)) {
cat("Model fit template:\n")
print(show_call(x))
}

invisible(x)
}

#' @export
#' @rdname parsnip_update
#' @importFrom stats update
#' @inheritParams gen_additive_mod
update.gen_additive_mod <- function(object,
select_features = NULL,
adjust_deg_free = NULL,
parameters = NULL,
fresh = FALSE, ...) {

update_dot_check(...)

if (!is.null(parameters)) {
parameters <- check_final_param(parameters)
}

args <- list(
select_features = rlang::enquo(select_features),
adjust_deg_free = rlang::enquo(adjust_deg_free)
)

args <- update_main_parameters(args, parameters)

if (fresh) {
object$args <- args
} else {
null_args <- purrr::map_lgl(args, null_value)
if (any(null_args))
args <- args[!null_args]
if (length(args) > 0)
object$args[names(args)] <- args
}

new_model_spec(
"gen_additive_mod",
args = object$args,
eng_args = object$eng_args,
mode = object$mode,
method = NULL,
engine = object$engine
)
}


#' @export
translate.gen_additive_mod <- function(x, engine = x$engine, ...) {
if (is.null(engine)) {
message("Used `engine = 'mgcv'` for translation.")
engine <- "gam"
}
x <- translate.default(x, engine, ...)

x
}

#' @export
#' @keywords internal
fit_xy.gen_additive_mod <- function(object, ...) {
rlang::abort("`fit()` must be used with GAM models (due to its use of formulas).")
}
Loading