Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pred_wflw #32

Closed
Tracked by #27
spsanderson opened this issue Dec 8, 2022 · 0 comments
Closed
Tracked by #27

pred_wflw #32

spsanderson opened this issue Dec 8, 2022 · 0 comments
Assignees
Labels
enhancement New feature or request
Milestone

Comments

@spsanderson
Copy link
Owner

spsanderson commented Dec 8, 2022

Function:

# Safely make predictions on fitted workflow
internal_make_wflw_predictions <- function(.model_tbl, .splits_obj){
  
  # Tidyeval ----
  model_tbl <- .model_tbl
  splits_obj <- .splits_obj
  col_nms <- colnames(model_tbl)
  
  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }
  
  if (!"fitted_wflw" %in% col_nms){
    rlang::abort(
      message = "Missing the column 'wflw'",
      use_cli_format = TRUE
    )
  }
  
  if (!".model_id" %in% col_nms){
    rlang::abort(
      message = "Missing the column '.model_id'",
      use_cli_format = TRUE
    )
  }
  
  # Manipulation
  # Make a group split object list
  model_factor_tbl <- model_tbl %>%
    dplyr::mutate(.model_id = forcats::as_factor(.model_id))
  
  models_list <- model_factor_tbl %>%
    dplyr::group_split(.model_id)
  
  # Make the predictions on the fitted workflow object using purrr imap
  wflw_preds_list <- models_list %>%
    purrr::imap(
      .f = function(obj, id){
        
        # Pull the fitted workflow column and then pluck it
        fitted_wflw = obj %>% dplyr::pull(7) %>% pluck(1)
        
        # Create a safe stats::predict
        safe_stats_predict <- purrr::safely(
          stats::predict,
          otherwise = "Error - Could not make predictions",
          quiet = FALSE
        )
        
        # Return the predictions
        ret <- safe_stats_predict(
          fitted_wflw, 
          new_data = rsample::training(splits_obj$splits)
        )
        
        res <- ret %>% purrr::pluck("result")
        
        return(res)
      }
    )
  
  return(wflw_preds_list)
}

Example:

> internal_make_wflw_predictions(mod_fitted_tbl, splits_obj)
Error: no applicable method for 'predict' applied to an object of class "character"
[[1]]
[[1]]$result
# A tibble: 24 × 1
   .pred
   <dbl>
 1  23.2
 2  18.9
 3  15.4
 4  17.7
 5  15.6
 6  16.8
 7  15.5
 8  19.7
 9  11.7
10  22.6
# … with 14 more rows
# ℹ Use `print(n = ...)` to see more rows

[[1]]$error
NULL


[[2]]
[[2]]$result
[1] "Error - Could not make predictions"

[[2]]$error
<simpleError in UseMethod("predict"): no applicable method for 'predict' applied to an object of class "character">


[[3]]
[[3]]$result
# A tibble: 24 × 1
   .pred
   <dbl>
 1  23.2
 2  18.9
 3  15.4
 4  17.7
 5  15.6
 6  16.8
 7  15.5
 8  19.7
 9  11.7
10  22.6
# … with 14 more rows
# ℹ Use `print(n = ...)` to see more rows

[[3]]$error
NULL
``
@spsanderson spsanderson self-assigned this Dec 8, 2022
@spsanderson spsanderson added enhancement New feature or request labels Dec 8, 2022
@spsanderson spsanderson added this to the tidyaml 0.0.1 milestone Dec 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Development

No branches or pull requests

1 participant