From 6b7cc2050e85fd4a44f9e619bff38c637a5587b3 Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Tue, 7 Mar 2023 09:50:38 -0500 Subject: [PATCH] Use .env pronoun (fixes #382) --- R/template.R | 45 ++++-- tests/testthat/_snaps/error-handling.md | 6 +- tests/testthat/test-template.R | 173 ++++++++++++++++++++++++ 3 files changed, 209 insertions(+), 15 deletions(-) diff --git a/R/template.R b/R/template.R index 816b7315..bdd8d270 100644 --- a/R/template.R +++ b/R/template.R @@ -116,8 +116,11 @@ numeric_metric_summarizer <- function(name, out <- dplyr::summarise( data, - .metric = name, - .estimator = finalize_estimator(.data[[truth]], metric_class = name), + .metric = .env[["name"]], + .estimator = finalize_estimator( + .data[[truth]], + metric_class = .env[["name"]] + ), .estimate = fn( truth = .data[[truth]], estimate = .data[[estimate]], @@ -176,8 +179,12 @@ class_metric_summarizer <- function(name, out <- dplyr::summarise( data, - .metric = name, - .estimator = finalize_estimator(.data[[truth]], estimator, name), + .metric = .env[["name"]], + .estimator = finalize_estimator( + .data[[truth]], + .env[["estimator"]], + .env[["name"]] + ), .estimate = fn( truth = .data[[truth]], estimate = .data[[estimate]], @@ -233,8 +240,12 @@ prob_metric_summarizer <- function(name, out <- dplyr::summarise( data, - .metric = name, - .estimator = finalize_estimator(.data[[truth]], estimator, name), + .metric = .env[["name"]], + .estimator = finalize_estimator( + .data[[truth]], + .env[["estimator"]], + .env[["name"]] + ), .estimate = fn( truth = .data[[truth]], estimate = { @@ -293,8 +304,12 @@ curve_metric_summarizer <- function(name, out <- dplyr::reframe( data, - .metric = name, - .estimator = finalize_estimator(.data[[truth]], estimator, name), + .metric = .env[["name"]], + .estimator = finalize_estimator( + .data[[truth]], + .env[["estimator"]], + .env[["name"]] + ), .estimate = fn( truth = .data[[truth]], estimate = { @@ -372,8 +387,11 @@ dynamic_survival_metric_summarizer <- function(name, out <- dplyr::summarise( data, - .metric = name, - .estimator = finalize_estimator(.data[[truth]], metric_class = name), + .metric = .env[["name"]], + .estimator = finalize_estimator( + .data[[truth]], + metric_class = .env[["name"]] + ), .estimate = fn( truth = .data[[truth]], estimate = .data[[estimate]], @@ -448,8 +466,11 @@ curve_survival_metric_summarizer <- function(name, out <- dplyr::reframe( data, - .metric = name, - .estimator = finalize_estimator(.data[[truth]], metric_class = name), + .metric = .env[["name"]], + .estimator = finalize_estimator( + .data[[truth]], + metric_class = .env[["name"]] + ), .estimate = fn( truth = .data[[truth]], estimate = .data[[estimate]], diff --git a/tests/testthat/_snaps/error-handling.md b/tests/testthat/_snaps/error-handling.md index e7f86592..30bc9600 100644 --- a/tests/testthat/_snaps/error-handling.md +++ b/tests/testthat/_snaps/error-handling.md @@ -51,7 +51,7 @@ sens(pathology, pathology, scan, estimator = "blah") Condition Error in `dplyr::summarise()`: - i In argument: `.estimator = finalize_estimator(.data[["pathology"]], estimator, name)`. + i In argument: `.estimator = finalize_estimator(...)`. Caused by error in `validate_estimator()`: ! `estimator` must be one of: "binary", "macro", "micro", "macro_weighted". Not "blah". @@ -71,7 +71,7 @@ sens(hpc_cv, obs, pred, estimator = 1) Condition Error in `dplyr::summarise()`: - i In argument: `.estimator = finalize_estimator(.data[["obs"]], estimator, name)`. + i In argument: `.estimator = finalize_estimator(.data[["obs"]], .env[["estimator"]], .env[["name"]])`. Caused by error in `validate_estimator()`: ! `estimator` must be a character, not a numeric. @@ -81,7 +81,7 @@ sens(hpc_cv, obs, pred, estimator = c("1", "2")) Condition Error in `dplyr::summarise()`: - i In argument: `.estimator = finalize_estimator(.data[["obs"]], estimator, name)`. + i In argument: `.estimator = finalize_estimator(.data[["obs"]], .env[["estimator"]], .env[["name"]])`. Caused by error in `validate_estimator()`: ! `estimator` must be length 1, not 2. diff --git a/tests/testthat/test-template.R b/tests/testthat/test-template.R index 122dcbc6..767a10e7 100644 --- a/tests/testthat/test-template.R +++ b/tests/testthat/test-template.R @@ -150,6 +150,32 @@ test_that("numeric_metric_summarizer() deals with characters in truth and estima expect_identical(rmse_res, rmse_exp) }) +test_that("numeric_metric_summarizer() handles column name collisions", { + new_mtcars <- mtcars + + new_mtcars$name <- mtcars$mpg + new_mtcars$estimator <- mtcars$mpg + new_mtcars$event_level <- mtcars$mpg + + rmse_res <- numeric_metric_summarizer( + name = "rmse", + fn = rmse_vec, + data = new_mtcars, + truth = mpg, + estimate = disp, + na_rm = TRUE, + case_weights = NULL + ) + + rmse_exp <- dplyr::tibble( + .metric = "rmse", + .estimator = "standard", + .estimate = rmse_vec(new_mtcars$mpg, new_mtcars$disp) + ) + + expect_identical(rmse_res, rmse_exp) +}) + ## class_metric_summarizer -------------------------------------------------- test_that("class_metric_summarizer() works as expected", { @@ -362,6 +388,31 @@ test_that("class_metric_summarizer() deals with characters in truth and estimate expect_identical(accuracy_res, accuracy_exp) }) +test_that("class_metric_summarizer() handles column name collisions", { + three_class <- data_three_class()$three_class + + new_three_class <- three_class + new_three_class$name <- three_class$obs + new_three_class$estimator <- three_class$obs + new_three_class$event_level <- three_class$obs + + accuracy_res <- class_metric_summarizer( + name = "accuracy", + fn = accuracy_vec, + data = new_three_class, + truth = "obs", + estimate = "pred" + ) + + accuracy_exp <- dplyr::tibble( + .metric = "accuracy", + .estimator = "multiclass", + .estimate = accuracy_vec(three_class$obs, three_class$pred) + ) + + expect_identical(accuracy_res, accuracy_exp) +}) + ## prob_metric_summarizer -------------------------------------------------- test_that("prob_metric_summarizer() works as expected", { @@ -580,6 +631,33 @@ test_that("prob_metric_summarizer() deals with characters in truth", { expect_identical(roc_auc_res, roc_auc_exp) }) +test_that("prob_metric_summarizer() handles column name collisions", { + hpc_f1 <- data_hpc_fold1() + + new_hpc_f1 <- hpc_f1 + new_hpc_f1$name <- hpc_f1$VF + new_hpc_f1$estimator <- hpc_f1$VF + new_hpc_f1$event_level <- hpc_f1$VF + + roc_auc_res <- prob_metric_summarizer( + name = "roc_auc", + fn = roc_auc_vec, + data = new_hpc_f1, + truth = "obs", + VF:L, + na_rm = TRUE, + case_weights = NULL + ) + + roc_auc_exp <- dplyr::tibble( + .metric = "roc_auc", + .estimator = "hand_till", + .estimate = roc_auc_vec(hpc_f1$obs, as.matrix(hpc_f1[3:6])) + ) + + expect_identical(roc_auc_res, roc_auc_exp) +}) + ## curve_metric_summarizer -------------------------------------------------- test_that("curve_metric_summarizer() works as expected", { @@ -779,6 +857,33 @@ test_that("curve_metric_summarizer() deals with characters in truth", { expect_identical(roc_curve_res, roc_curve_exp) }) +test_that("curve_metric_summarizer() handles column name collisions", { + hpc_f1 <- data_hpc_fold1() + + new_hpc_f1 <- hpc_f1 + new_hpc_f1$name <- hpc_f1$VF + new_hpc_f1$estimator <- hpc_f1$VF + new_hpc_f1$event_level <- hpc_f1$VF + + roc_curve_res <- curve_metric_summarizer( + name = "roc_curve", + fn = roc_curve_vec, + data = new_hpc_f1, + truth = "obs", + VF:L, + na_rm = TRUE, + case_weights = NULL + ) + + roc_curve_exp <- dplyr::tibble( + .metric = "roc_curve", + .estimator = "multiclass", + .estimate = roc_curve_vec(hpc_f1$obs, as.matrix(hpc_f1[3:6])) + ) + + expect_identical(roc_curve_res, roc_curve_exp) +}) + ## dynamic_survival_metric_summarizer ----------------------------------------- test_that("dynamic_survival_metric_summarizer() works as expected", { @@ -982,6 +1087,40 @@ test_that("dynamic_survival_metric_summarizer() deals with characters in truth a expect_identical(brier_survival_res, brier_survival_exp) }) +test_that("dynamic_survival_metric_summarizer() handles column name collisions", { + lung_surv <- data_lung_surv() %>% dplyr::filter(.time == 100) + + new_lung_surv <- lung_surv + new_lung_surv$name <- lung_surv$.time + new_lung_surv$estimator <- lung_surv$.time + new_lung_surv$event_level <- lung_surv$.time + + brier_survival_res <- dynamic_survival_metric_summarizer( + name = "brier_survival", + fn = brier_survival_vec, + data = new_lung_surv, + truth = "surv_obj", + estimate = ".pred_survival", + censoring_weights = "ipcw", + eval_time = ".time", + na_rm = TRUE, + case_weights = NULL + ) + + brier_survival_exp <- dplyr::tibble( + .metric = "brier_survival", + .estimator = "standard", + .estimate = brier_survival_vec( + truth = lung_surv$surv_obj, + estimate = lung_surv$.pred_survival, + censoring_weights = lung_surv$ipcw, + eval_time = lung_surv$.time + ) + ) + + expect_identical(brier_survival_res, brier_survival_exp) +}) + ## curve_survival_metric_summarizer ----------------------------------------- # To be removed once roc_survival_curve() is added @@ -1228,3 +1367,37 @@ test_that("curve_survival_metric_summarizer() deals with characters in truth and expect_identical(roc_survival_curve_res, roc_survival_curve_exp) }) + +test_that("curve_survival_metric_summarizer() handles column name collisions", { + lung_surv <- data_lung_surv() %>% dplyr::filter(.time == 100) + + new_lung_surv <- lung_surv + new_lung_surv$name <- lung_surv$.time + new_lung_surv$estimator <- lung_surv$.time + new_lung_surv$event_level <- lung_surv$.time + + roc_survival_curve_res <- curve_survival_metric_summarizer( + name = "roc_survival_curve", + fn = roc_survival_curve_vec, + data = new_lung_surv, + truth = "surv_obj", + estimate = ".pred_survival", + censoring_weights = "ipcw", + eval_time = ".time", + na_rm = TRUE, + case_weights = NULL + ) + + roc_survival_curve_exp <- dplyr::tibble( + .metric = "roc_survival_curve", + .estimator = "standard", + .estimate = roc_survival_curve_vec( + truth = lung_surv$surv_obj, + estimate = lung_surv$.pred_survival, + censoring_weights = lung_surv$ipcw, + eval_time = lung_surv$.time + ) + ) + + expect_identical(roc_survival_curve_res, roc_survival_curve_exp) +})