Skip to content

Commit

Permalink
call_standardise_formals() now attempts to standardise formals of S…
Browse files Browse the repository at this point in the history
…3 methods (#339)
  • Loading branch information
rossellhayes committed Apr 5, 2023
1 parent 02f23cd commit 35dbf6a
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 12 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: gradethis
Title: Automated Feedback for Student Exercises in 'learnr' Tutorials
Version: 0.2.12.9001
Version: 0.2.12.9002
Authors@R: c(
person("Garrick", "Aden-Buie", , "garrick@posit.co", role = "aut",
comment = c(ORCID = "0000-0002-7111-0077")),
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# gradethis 0.2.12.9002

* `call_standardise_formals()` now attempts to standardize the arguments of calls to S3 generics (#339).

# gradethis 0.2.12.9001

* `pass_if()` and `fail_if()` now produce more informative error messages if their `cond` argument is invalid (#341).
Expand Down
56 changes: 55 additions & 1 deletion R/call_standarise_formals.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ call_standardise_formals_recursive <- function( # nolint
call_standardise_formals(code)
}

call_fn <- function(code, env) {
call_fn <- function(code, env = parent.frame()) {
if (rlang::is_quosure(code) || rlang::is_formula(code)) {
code <- rlang::get_expr(code)
}
Expand All @@ -110,5 +110,59 @@ call_fn <- function(code, env) {

fn <- rlang::eval_bare(head, env)
stopifnot(rlang::is_function(fn))

try_is_s3 <- purrr::possibly(utils::isS3stdGeneric, otherwise = FALSE)
fn_is_s3_generic <- try_is_s3(fn)

if (fn_is_s3_generic) {
fn_name <- names(fn_is_s3_generic) %||% head
try_get_s3_method <- purrr::possibly(get_s3_method, otherwise = NULL)
fn <- try_get_s3_method(fn_name, arg = code[[2]], env = env) %||% fn
}

fn
}

get_s3_method <- function(fn_name, arg, env = parent.frame()) {
class <- expand_class(arg, env)

while (length(class) > 0) {
method <- utils::getS3method(
fn_name,
class[[1]],
optional = TRUE,
envir = env
)

if (!is.null(method)) {
break
}

class <- class[-1]
}

method
}

expand_class <- function(arg, env) {
arg <- rlang::eval_bare(arg, env)
class <- unique(class(arg))

if ("array" %in% class) {
non_array_arg <- arg
dim(non_array_arg) <- NULL
non_array_class <- class(non_array_arg)
class <- unique(append(class, non_array_class))
}

if ("numeric" %in% class) {
class <- unique(append(class, "double", which(class == "numeric") - 1))
}

if ("integer" %in% class) {
class <- unique(append(class, "numeric", which(class == "integer")))
}

class <- unique(append(class, "default"))
class
}
50 changes: 47 additions & 3 deletions tests/testthat/test-call_standarise_formals.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
test_that("Standarize call with formals primitive function", {
test_that("Standarize call with formals S3 function", {
user <- rlang::get_expr(quote(mean(1:3, na.rm = TRUE)))
user_stand <- call_standardise_formals(user)

expect_equal(user_stand, quote(mean(x = 1:3, na.rm = TRUE)))
expect_equal(user_stand, quote(mean(x = 1:3, trim = 0, na.rm = TRUE)))

user <- quote(mean(1:3, 0, TRUE))
user_stand <- call_standardise_formals(user)

expect_equal(user_stand, quote(mean(x = 1:3, 0, TRUE)))
expect_equal(user_stand, quote(mean(x = 1:3, trim = 0, na.rm = TRUE)))
})

test_that("Standarize call with formals user function", {
Expand All @@ -27,6 +27,50 @@ test_that("Standarize call with formals user function", {
)
})

test_that("Standarize call with formals user S3 function", {
my_func <- function(x, ...) {
UseMethod("my_func")
}

my_func.numeric <- function(x, y, z = 100, ...) {
x + y + z
}

my_func.character <- function(x, a, b = 3.14, c = "s", ...) {
paste(x, a, b, c)
}

user_numeric <- rlang::get_expr(quote(my_func(x = 1, 20)))
user_numeric_stand <- call_standardise_formals(
user_numeric,
env = rlang::env(
my_func = my_func,
my_func.numeric = my_func.numeric,
my_func.character = my_func.character
)
)

testthat::expect_equal(
user_numeric_stand,
quote(my_func(x = 1, y = 20, z = 100))
)

user_character <- rlang::get_expr(quote(my_func(x = "1", 20)))
user_character_stand <- call_standardise_formals(
user_character,
env = rlang::env(
my_func = my_func,
my_func.numeric = my_func.numeric,
my_func.character = my_func.character
)
)

testthat::expect_equal(
user_character_stand,
quote(my_func(x = "1", a = 20, b = 3.14, c = "s"))
)
})

test_that("Standarize call with ... and kwargs", {
a <- quote(vapply(list(1:3, 4:6), mean, numeric(1), 0, TRUE))
b <- quote(vapply(list(1:3, 4:6), mean, numeric(1), trim = 0, TRUE))
Expand Down
11 changes: 4 additions & 7 deletions tests/testthat/test-detect_mistakes.R
Original file line number Diff line number Diff line change
Expand Up @@ -727,19 +727,16 @@ test_that("detect_mistakes handles argument names correctly", {
)
)

# This user code looks correct (and runs!) but na.rm is an argument passed to
# This user code looks correct (and runs!) but invalid is an argument passed to
# ... that does not appear in the solution, and so should be flagged wrong.
user <- quote(mean(1:10, cut = 1, na.rm = TRUE))
solution <- quote(mean(1:10, TRUE, cut = 1))
user <- quote(mean(1:10, cut = 1, invalid = TRUE))
solution <- quote(mean(1:10, cut = 1))
expect_equal(
detect_mistakes(user, solution),
# message_wrong_value(this = quote(1),
# this_name = "cut",
# that = quote(TRUE))
message_surplus_argument(
submitted_call = quote(mean()),
submitted = quote(TRUE),
submitted_name = "na.rm"
submitted_name = "invalid"
)
)

Expand Down

0 comments on commit 35dbf6a

Please sign in to comment.