Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fitted_wflw #31

Closed
Tracked by #27
spsanderson opened this issue Dec 8, 2022 · 0 comments
Closed
Tracked by #27

fitted_wflw #31

spsanderson opened this issue Dec 8, 2022 · 0 comments
Assignees
Labels
enhancement New feature or request
Milestone

Comments

@spsanderson
Copy link
Owner

spsanderson commented Dec 8, 2022

Function:

# Safely make fitted workflow
internal_make_fitted_wflw <- function(.model_tbl, .splits_obj){
  
  # Tidyeval ----
  model_tbl <- .model_tbl
  splits_obj <- .splits_obj
  col_nms <- colnames(model_tbl)
  
  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }
  
  if (!"wflw" %in% col_nms){
    rlang::abort(
      message = "Missing the column 'wflw'",
      use_cli_format = TRUE
    )
  }
  
  if (!".model_id" %in% col_nms){
    rlang::abort(
      message = "Missing the column '.model_id'",
      use_cli_format = TRUE
    )
  }
  
  # Manipulation
  # Make a group split object list
  models_list <- model_tbl %>%
    dplyr::group_split(.model_id)
  
  # Make the fitted workflow object using purrr imap
  fitted_wflw_list <- models_list %>%
    purrr::imap(
      .f = function(obj, id){
        
        # Pull the workflow column and then pluck it
        wflw <- obj %>% dplyr::pull(6) %>% pluck(1)
        
        # Create a safe parsnip::fit function
        safe_parsnip_fit <- purrr::safely(
          parsnip::fit,
          otherwise = "Error - Could not fit the workflow.",
          quiet = FALSE
        )
        
        # Return the fitted workflow
        ret <- safe_parsnip_fit(
          wflw, data = rsample::training(splits_obj$splits)
        )
        
        res <- ret %>% purrr::pluck("result")
        
        return(res)
      }
    )
    
  return(fitted_wflw_list)
  
}

Example:

> internal_make_fitted_wflw(mod_tbl, splits_obj)
[[1]]
══ Workflow [trained] ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ──────────────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ─────────────────────────────────────────────────────────────────────────────────

Call:
stats::lm(formula = ..y ~ ., data = data)

Coefficients:
(Intercept)          cyl         disp           hp         drat           wt  
  28.907650    -0.664742    -0.009334    -0.014871     0.197659    -0.188327  
       qsec           vs           am         gear         carb  
  -0.190551     0.132323     1.732139     1.372764    -1.184251  


[[2]]
══ Workflow [trained] ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ──────────────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ─────────────────────────────────────────────────────────────────────────────────

Call:  stats::glm(formula = ..y ~ ., family = stats::gaussian, data = data)

Coefficients:
(Intercept)          cyl         disp           hp         drat           wt  
  28.907650    -0.664742    -0.009334    -0.014871     0.197659    -0.188327  
       qsec           vs           am         gear         carb  
  -0.190551     0.132323     1.732139     1.372764    -1.184251  

Degrees of Freedom: 23 Total (i.e. Null);  13 Residual
Null Deviance:	    736.5 
Residual Deviance: 100.5 	AIC: 126.5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Development

No branches or pull requests

1 participant