-
Notifications
You must be signed in to change notification settings - Fork 47
Closed
Labels
featurea feature request or enhancementa feature request or enhancement
Description
In tidymodels/tidymodels#58, @hfrick pointed out that it is pretty awkward to get out the fitted workflow from the output of last_fit(); it's all $.workflow[[1]]. 😣 This is what you have to do to get a fitted workflow you can use for prediction.
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
set.seed(6735)
car_split <- initial_split(mtcars)
spline_rec <- recipe(mpg ~ ., data = mtcars) %>%
step_ns(disp)
lin_mod <- linear_reg() %>%
set_engine("lm")
spline_wfl <-
workflow() %>%
add_recipe(spline_rec) %>%
add_model(lin_mod)
final_fitted <- last_fit(spline_wfl, car_split)
collect_metrics(final_fitted)
#> # A tibble: 2 x 4
#> .metric .estimator .estimate .config
#> <chr> <chr> <dbl> <chr>
#> 1 rmse standard 5.74 Preprocessor1_Model1
#> 2 rsq standard 0.503 Preprocessor1_Model1
collect_predictions(final_fitted)
#> # A tibble: 8 x 5
#> id .pred .row mpg .config
#> <chr> <dbl> <int> <dbl> <chr>
#> 1 train/test split 18.1 5 18.7 Preprocessor1_Model1
#> 2 train/test split 31.1 9 22.8 Preprocessor1_Model1
#> 3 train/test split 16.9 14 15.2 Preprocessor1_Model1
#> 4 train/test split 20.4 15 10.4 Preprocessor1_Model1
#> 5 train/test split 18.0 16 10.4 Preprocessor1_Model1
#> 6 train/test split 28.3 19 30.4 Preprocessor1_Model1
#> 7 train/test split 9.04 24 13.3 Preprocessor1_Model1
#> 8 train/test split 23.0 30 19.7 Preprocessor1_Model1
final_fitted$.workflow[[1]]
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: linear_reg()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#>
#> • step_ns()
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#>
#> Call:
#> stats::lm(formula = ..y ~ ., data = data)
#>
#> Coefficients:
#> (Intercept) cyl hp drat wt qsec
#> -19.69774 1.09769 -0.02326 -3.05865 -4.28722 2.80367
#> vs am gear carb disp_ns_1 disp_ns_2
#> -3.37155 0.18869 4.33774 0.52292 -12.82725 4.55395Created on 2021-05-03 by the reprex package (v2.0.0)
What do we think about a little helper function to take care of this, like we have the helper functions for the metrics and predictions? We have collect_metrics(), collect_predictions() -- maybe collect_workflow()? It's nice and parallel from the same structure. I feel like pull_xx is another obvious choice but it is getting overloaded quickly for us.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
featurea feature request or enhancementa feature request or enhancement