Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a6fab99
move function to new file
topepo Dec 4, 2023
8c2aeb1
change function order for docs
topepo Dec 4, 2023
252db2b
documentation start
topepo Dec 4, 2023
f891bf0
updates to the show/select functions
topepo Dec 4, 2023
6790f6d
updates to select/show functions
topepo Dec 4, 2023
16caebe
updates for selecting eval times
topepo Dec 5, 2023
8a813c8
remove commented out code
topepo Dec 5, 2023
a340985
bug fix
topepo Dec 5, 2023
620d048
metric test cases
topepo Dec 5, 2023
a00c951
Merge branch 'new-metric-selections' into new-time-selections
topepo Dec 5, 2023
9f633d4
add a survival model object
topepo Dec 5, 2023
ff06ced
note for next PR
topepo Dec 5, 2023
1aea76b
select/show test cases
topepo Dec 5, 2023
70afeeb
small set of direct tests
topepo Dec 5, 2023
5bcbb7b
update snapshot
topepo Dec 5, 2023
0d5332b
Apply suggestions from code review
topepo Dec 5, 2023
538b343
Apply suggestions from code review
topepo Dec 5, 2023
1dc65ca
Merge branch 'new-metric-selections' into new-time-selections
topepo Dec 5, 2023
5580d7a
updates from previous review
topepo Dec 5, 2023
9afab6c
Merge branch 'main' into new-time-selections
topepo Dec 5, 2023
870ce13
small cli update
topepo Dec 5, 2023
0b366d3
doc update
topepo Dec 5, 2023
fbadcd9
refresh snapshots
topepo Dec 5, 2023
86b37fd
modularize a check
topepo Dec 5, 2023
6247b96
Remake with newest CRAN version of scales for #775
topepo Dec 5, 2023
972ffc2
Apply suggestions from code review
topepo Dec 6, 2023
55fc24e
Apply suggestions from code review
topepo Dec 6, 2023
c0066d3
add dot when function is invoked
topepo Dec 6, 2023
8ee1c98
add a warning for eval times with non-survival models
topepo Dec 6, 2023
f032da4
go back to enquos
topepo Dec 6, 2023
453c23a
rework warning text
topepo Dec 6, 2023
c377d3d
rework warning text pt 2
topepo Dec 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tune
Title: Tidy Tuning Tools
Version: 1.1.2.9001
Version: 1.1.2.9002
Authors@R: c(
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ export(.catch_and_log)
export(.catch_and_log_fit)
export(.config_key_from_metrics)
export(.estimate_metrics)
export(.filter_perf_metrics)
export(.get_extra_col_names)
export(.get_fingerprint)
export(.get_tune_eval_times)
Expand All @@ -156,6 +157,7 @@ export(check_parameters)
export(check_rset)
export(check_time)
export(check_workflow)
export(choose_eval_time)
export(choose_metric)
export(collect_extracts)
export(collect_metrics)
Expand Down
111 changes: 104 additions & 7 deletions R/metric-selection.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#' @param metric A character value for which metric is being used.
#' @param eval_time An optional vector of times to compute dynamic and/or
#' integrated metrics.
#' @param x An object with class `tune_results`.
#' @param call The call to be displayed in warnings or errors.
#' @description
#' These are developer-facing functions used to compute and validate choices
#' for performance metrics. For survival analysis models, there are similar
Expand All @@ -15,6 +17,14 @@
#' no value is given by the user, the first metric value is used (with a
#' warning).
#'
#' For evaluation times, one is only required when the metric type is dynamic
#' (e.g. [yardstick::brier_survival()] or [yardstick::roc_auc_survival()]). For
#' these metrics, we require a single numeric value that was originally given
#' to the function used to produce `x` (such as [tune_grid()]).
#'
#' If a time is required and none is given, the first value in the vector
#' originally given in the `eval_time` argument is used (with a warning).
#'
#' @keywords internal
#' @export
choose_metric <- function(x, metric, ..., call = rlang::caller_env()) {
Expand All @@ -25,10 +35,12 @@ choose_metric <- function(x, metric, ..., call = rlang::caller_env()) {

if (is.null(metric)) {
metric <- mtr_info$metric[1]
cli::cli_warn("No value of {.arg metric} was given; {.val {metric}} will be used.", call = call)
cli::cli_warn("No value of {.arg metric} was given; {.val {metric}}
will be used.",
call = call)
} else {
metric <- check_mult_metrics(metric, call = call)
check_right_metric(mtr_info, metric, call = call)
check_metric_in_tune_results(mtr_info, metric, call = call)
}

mtr_info[mtr_info$metric == metric,]
Expand All @@ -40,16 +52,19 @@ check_mult_metrics <- function(metric, ..., call = rlang::caller_env()) {
num_metrics <- length(metric)
metric <- metric[1]
if (num_metrics > 1) {
cli::cli_warn("{num_metrics} metric{?s} were given; {.val {metric}} will be used.", call = call)
cli::cli_warn("{num_metrics} metric{?s} were given; {.val {metric}} will
be used.",
call = call)
}
metric
}

check_right_metric <- function(mtr_info, metric, ..., call = rlang::caller_env()) {
check_metric_in_tune_results <- function(mtr_info, metric, ..., call = rlang::caller_env()) {
rlang::check_dots_empty()

if (!any(mtr_info$metric == metric)) {
cli::cli_abort("{.val {metric}} was not in the metric set. Please choose from: {.val {mtr_info$metric}}.", call = call)
cli::cli_abort("{.val {metric}} was not in the metric set. Please choose
from: {.val {mtr_info$metric}}.", call = call)
}
invisible(NULL)
}
Expand All @@ -58,6 +73,53 @@ contains_survival_metric <- function(mtr_info) {
any(grepl("_survival", mtr_info$class))
}

#' @rdname choose_metric
#' @export
choose_eval_time <- function(x, metric, eval_time = NULL, ..., call = rlang::caller_env()) {
rlang::check_dots_empty()

mtr_set <- .get_tune_metrics(x)
mtr_info <- tibble::as_tibble(mtr_set)

if (!contains_survival_metric(mtr_info)) {
if (!is.null(eval_time)) {
cli::cli_warn("Evaluation times are only required when the model
mode is {.val censored regression} (and will be ignored).")
}
return(NULL)
}

# If we need an eval time, set it to the possible values so that
# we can choose the first value
if (is_dyn(mtr_set, metric) && is.null(eval_time)) {
eval_time <- .get_tune_eval_times(x)
}

eval_time <- first_eval_time(mtr_set, metric = metric, eval_time = eval_time)

check_eval_time_in_tune_results(x, eval_time, call = call)

eval_time
}

is_dyn <- function(mtr_set, metric) {
mtr_info <- tibble::as_tibble(mtr_set)
mtr_cls <- mtr_info$class[mtr_info$metric == metric]
mtr_cls == "dynamic_survival_metric"
}

check_eval_time_in_tune_results <- function(x, eval_time, call = rlang::caller_env()) {
given_times <- .get_tune_eval_times(x)
if (!is.null(eval_time)) {
if (!any(eval_time == given_times)) {
print_time <- format(eval_time, digits = 3)
cli::cli_abort("Evaluation time {print_time} is not in the results.",
call = call)
}
}
invisible(NULL)
}

# ------------------------------------------------------------------------------

#' @rdname choose_metric
Expand Down Expand Up @@ -88,7 +150,9 @@ first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL) {
no_time_req <- c("static_survival_metric", "integrated_survival_metric")
if (mtr_info$class %in% no_time_req) {
if (num_times > 0) {
cli::cli_warn("Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric.")
cli::cli_warn("Evaluation times are only required when dynmanic or
integrated metrics are selected as the primary metric
(and will be ignored).")
}
return(NULL)
}
Expand All @@ -99,8 +163,41 @@ first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL) {
} else if ( num_times > 1 ) {
eval_time <- eval_time[1]
print_time <- format(eval_time, digits = 3)
cli::cli_warn("{num_times} evaluation times were available; the first ({print_time}) will be used.")
cli::cli_warn("{.val {num_times}} evaluation times were specified during
tuning; the first ({print_time}) will be used.")
}

eval_time
}

# ------------------------------------------------------------------------------

#' @rdname choose_metric
#' @export
.filter_perf_metrics <- function(x, metric, eval_time) {
summary_res <- estimate_tune_results(x)
summary_res <- summary_res[summary_res$.metric == metric, ]
is_missing_mean <- is.na(summary_res$mean)
summary_res <- summary_res[!is_missing_mean, ]

if (!is.null(eval_time) && any(colnames(summary_res) == ".eval_time")) {
summary_res <- summary_res[summary_res$.eval_time == eval_time, ]
}
if (nrow(summary_res) == 0) {
cli::cli_abort("No results are available. Please use {.fun collect_metrics}
to see if there were any issues.")
}

summary_res
}

# TODO will be removed shortly

middle_eval_time <- function(x) {
x <- x[!is.na(x)]
times <- unique(x)
med_time <- median(x, na.rm = TRUE)
ind <- which.min(abs(times - med_time))
eval_time <- times[ind]
eval_time
}
1 change: 1 addition & 0 deletions R/plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ get_param_label <- function(x, id_val) {
res
}

# TODO remove this.
default_eval_time <- function(eval_time, x, call = rlang::caller_env()) {
if (!any(names(x) == ".eval_time")) {
if (!is.null(eval_time)) {
Expand Down
Loading