Skip to content

Commit

Permalink
codomain assert is now configureable
Browse files Browse the repository at this point in the history
  • Loading branch information
pfistfl committed Nov 25, 2022
1 parent 3f65286 commit 629ddb6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 29 deletions.
2 changes: 1 addition & 1 deletion yahpo_gym_r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Suggests:
knitr,
rmarkdown,
future
RoxygenNote: 7.1.2
RoxygenNote: 7.2.1
VignetteBuilder: knitr
Config/testthat/edition: 3
Config/testthat/parallel: true
17 changes: 14 additions & 3 deletions yahpo_gym_r/R/BenchmarkSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ BenchmarkSet = R6::R6Class("BenchmarkSet",
#' @field check `logical` \cr
#' Check whether values coincide with `domain`.
check = NULL,

#' @field check_codomain `logical` \cr
#' Check whether returned values coincide with `codomain`.
check_codomain = NULL,

#' @field noisy `logical` \cr
#' Whether noisy surrogates should be used.
Expand All @@ -70,14 +74,17 @@ BenchmarkSet = R6::R6Class("BenchmarkSet",
#' Check inputs for validity before passing to surrogate model? Default `FALSE`.
#' @param noisy `logical` \cr
#' Should noisy surrogates be used instead of deterministic ones?
initialize = function(scenario, instance = NULL, onnx_session = NULL, active_session = FALSE, multithread = FALSE, check = FALSE, noisy = FALSE) {
#' @param check_codomain `logical` \cr
#' Check outputs of surrogate model for validity? Default `FALSE`.
initialize = function(scenario, instance = NULL, onnx_session = NULL, active_session = FALSE, multithread = FALSE, check = FALSE, noisy = FALSE, check_codomain = FALSE) {
self$id = assert_string(scenario)
self$instance = assert_string(instance, null.ok = TRUE)
self$onnx_session = onnx_session
self$active_session = assert_flag(active_session)
self$multithread = assert_flag(multithread)
self$check = assert_flag(check)
self$noisy = assert_flag(noisy)
self$check_codomain = assert_flag(check_codomain)
},
#' @description
#' Printer with some additional information.
Expand Down Expand Up @@ -109,13 +116,16 @@ BenchmarkSet = R6::R6Class("BenchmarkSet",
#' Should the ONNX session be allowed to leverage multithreading capabilities? Default `FALSE`.
#' @param seed `integer` \cr
#' Initial seed for the `onnxruntime.runtime`. Only relevant if `noisy = TRUE`. Default `NULL` (no seed).
#' @param check_codomain `logical` \cr
#' Check outputs of surrogate model for validity? Default `FALSE`.
#' @return
#' A [`Objective`][bbotk::Objective] containing "domain", "codomain" and a
#' functionality to evaluate the surrogates.
get_objective = function(instance, multifidelity = TRUE, check_values = TRUE, timed = FALSE, logging = FALSE, multithread = FALSE, seed = NULL) {
get_objective = function(instance, multifidelity = TRUE, check_values = TRUE, timed = FALSE, logging = FALSE, multithread = FALSE, seed = NULL, check_codomain = NULL) {
assert_choice(instance, self$instances)
assert_flag(check_values)
assert_int(seed, null.ok = TRUE)
assert_flag(check_codomain)
ObjectiveYAHPO$new(
instance,
multifidelity,
Expand All @@ -133,7 +143,8 @@ BenchmarkSet = R6::R6Class("BenchmarkSet",
timed = timed,
logging = logging,
multithread = multithread,
seed = seed
seed = seed,
check_codomain = check_codomain
)
},
#' @description
Expand Down
7 changes: 5 additions & 2 deletions yahpo_gym_r/R/Objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ ObjectiveYAHPO = R6::R6Class("ObjectiveYAHPO",
logging = NULL,
multithread = NULL,
seed = NULL,
check_codomain = NULL,

initialize = function(instance, multifidelity = TRUE, py_instance_args, domain, codomain = NULL, check_values = TRUE, timed = FALSE, logging = FALSE, multithread = FALSE, seed = 0L) {
initialize = function(instance, multifidelity = TRUE, py_instance_args, domain, codomain = NULL, check_values = TRUE, timed = FALSE, logging = FALSE, multithread = FALSE, seed = 0L, check_codomain = FALSE) {
assert_flag(multifidelity)
assert_flag(check_values)
self$timed = assert_flag(timed)
self$logging = assert_flag(logging)
self$multithread = assert_flag(multithread)
self$seed = assert_int(seed, null.ok = TRUE)
self$check_codomain = assert_flag(check_codomain)

if (is.null(codomain)) {
codomain = ps(y = p_dbl(tags = "minimize"))
}
Expand Down Expand Up @@ -61,7 +64,7 @@ ObjectiveYAHPO = R6::R6Class("ObjectiveYAHPO",
}
res = invoke(private$.fun, list(xs), .args = self$constants$values)
res = res[[1]][self$codomain$ids()]
if (self$check_values) self$codomain$assert(as.list(res)[self$codomain$ids()])
if (self$check_codomain) self$codomain$assert(as.list(res)[self$codomain$ids()])
return(res)
},
export = function() {
Expand Down
57 changes: 34 additions & 23 deletions yahpo_gym_r/man/BenchmarkSet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 629ddb6

Please sign in to comment.