Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ set_mode.model_spec <- function(object, mode) {
# determine if the model specification could feasibly match any entry
# in the union of the parsnip model environment and model_info_table.
# if not, trigger an error based on the (possibly inferred) model spec slots.
if (!spec_is_possible(cls,
object$engine, object$user_specified_engine,
mode, user_specified_mode = TRUE)) {
if (!spec_is_possible(spec = object,
mode = mode, user_specified_mode = TRUE)) {
check_spec_mode_engine_val(cls, object$engine, mode)
}

Expand Down
5 changes: 2 additions & 3 deletions R/engines.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,8 @@ set_engine.model_spec <- function(object, engine, ...) {
# determine if the model specification could feasibly match any entry
# in the union of the parsnip model environment and model_info_table.
# if not, trigger an error based on the (possibly inferred) model spec slots.
if (!spec_is_possible(mod_type,
object$engine, user_specified_engine = TRUE,
object$mode, object$user_specified_mode)) {
if (!spec_is_possible(spec = object,
engine = object$engine, user_specified_engine = TRUE)) {
check_spec_mode_engine_val(mod_type, object$engine, object$mode)
}

Expand Down
8 changes: 8 additions & 0 deletions R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ extract_fit_engine.model_fit <- function(x, ...) {
#' @export
#' @rdname extract-parsnip
extract_parameter_set_dials.model_spec <- function(x, ...) {
if (!spec_is_loaded(spec = x)) {
prompt_missing_implementation(
spec = x,
prompt = cli::cli_abort,
call = NULL
)
}

all_args <- generics::tunable(x)
tuning_param <- generics::tune_args(x)

Expand Down
6 changes: 1 addition & 5 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,7 @@ fit.model_spec <-

if (length(possible_engines(object)) == 0) {
prompt_missing_implementation(
cls = class(object)[1],
engine = object$engine,
user_specified_engine = object$user_specified_engine,
mode = object$mode,
user_specified_mode = object$user_specified_mode,
spec = object,
prompt = cli::cli_abort,
call = call2("fit")
)
Expand Down
47 changes: 28 additions & 19 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ mode_filter_condition <- function(mode, user_specified_mode) {
#'
#' The helpers `spec_is_possible()`, `spec_is_loaded()`, and
#' `prompt_missing_implementation()` provide tooling for checking
#' model specifications. In addition to the `cls`, `engine`, and `mode`
#' model specifications. In addition to the `spec`, `engine`, and `mode`
#' arguments, the functions take arguments `user_specified_engine` and
#' `user_specified_mode`, denoting whether the user themselves has
#' specified the engine or mode, respectively.
Expand Down Expand Up @@ -91,9 +91,13 @@ mode_filter_condition <- function(mode, user_specified_mode) {
#' @export
#' @keywords internal
#' @rdname add_on_exports
spec_is_possible <- function(cls,
engine, user_specified_engine,
mode, user_specified_mode) {
spec_is_possible <- function(spec,
engine = spec$engine,
user_specified_engine = spec$user_specified_engine,
mode = spec$mode,
user_specified_mode = spec$user_specified_mode) {
cls <- class(spec)[[1]]

all_model_info <-
dplyr::full_join(
read_model_info_table(),
Expand All @@ -119,9 +123,13 @@ spec_is_possible <- function(cls,
#' @export
#' @keywords internal
#' @rdname add_on_exports
spec_is_loaded <- function(cls,
engine, user_specified_engine,
mode, user_specified_mode) {
spec_is_loaded <- function(spec,
engine = spec$engine,
user_specified_engine = spec$user_specified_engine,
mode = spec$mode,
user_specified_mode = spec$user_specified_mode) {
cls <- class(spec)[[1]]

engine_condition <- engine_filter_condition(engine, user_specified_engine)
mode_condition <- mode_filter_condition(mode, user_specified_mode)

Expand All @@ -143,9 +151,7 @@ spec_is_loaded <- function(cls,

is_printable_spec <- function(x) {
!is.null(x$method$fit$args) &&
spec_is_loaded(class(x)[1],
x$engine, x$user_specified_engine,
x$mode, x$user_specified_mode)
spec_is_loaded(x)
}

# construct a message informing the user that there are no
Expand All @@ -158,10 +164,14 @@ is_printable_spec <- function(x) {
#' @export
#' @keywords internal
#' @rdname add_on_exports
prompt_missing_implementation <- function(cls,
engine, user_specified_engine,
mode, user_specified_mode,
prompt_missing_implementation <- function(spec,
engine = spec$engine,
user_specified_engine = spec$user_specified_engine,
mode = spec$mode,
user_specified_mode = spec$user_specified_mode,
prompt, ...) {
cls <- class(spec)[[1]]

engine_condition <- engine_filter_condition(engine, user_specified_engine)
mode_condition <- mode_filter_condition(mode, user_specified_mode)

Expand Down Expand Up @@ -303,18 +313,17 @@ new_model_spec <- function(cls, args, eng_args, mode, user_specified_mode = TRUE
# determine if the model specification could feasibly match any entry
# in the union of the parsnip model environment and model_info_table.
# if not, trigger an error based on the (possibly inferred) model spec slots.
if (!spec_is_possible(cls,
engine, user_specified_engine,
mode, user_specified_mode)) {
check_spec_mode_engine_val(cls, engine, mode)
}

out <- list(
args = args, eng_args = eng_args,
mode = mode, user_specified_mode = user_specified_mode, method = method,
engine = engine, user_specified_engine = user_specified_engine
)
class(out) <- make_classes(cls)

if (!spec_is_possible(spec = out)) {
check_spec_mode_engine_val(cls, engine, mode)
}

out
}

Expand Down
9 changes: 2 additions & 7 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@ print.model_spec <- function(x, ...) {
#' @rdname add_on_exports
#' @export
print_model_spec <- function(x, cls = class(x)[1], desc = get_model_desc(cls), ...) {
if (!spec_is_loaded(cls,
x$engine, x$user_specified_engine,
x$mode, x$user_specified_mode)) {
prompt_missing_implementation(cls,
x$engine, x$user_specified_engine,
x$mode, x$user_specified_mode,
prompt = cli::cli_inform)
if (!spec_is_loaded(spec = structure(x, class = cls))) {
prompt_missing_implementation(spec = structure(x, class = cls), prompt = cli::cli_inform)
}

cat(desc, " Model Specification (", x$mode, ")\n\n", sep = "")
Expand Down
28 changes: 20 additions & 8 deletions man/add_on_exports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions tests/testthat/_snaps/extract.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# extract parameter set from model with no loaded implementation

Code
extract_parameter_set_dials(bt_mod)
Condition
Error:
! parsnip could not locate an implementation for `bag_tree` regression model specifications.
i The parsnip extension package baguette implements support for this specification.
i Please install (if needed) and load to continue.

---

Code
extract_parameter_dials(bt_mod, parameter = "min_n")
Condition
Error:
! parsnip could not locate an implementation for `bag_tree` regression model specifications.
i The parsnip extension package baguette implements support for this specification.
i Please install (if needed) and load to continue.

8 changes: 8 additions & 0 deletions tests/testthat/test_extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ test_that('extract parameter set from model with main and engine parameters', {
expect_equal(c5_info$object[[2]], NA)
})

test_that('extract parameter set from model with no loaded implementation', {
bt_mod <- bag_tree(min_n = tune()) %>%
set_mode("regression")

expect_snapshot(error = TRUE, extract_parameter_set_dials(bt_mod))
expect_snapshot(error = TRUE, extract_parameter_dials(bt_mod, parameter = "min_n"))
})

# ------------------------------------------------------------------------------

test_that('extract single parameter from model with no parameters', {
Expand Down