Skip to content

Helper function to extract fitted workflow from last_fit() output #378

@juliasilge

Description

@juliasilge

In tidymodels/tidymodels#58, @hfrick pointed out that it is pretty awkward to get out the fitted workflow from the output of last_fit(); it's all $.workflow[[1]]. 😣 This is what you have to do to get a fitted workflow you can use for prediction.

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip

set.seed(6735)
car_split <- initial_split(mtcars)

spline_rec <- recipe(mpg ~ ., data = mtcars) %>%
  step_ns(disp)

lin_mod <- linear_reg() %>%
  set_engine("lm")

spline_wfl <-
  workflow() %>%
  add_recipe(spline_rec) %>%
  add_model(lin_mod)

final_fitted <- last_fit(spline_wfl, car_split)

collect_metrics(final_fitted)
#> # A tibble: 2 x 4
#>   .metric .estimator .estimate .config             
#>   <chr>   <chr>          <dbl> <chr>               
#> 1 rmse    standard       5.74  Preprocessor1_Model1
#> 2 rsq     standard       0.503 Preprocessor1_Model1
collect_predictions(final_fitted)
#> # A tibble: 8 x 5
#>   id               .pred  .row   mpg .config             
#>   <chr>            <dbl> <int> <dbl> <chr>               
#> 1 train/test split 18.1      5  18.7 Preprocessor1_Model1
#> 2 train/test split 31.1      9  22.8 Preprocessor1_Model1
#> 3 train/test split 16.9     14  15.2 Preprocessor1_Model1
#> 4 train/test split 20.4     15  10.4 Preprocessor1_Model1
#> 5 train/test split 18.0     16  10.4 Preprocessor1_Model1
#> 6 train/test split 28.3     19  30.4 Preprocessor1_Model1
#> 7 train/test split  9.04    24  13.3 Preprocessor1_Model1
#> 8 train/test split 23.0     30  19.7 Preprocessor1_Model1


final_fitted$.workflow[[1]]
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: linear_reg()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#> 
#> • step_ns()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> 
#> Call:
#> stats::lm(formula = ..y ~ ., data = data)
#> 
#> Coefficients:
#> (Intercept)          cyl           hp         drat           wt         qsec  
#>   -19.69774      1.09769     -0.02326     -3.05865     -4.28722      2.80367  
#>          vs           am         gear         carb    disp_ns_1    disp_ns_2  
#>    -3.37155      0.18869      4.33774      0.52292    -12.82725      4.55395

Created on 2021-05-03 by the reprex package (v2.0.0)

What do we think about a little helper function to take care of this, like we have the helper functions for the metrics and predictions? We have collect_metrics(), collect_predictions() -- maybe collect_workflow()? It's nice and parallel from the same structure. I feel like pull_xx is another obvious choice but it is getting overloaded quickly for us.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featurea feature request or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions