Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make fast_regression() function #7

Closed
spsanderson opened this issue Dec 5, 2022 · 0 comments
Closed

Make fast_regression() function #7

spsanderson opened this issue Dec 5, 2022 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@spsanderson
Copy link
Owner

spsanderson commented Dec 5, 2022

Function:

fast_regression <- function(.data, .rec_obj, .parsnip_fns = "all",
                            .parsnip_eng = "all", .split_type = "initial_split",
                            .split_args = NULL){
  
  # Tidy Eval ----
  call <- list(.parsnip_fns) %>%
    purrr::flatten_chr()
  engine <- list(.parsnip_eng) %>%
    purrr::flatten_chr()
  
  rec_obj <- .rec_obj
  split_type <- .split_type
  split_args <- .split_args
  
  # Checks ----
  
  # Get data splits
  df <- dplyr::as_tibble(.data)
  splits_obj <- create_splits(
    .data = df, 
    .split_type = split_type,
    .split_args = split_args
  )
  
  # Generate Model Spec Tbl
  mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
    .parsnip_fns = call,
    .parsnip_eng = engine
  )
  
  mod_rec_tbl <- mod_spec_tbl %>%
    dplyr::mutate(.model_recipe = list(rec_obj))
  
  mod_tbl <- mod_rec_tbl %>%
    dplyr::mutate(
      .wflw = list(
        workflows::workflow() %>%
          workflows::add_recipe(.model_recipe[[1]]) %>%
          workflows::add_model(.model_spec[[1]])
      )
    ) %>%
    dplyr::mutate(
      .fitted_wflw = list(
        parsnip::fit(.wflw[[1]], data = rsample::training(splits_obj$splits))
      )
    ) %>%
    dplyr::mutate(
      .pred_wflw = list(
        predict(.fitted_wflw[[1]], new_data = rsample::testing(splits_obj$splits))
      )
    )
  
  # Return ----
  class(mod_tbl) <- c("fst_reg_tbl", class(mod_tbl))
  attr(mod_tbl, ".parsnip_engines") <- .parsnip_eng
  attr(mod_tbl, ".parsnip_functions") <- .parsnip_fns
  attr(mod_tbl, ".split_type") <- .split_type
  attr(mod_tbl, ".split_args") <- .split_args
  return(mod_tbl)
}

Example:

> rec_obj <- recipes::recipe(mpg ~ ., data = mtcars)
> frt_tbl <- fast_regression(mtcars, rec_obj, .parsnip_eng = c("lm","glm"))
> frt_tbl
# A tibble: 3 × 8
  .parsnip_engine .parsn…¹ .pars…² .model_…³ .model…⁴ .wflw      .fitted_…⁵ .pred_…⁶
  <chr>           <chr>    <chr>   <list>    <list>   <list>     <list>     <list>  
1 lm              regresslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
2 glm             regresslinear<spec[+]> <recipe> <workflow> <workflow> <tibble>
3 glm             regresspoisso<spec[+]> <recipe> <workflow> <workflow> <tibble>
# … with abbreviated variable names ¹​.parsnip_mode, ²​.parsnip_fns, ³​.model_spec,
#   ⁴​.model_recipe, ⁵​.fitted_wflw, ⁶​.pred_wflw
> class(frt_tbl)
[1] "fst_reg_tbl"      "fst_reg_spec_tbl" "tbl_df"           "tbl"             
[5] "data.frame"      
> attributes(frt_tbl)
$names
[1] ".parsnip_engine" ".parsnip_mode"   ".parsnip_fns"    ".model_spec"    
[5] ".model_recipe"   ".wflw"           ".fitted_wflw"    ".pred_wflw"     

$row.names
[1] 1 2 3

$class
[1] "fst_reg_tbl"      "fst_reg_spec_tbl" "tbl_df"           "tbl"             
[5] "data.frame"      

$.parsnip_engines
[1] "lm"  "glm"

$.parsnip_functions
[1] "all"

$.split_type
[1] "initial_split"

> 
> frt_tbl$.fitted_wflw[[1]] %>%
+   broom::glance()
# A tibble: 1 × 12
  r.squared adj.r.s…¹ sigma stati…² p.value    df logLik   AIC   BIC devia…³ df.re…⁴
      <dbl>     <dbl> <dbl>   <dbl>   <dbl> <dbl>  <dbl> <dbl> <dbl>   <dbl>   <int>
1     0.880     0.788  2.58    9.55 1.76e-4    10  -49.5  123.  137.    86.8      13
# … with 1 more variable: nobs <int>, and abbreviated variable names
#   ¹​adj.r.squared, ²​statistic, ³​deviance, ⁴​df.residual
# ℹ Use `colnames()` to see all variable names
> 
> frt_tbl$.fitted_wflw[[1]] %>%
+   broom::tidy()
# A tibble: 11 × 5
   term        estimate std.error statistic p.value
   <chr>          <dbl>     <dbl>     <dbl>   <dbl>
 1 (Intercept)  14.3      20.6       0.692    0.501
 2 cyl          -0.507     1.14     -0.443    0.665
 3 disp          0.0119    0.0252    0.473    0.644
 4 hp           -0.0279    0.0249   -1.12     0.282
 5 drat          1.72      1.89      0.912    0.379
 6 wt           -2.99      2.34     -1.28     0.223
 7 qsec          0.670     0.786     0.852    0.409
 8 vs           -0.391     2.43     -0.161    0.874
 9 am            2.27      2.22      1.02     0.326
10 gear         -0.0871    1.59     -0.0549   0.957
11 carb          0.257     0.957     0.269    0.792
@spsanderson spsanderson added the enhancement New feature or request label Dec 5, 2022
@spsanderson spsanderson self-assigned this Dec 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Development

No branches or pull requests

1 participant