Skip to content

Commit

Permalink
Use .env pronoun (fixes #382)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikemahoney218 committed Mar 7, 2023
1 parent 3099e99 commit 6b7cc20
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 15 deletions.
45 changes: 33 additions & 12 deletions R/template.R
Expand Up @@ -116,8 +116,11 @@ numeric_metric_summarizer <- function(name,

out <- dplyr::summarise(
data,
.metric = name,
.estimator = finalize_estimator(.data[[truth]], metric_class = name),
.metric = .env[["name"]],
.estimator = finalize_estimator(
.data[[truth]],
metric_class = .env[["name"]]
),
.estimate = fn(
truth = .data[[truth]],
estimate = .data[[estimate]],
Expand Down Expand Up @@ -176,8 +179,12 @@ class_metric_summarizer <- function(name,

out <- dplyr::summarise(
data,
.metric = name,
.estimator = finalize_estimator(.data[[truth]], estimator, name),
.metric = .env[["name"]],
.estimator = finalize_estimator(
.data[[truth]],
.env[["estimator"]],
.env[["name"]]
),
.estimate = fn(
truth = .data[[truth]],
estimate = .data[[estimate]],
Expand Down Expand Up @@ -233,8 +240,12 @@ prob_metric_summarizer <- function(name,

out <- dplyr::summarise(
data,
.metric = name,
.estimator = finalize_estimator(.data[[truth]], estimator, name),
.metric = .env[["name"]],
.estimator = finalize_estimator(
.data[[truth]],
.env[["estimator"]],
.env[["name"]]
),
.estimate = fn(
truth = .data[[truth]],
estimate = {
Expand Down Expand Up @@ -293,8 +304,12 @@ curve_metric_summarizer <- function(name,

out <- dplyr::reframe(
data,
.metric = name,
.estimator = finalize_estimator(.data[[truth]], estimator, name),
.metric = .env[["name"]],
.estimator = finalize_estimator(
.data[[truth]],
.env[["estimator"]],
.env[["name"]]
),
.estimate = fn(
truth = .data[[truth]],
estimate = {
Expand Down Expand Up @@ -372,8 +387,11 @@ dynamic_survival_metric_summarizer <- function(name,

out <- dplyr::summarise(
data,
.metric = name,
.estimator = finalize_estimator(.data[[truth]], metric_class = name),
.metric = .env[["name"]],
.estimator = finalize_estimator(
.data[[truth]],
metric_class = .env[["name"]]
),
.estimate = fn(
truth = .data[[truth]],
estimate = .data[[estimate]],
Expand Down Expand Up @@ -448,8 +466,11 @@ curve_survival_metric_summarizer <- function(name,

out <- dplyr::reframe(
data,
.metric = name,
.estimator = finalize_estimator(.data[[truth]], metric_class = name),
.metric = .env[["name"]],
.estimator = finalize_estimator(
.data[[truth]],
metric_class = .env[["name"]]
),
.estimate = fn(
truth = .data[[truth]],
estimate = .data[[estimate]],
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/_snaps/error-handling.md
Expand Up @@ -51,7 +51,7 @@
sens(pathology, pathology, scan, estimator = "blah")
Condition
Error in `dplyr::summarise()`:
i In argument: `.estimator = finalize_estimator(.data[["pathology"]], estimator, name)`.
i In argument: `.estimator = finalize_estimator(...)`.
Caused by error in `validate_estimator()`:
! `estimator` must be one of: "binary", "macro", "micro", "macro_weighted". Not "blah".

Expand All @@ -71,7 +71,7 @@
sens(hpc_cv, obs, pred, estimator = 1)
Condition
Error in `dplyr::summarise()`:
i In argument: `.estimator = finalize_estimator(.data[["obs"]], estimator, name)`.
i In argument: `.estimator = finalize_estimator(.data[["obs"]], .env[["estimator"]], .env[["name"]])`.
Caused by error in `validate_estimator()`:
! `estimator` must be a character, not a numeric.

Expand All @@ -81,7 +81,7 @@
sens(hpc_cv, obs, pred, estimator = c("1", "2"))
Condition
Error in `dplyr::summarise()`:
i In argument: `.estimator = finalize_estimator(.data[["obs"]], estimator, name)`.
i In argument: `.estimator = finalize_estimator(.data[["obs"]], .env[["estimator"]], .env[["name"]])`.
Caused by error in `validate_estimator()`:
! `estimator` must be length 1, not 2.

Expand Down
173 changes: 173 additions & 0 deletions tests/testthat/test-template.R
Expand Up @@ -150,6 +150,32 @@ test_that("numeric_metric_summarizer() deals with characters in truth and estima
expect_identical(rmse_res, rmse_exp)
})

test_that("numeric_metric_summarizer() handles column name collisions", {
new_mtcars <- mtcars

new_mtcars$name <- mtcars$mpg
new_mtcars$estimator <- mtcars$mpg
new_mtcars$event_level <- mtcars$mpg

rmse_res <- numeric_metric_summarizer(
name = "rmse",
fn = rmse_vec,
data = new_mtcars,
truth = mpg,
estimate = disp,
na_rm = TRUE,
case_weights = NULL
)

rmse_exp <- dplyr::tibble(
.metric = "rmse",
.estimator = "standard",
.estimate = rmse_vec(new_mtcars$mpg, new_mtcars$disp)
)

expect_identical(rmse_res, rmse_exp)
})

## class_metric_summarizer --------------------------------------------------

test_that("class_metric_summarizer() works as expected", {
Expand Down Expand Up @@ -362,6 +388,31 @@ test_that("class_metric_summarizer() deals with characters in truth and estimate
expect_identical(accuracy_res, accuracy_exp)
})

test_that("class_metric_summarizer() handles column name collisions", {
three_class <- data_three_class()$three_class

new_three_class <- three_class
new_three_class$name <- three_class$obs
new_three_class$estimator <- three_class$obs
new_three_class$event_level <- three_class$obs

accuracy_res <- class_metric_summarizer(
name = "accuracy",
fn = accuracy_vec,
data = new_three_class,
truth = "obs",
estimate = "pred"
)

accuracy_exp <- dplyr::tibble(
.metric = "accuracy",
.estimator = "multiclass",
.estimate = accuracy_vec(three_class$obs, three_class$pred)
)

expect_identical(accuracy_res, accuracy_exp)
})

## prob_metric_summarizer --------------------------------------------------

test_that("prob_metric_summarizer() works as expected", {
Expand Down Expand Up @@ -580,6 +631,33 @@ test_that("prob_metric_summarizer() deals with characters in truth", {
expect_identical(roc_auc_res, roc_auc_exp)
})

test_that("prob_metric_summarizer() handles column name collisions", {
hpc_f1 <- data_hpc_fold1()

new_hpc_f1 <- hpc_f1
new_hpc_f1$name <- hpc_f1$VF
new_hpc_f1$estimator <- hpc_f1$VF
new_hpc_f1$event_level <- hpc_f1$VF

roc_auc_res <- prob_metric_summarizer(
name = "roc_auc",
fn = roc_auc_vec,
data = new_hpc_f1,
truth = "obs",
VF:L,
na_rm = TRUE,
case_weights = NULL
)

roc_auc_exp <- dplyr::tibble(
.metric = "roc_auc",
.estimator = "hand_till",
.estimate = roc_auc_vec(hpc_f1$obs, as.matrix(hpc_f1[3:6]))
)

expect_identical(roc_auc_res, roc_auc_exp)
})

## curve_metric_summarizer --------------------------------------------------

test_that("curve_metric_summarizer() works as expected", {
Expand Down Expand Up @@ -779,6 +857,33 @@ test_that("curve_metric_summarizer() deals with characters in truth", {
expect_identical(roc_curve_res, roc_curve_exp)
})

test_that("curve_metric_summarizer() handles column name collisions", {
hpc_f1 <- data_hpc_fold1()

new_hpc_f1 <- hpc_f1
new_hpc_f1$name <- hpc_f1$VF
new_hpc_f1$estimator <- hpc_f1$VF
new_hpc_f1$event_level <- hpc_f1$VF

roc_curve_res <- curve_metric_summarizer(
name = "roc_curve",
fn = roc_curve_vec,
data = new_hpc_f1,
truth = "obs",
VF:L,
na_rm = TRUE,
case_weights = NULL
)

roc_curve_exp <- dplyr::tibble(
.metric = "roc_curve",
.estimator = "multiclass",
.estimate = roc_curve_vec(hpc_f1$obs, as.matrix(hpc_f1[3:6]))
)

expect_identical(roc_curve_res, roc_curve_exp)
})

## dynamic_survival_metric_summarizer -----------------------------------------

test_that("dynamic_survival_metric_summarizer() works as expected", {
Expand Down Expand Up @@ -982,6 +1087,40 @@ test_that("dynamic_survival_metric_summarizer() deals with characters in truth a
expect_identical(brier_survival_res, brier_survival_exp)
})

test_that("dynamic_survival_metric_summarizer() handles column name collisions", {
lung_surv <- data_lung_surv() %>% dplyr::filter(.time == 100)

new_lung_surv <- lung_surv
new_lung_surv$name <- lung_surv$.time
new_lung_surv$estimator <- lung_surv$.time
new_lung_surv$event_level <- lung_surv$.time

brier_survival_res <- dynamic_survival_metric_summarizer(
name = "brier_survival",
fn = brier_survival_vec,
data = new_lung_surv,
truth = "surv_obj",
estimate = ".pred_survival",
censoring_weights = "ipcw",
eval_time = ".time",
na_rm = TRUE,
case_weights = NULL
)

brier_survival_exp <- dplyr::tibble(
.metric = "brier_survival",
.estimator = "standard",
.estimate = brier_survival_vec(
truth = lung_surv$surv_obj,
estimate = lung_surv$.pred_survival,
censoring_weights = lung_surv$ipcw,
eval_time = lung_surv$.time
)
)

expect_identical(brier_survival_res, brier_survival_exp)
})

## curve_survival_metric_summarizer -----------------------------------------

# To be removed once roc_survival_curve() is added
Expand Down Expand Up @@ -1228,3 +1367,37 @@ test_that("curve_survival_metric_summarizer() deals with characters in truth and

expect_identical(roc_survival_curve_res, roc_survival_curve_exp)
})

test_that("curve_survival_metric_summarizer() handles column name collisions", {
lung_surv <- data_lung_surv() %>% dplyr::filter(.time == 100)

new_lung_surv <- lung_surv
new_lung_surv$name <- lung_surv$.time
new_lung_surv$estimator <- lung_surv$.time
new_lung_surv$event_level <- lung_surv$.time

roc_survival_curve_res <- curve_survival_metric_summarizer(
name = "roc_survival_curve",
fn = roc_survival_curve_vec,
data = new_lung_surv,
truth = "surv_obj",
estimate = ".pred_survival",
censoring_weights = "ipcw",
eval_time = ".time",
na_rm = TRUE,
case_weights = NULL
)

roc_survival_curve_exp <- dplyr::tibble(
.metric = "roc_survival_curve",
.estimator = "standard",
.estimate = roc_survival_curve_vec(
truth = lung_surv$surv_obj,
estimate = lung_surv$.pred_survival,
censoring_weights = lung_surv$ipcw,
eval_time = lung_surv$.time
)
)

expect_identical(roc_survival_curve_res, roc_survival_curve_exp)
})

0 comments on commit 6b7cc20

Please sign in to comment.