-
Notifications
You must be signed in to change notification settings - Fork 7
4. How It Works
The regression and classification functions work in identical fashions so we will only pick on the regression function, that is to say, how does the fast_regression()
function work. While, it does work for many, it does fail for some, which means the design is flawed and it is possible at this point that it is a fundamental flaw (which I think it is). First I will post the code, and then we will see what the output of each step looks like.
Here is the full function:
fast_regression <- function(.data, .rec_obj, .parsnip_fns = "all",
.parsnip_eng = "all", .split_type = "initial_split",
.split_args = NULL){
# Tidy Eval ----
call <- list(.parsnip_fns) %>%
purrr::flatten_chr()
engine <- list(.parsnip_eng) %>%
purrr::flatten_chr()
rec_obj <- .rec_obj
split_type <- .split_type
split_args <- .split_args
# Checks ----
# Get data splits
df <- dplyr::as_tibble(.data)
splits_obj <- create_splits(
.data = df,
.split_type = split_type,
.split_args = split_args
)
# Generate Model Spec Tbl
mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
.parsnip_fns = call,
.parsnip_eng = engine
)
# Generate Workflow object
mod_tbl <- mod_spec_tbl %>%
dplyr::mutate(
wflw = internal_make_wflw(mod_spec_tbl, .rec_obj = rec_obj)
)
mod_fitted_tbl <- mod_tbl %>%
dplyr::mutate(
fitted_wflw = internal_make_fitted_wflw(mod_tbl, splits_obj)
)
mod_pred_tbl <- mod_fitted_tbl %>%
dplyr::mutate(
pred_wflw = internal_make_wflw_predictions(mod_fitted_tbl, splits_obj)
)
# Return ----
class(mod_tbl) <- c("fst_reg_tbl", class(mod_tbl))
attr(mod_tbl, ".parsnip_engines") <- .parsnip_eng
attr(mod_tbl, ".parsnip_functions") <- .parsnip_fns
attr(mod_tbl, ".split_type") <- .split_type
attr(mod_tbl, ".split_args") <- .split_args
return(mod_pred_tbl)
}
We see that the function indeed is broken up into steps basically by building off of the previous output. So what is happening?
This function, named fast_regression
, is designed to generate model specifications for regression using the parsnip
package in R. Let's break down the key components of the code:
-
Function Signature:
fast_regression <- function(.data, .rec_obj, .parsnip_fns = "all", .parsnip_eng = "all", .split_type = "initial_split", .split_args = NULL)
This function takes several parameters:
-
.data
: The data for the regression problem. -
.rec_obj
: A recipe object. -
.parsnip_fns
: The parsnip model functions to use. Default is "all." -
.parsnip_eng
: The parsnip model engines to use. Default is "all." -
.split_type
: The type of data split to use (e.g., initial split). Default is "initial_split." -
.split_args
: Additional arguments for data splitting. Default is NULL.
-
-
Tidy Eval and Parameter Handling:
call <- list(.parsnip_fns) %>% purrr::flatten_chr() engine <- list(.parsnip_eng) %>% purrr::flatten_chr() rec_obj <- .rec_obj split_type <- .split_type split_args <- .split_args
Here, the function uses tidy evaluation to handle the parameters related to parsnip functions and engines.
-
Data Splitting:
df <- dplyr::as_tibble(.data) splits_obj <- create_splits( .data = df, .split_type = split_type, .split_args = split_args )
The function converts the input data to a tibble and then uses the
create_splits
function to generate data splits based on the specified split type and arguments. -
Model Specification Table Generation:
mod_spec_tbl <- fast_regression_parsnip_spec_tbl( .parsnip_fns = call, .parsnip_eng = engine )
The
fast_regression_parsnip_spec_tbl
function is called to generate a table of parsnip model specifications based on the specified parsnip functions and engines. -
Workflow and Fitted Model Generation:
mod_tbl <- mod_spec_tbl %>% dplyr::mutate( wflw = internal_make_wflw(mod_spec_tbl, .rec_obj = rec_obj) ) mod_fitted_tbl <- mod_tbl %>% dplyr::mutate( fitted_wflw = internal_make_fitted_wflw(mod_tbl, splits_obj) )
The function creates a workflow (
wflw
) and a fitted workflow (fitted_wflw
) using internal functions (internal_make_wflw
andinternal_make_fitted_wflw
) based on the generated model specifications and data splits. -
Prediction Generation:
mod_pred_tbl <- mod_fitted_tbl %>% dplyr::mutate( pred_wflw = internal_make_wflw_predictions(mod_fitted_tbl, splits_obj) )
Finally, the function generates predictions using the fitted workflow and the data splits.
-
Return:
return(mod_pred_tbl)
The function returns a table (
mod_pred_tbl
) containing information about the generated model specifications, workflows, fitted models, and predictions.
Let's now walk through an example that will build the output piece by piece. This starts with a primitive function of the package called make_regression_base_tbl()
which holds the arguments that get passed to the functions later down the line in order to build the specifications of the mdoels.
> make_regression_base_tbl() |>
+ internal_make_spec_tbl()
# A tibble: 39 × 5
.model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec
<int> <chr> <chr> <chr> <list>
1 1 lm regression linear_reg <spec[+]>
2 2 brulee regression linear_reg <spec[+]>
3 3 gee regression linear_reg <spec[+]>
4 4 glm regression linear_reg <spec[+]>
5 5 glmer regression linear_reg <spec[+]>
6 6 glmnet regression linear_reg <spec[+]>
7 7 gls regression linear_reg <spec[+]>
8 8 lme regression linear_reg <spec[+]>
9 9 lmer regression linear_reg <spec[+]>
10 10 stan regression linear_reg <spec[+]>
# ℹ 29 more rows
# ℹ Use `print(n = ...)` to see more rows
> fast_regression_parsnip_spec_tbl(
.parsnip_eng = c('lm','glm','gee'),
.parsnip_fns = 'linear_reg'
)
# A tibble: 3 × 5
.model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec
<int> <chr> <chr> <chr> <list>
1 1 lm regression linear_reg <spec[+]>
2 2 gee regression linear_reg <spec[+]>
3 3 glm regression linear_reg <spec[+]>
Again we can see that gee
fails:
> fast_regression_parsnip_spec_tbl(.parsnip_eng = c('lm','glm','gee'), .parsnip_fns = 'linear_reg') |>
+ pull(model_spec) |> pluck(2)
! parsnip could not locate an implementation for `linear_reg` regression model specifications using
the `gee` engine.
ℹ The parsnip extension package multilevelmod implements support for this specification.
ℹ Please install (if needed) and load to continue.
Linear Regression Model Specification (regression)
Computational engine: gee
Let's save these results as mod_spec_tbl
Let's make the workflow now.
mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
.parsnip_eng = c('lm','glm','gee'),
.parsnip_fns = 'linear_reg'
)
> mod_tbl <- mod_spec_tbl %>%
+ dplyr::mutate(
+ wflw = internal_make_wflw(mod_spec_tbl, .rec_obj = rec_obj)
+ )
Error in `.f()`:
! parsnip could not locate an implementation for `linear_reg` regression model specifications
using the `gee` engine.
ℹ The parsnip extension package multilevelmod implements support for this specification.
ℹ Please install (if needed) and load to continue.
> mod_tbl
# A tibble: 3 × 6
.model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec wflw
<int> <chr> <chr> <chr> <list> <list>
1 1 lm regression linear_reg <spec[+]> <workflow>
2 2 gee regression linear_reg <spec[+]> <NULL>
3 3 glm regression linear_reg <spec[+]> <workflow>
The next step requires that the splits object splits_obj
exists, so we will make it first and proceed. We use the create_splits()
function which is internal to tidyAML
splits_obj <- create_splits(
.data = mtcars,
.split_type = "initial_split",
.split_args = NULL
)
splits_obj
$splits
<Training/Testing/Total>
<24/8/32>
$split_type
[1] "initial_split"
mod_fitted_tbl <- mod_tbl %>%
dplyr::mutate(
fitted_wflw = internal_make_fitted_wflw(mod_tbl, splits_obj)
)
Error in UseMethod("fit"): no applicable method for 'fit' applied to an object of class "NULL"
mod_fitted_tbl
# A tibble: 3 × 7
.model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec wflw fitted_wflw
<int> <chr> <chr> <chr> <list> <list> <list>
1 1 lm regression linear_reg <spec[+]> <workflow> <workflow>
2 2 gee regression linear_reg <spec[+]> <NULL> <NULL>
3 3 glm regression linear_reg <spec[+]> <workflow> <workflow>
Now on to the predictions.
> mod_pred_tbl <- mod_fitted_tbl %>%
+ dplyr::mutate(
+ pred_wflw = internal_make_wflw_predictions(mod_fitted_tbl, splits_obj)
+ )
Error in UseMethod("predict"): no applicable method for 'predict' applied to an object of class "NULL"
> mod_pred_tbl
# A tibble: 3 × 8
.model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec wflw fitted_wflw pred_wflw
<int> <chr> <chr> <chr> <list> <list> <list> <list>
1 1 lm regression linear_reg <spec[+]> <workflow> <workflow> <tibble>
2 2 gee regression linear_reg <spec[+]> <NULL> <NULL> <NULL>
3 3 glm regression linear_reg <spec[+]> <workflow> <workflow> <tibble>
> mod_pred_tbl |> pull(pred_wflw)
[[1]]
# A tibble: 8 × 1
.pred
<dbl>
1 22.1
2 22.1
3 12.9
4 14.8
5 27.0
6 17.0
7 17.6
8 23.6
[[2]]
NULL
[[3]]
# A tibble: 8 × 1
.pred
<dbl>
1 22.1
2 22.1
3 12.9
4 14.8
5 27.0
6 17.0
7 17.6
8 23.6
That is how the fast_regression()
and fast_classification()
functions work.