Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error if grouped dataframes are used in functions #106

Merged
merged 9 commits into from
May 12, 2023
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,34 @@ S3method(cal_apply,cal_object)
S3method(cal_apply,data.frame)
S3method(cal_apply,tune_results)
S3method(cal_estimate_beta,data.frame)
S3method(cal_estimate_beta,grouped_df)
S3method(cal_estimate_beta,tune_results)
S3method(cal_estimate_isotonic,data.frame)
S3method(cal_estimate_isotonic,grouped_df)
S3method(cal_estimate_isotonic,tune_results)
S3method(cal_estimate_isotonic_boot,data.frame)
S3method(cal_estimate_isotonic_boot,grouped_df)
S3method(cal_estimate_isotonic_boot,tune_results)
S3method(cal_estimate_linear,data.frame)
S3method(cal_estimate_linear,grouped_df)
S3method(cal_estimate_linear,tune_results)
S3method(cal_estimate_logistic,data.frame)
S3method(cal_estimate_logistic,grouped_df)
S3method(cal_estimate_logistic,tune_results)
S3method(cal_estimate_multinomial,data.frame)
S3method(cal_estimate_multinomial,grouped_df)
S3method(cal_estimate_multinomial,tune_results)
S3method(cal_plot_breaks,data.frame)
S3method(cal_plot_breaks,grouped_df)
S3method(cal_plot_breaks,tune_results)
S3method(cal_plot_logistic,data.frame)
S3method(cal_plot_logistic,grouped_df)
S3method(cal_plot_logistic,tune_results)
S3method(cal_plot_regression,data.frame)
S3method(cal_plot_regression,grouped_df)
S3method(cal_plot_regression,tune_results)
S3method(cal_plot_windowed,data.frame)
S3method(cal_plot_windowed,grouped_df)
S3method(cal_plot_windowed,tune_results)
S3method(cal_validate_beta,resample_results)
S3method(cal_validate_beta,rset)
Expand Down
12 changes: 12 additions & 0 deletions R/cal-estimate-beta.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ cal_estimate_beta.tune_results <- function(.data,
)
}

#' @export
#' @rdname cal_estimate_beta
cal_estimate_beta.grouped_df <- function(.data,
truth = NULL,
shape_params = 2,
location_params = 1,
estimate = NULL,
parameters = NULL,
...) {
abort_if_grouped_df()
}

#' @rdname required_pkgs.cal_object
#' @keywords internal
#' @export
Expand Down
21 changes: 21 additions & 0 deletions R/cal-estimate-isotonic.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ cal_estimate_isotonic.tune_results <- function(.data,
)
}

#' @export
#' @rdname cal_estimate_isotonic
cal_estimate_isotonic.grouped_df <- function(.data,
truth = NULL,
estimate = NULL,
parameters = NULL,
...) {
abort_if_grouped_df()
}

#------------------ >> Bootstrapped Isotonic Regression------------------------
#' Uses a bootstrapped Isotonic regression model to calibrate probabilities
#' @param times Number of bootstraps.
Expand Down Expand Up @@ -166,6 +176,17 @@ cal_estimate_isotonic_boot.tune_results <- function(.data,
)
}

#' @export
#' @rdname cal_estimate_isotonic_boot
cal_estimate_isotonic_boot.grouped_df <- function(.data,
truth = NULL,
estimate = NULL,
times = 10,
parameters = NULL,
...) {
abort_if_grouped_df()
}

#------------------------------ Implementation ---------------------------------
cal_isoreg_impl <- function(.data,
truth,
Expand Down
15 changes: 13 additions & 2 deletions R/cal-estimate-linear.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#------------------------------- Methods ---------------------------------------
#' Uses a linear regression model to calibrate numeric predictions
#' @inheritParams cal_estimate_logistic
#' @param .data A `data.frame` object, or `tune_results` object, that contains
#' predictions and probability columns.
#' @param .data Am ungrouped `data.frame` object, or `tune_results` object,
#' that contains a prediction column.
#' @param truth The column identifier for the observed outcome data (that is
#' numeric). This should be an unquoted column name.
#' @param estimate Column identifier for the predicted values
Expand Down Expand Up @@ -125,6 +125,17 @@ cal_estimate_linear.tune_results <- function(.data,
)
}

#' @export
#' @rdname cal_estimate_linear
cal_estimate_linear.grouped_df <- function(.data,
truth = NULL,
estimate = NULL,
smooth = TRUE,
parameters = NULL,
...) {
abort_if_grouped_df()
}

#' @rdname required_pkgs.cal_object
#' @keywords internal
#' @export
Expand Down
16 changes: 14 additions & 2 deletions R/cal-estimate-logistic.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#------------------------------- Methods ---------------------------------------
#' Uses a logistic regression model to calibrate probabilities
#' @param .data A `data.frame` object, or `tune_results` object, that contains
#' predictions and probability columns.
#' @param .data An ungrouped `data.frame` object, or `tune_results` object,
#' that contains predictions and probability columns.
#' @param truth The column identifier for the true class results
#' (that is a factor). This should be an unquoted column name.
#' @param estimate A vector of column identifiers, or one of `dplyr` selector
Expand Down Expand Up @@ -102,6 +102,18 @@ cal_estimate_logistic.tune_results <- function(.data,
)
}

#' @export
#' @rdname cal_estimate_logistic
cal_estimate_logistic.grouped_df <- function(.data,
truth = NULL,
estimate = NULL,
smooth = TRUE,
parameters = NULL,
...) {
abort_if_grouped_df()
}


#' @rdname required_pkgs.cal_object
#' @keywords internal
#' @export
Expand Down
11 changes: 11 additions & 0 deletions R/cal-estimate-multinom.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ cal_estimate_multinomial.tune_results <-
)
}

#' @export
#' @rdname cal_estimate_multinomial
cal_estimate_multinomial.grouped_df <- function(.data,
truth = NULL,
estimate = NULL,
smooth = TRUE,
parameters = NULL,
...) {
abort_if_grouped_df()
}

#' @rdname required_pkgs.cal_object
#' @keywords internal
#' @export
Expand Down
24 changes: 19 additions & 5 deletions R/cal-plot-breaks.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
#' If the predictions are well calibrated, the fitted curve should align with
#' the diagonal line.
#'
#' @param .data A data.frame object containing predictions and probability columns.
#' @param .data An ungrouped data frame object containing predictions and
#' probability columns.
#' @param truth The column identifier for the true class results
#' (that is a factor). This should be an unquoted column name.
#' @param estimate A vector of column identifiers, or one of `dplyr` selector
Expand Down Expand Up @@ -74,13 +75,11 @@
#' combined <- bind_rows(mutate(segment_logistic, source = "original"), gl)
#'
#' combined %>%
#' group_by(source) %>%
#' cal_plot_logistic(Class, .pred_good)
#' cal_plot_logistic(Class, .pred_good, group = source)
#'
#' # The grouping can be faceted in ggplot2
#' combined %>%
#' group_by(source) %>%
#' cal_plot_logistic(Class, .pred_good) +
#' cal_plot_logistic(Class, .pred_good, group = source) +
#' facet_wrap(~source) +
#' theme(legend.position = "")
#' @seealso [cal_plot_logistic()], [cal_plot_windowed()]
Expand Down Expand Up @@ -163,6 +162,21 @@ cal_plot_breaks.tune_results <- function(.data,
)
}

#' @export
#' @rdname cal_plot_breaks
cal_plot_breaks.grouped_df <- function(.data,
truth = NULL,
estimate = NULL,
num_breaks = 10,
conf_level = 0.90,
include_ribbon = TRUE,
include_rug = TRUE,
include_points = TRUE,
event_level = c("auto", "first", "second"),
...) {
abort_if_grouped_df()
}

#--------------------------- >> Implementation ---------------------------------

cal_plot_breaks_impl <- function(.data,
Expand Down
14 changes: 14 additions & 0 deletions R/cal-plot-logistic.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ cal_plot_logistic.tune_results <- function(.data,
)
}

#' @export
#' @rdname cal_plot_logistic
cal_plot_logistic.grouped_df <- function(.data,
truth = NULL,
estimate = NULL,
conf_level = 0.90,
smooth = TRUE,
include_rug = TRUE,
include_ribbon = TRUE,
event_level = c("auto", "first", "second"),
...) {
abort_if_grouped_df()
}

#--------------------------- >> Implementation ---------------------------------
cal_plot_logistic_impl <- function(.data,
truth = NULL,
Expand Down
13 changes: 12 additions & 1 deletion R/cal-plot-regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#' is shown. If the predictions are well calibrated, the fitted curve should align with
#' the diagonal line.
#'
#' @param .data A data.frame object containing prediction and truth columns.
#' @param .data An ungrouped data frame object containing a prediction
#' column.
#' @param truth The column identifier for the true results
#' (numeric). This should be an unquoted column name.
#' @param estimate The column identifier for the predictions.
Expand Down Expand Up @@ -85,6 +86,16 @@ cal_plot_regression.tune_results <- function(.data,
)
}

#' @export
#' @rdname cal_plot_regression
cal_plot_regression.grouped_df <- function(.data,
truth = NULL,
estimate = NULL,
smooth = TRUE,
...) {
abort_if_grouped_df()
}

regression_plot_impl <- function(.data, truth, estimate, group,
smooth, ...) {
truth <- enquo(truth)
Expand Down
16 changes: 16 additions & 0 deletions R/cal-plot-windowed.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,22 @@ cal_plot_windowed.tune_results <- function(.data,
)
}

#' @export
#' @rdname cal_plot_windowed
cal_plot_windowed.grouped_df <- function(.data,
truth = NULL,
estimate = NULL,
window_size = 0.1,
step_size = window_size / 2,
conf_level = 0.90,
include_ribbon = TRUE,
include_rug = TRUE,
include_points = TRUE,
event_level = c("auto", "first", "second"),
...) {
abort_if_grouped_df()
}

#--------------------------- >> Implementation ---------------------------------
cal_plot_windowed_impl <- function(.data,
truth = NULL,
Expand Down
10 changes: 10 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,13 @@ get_group_argument <- function(group, .data, call = rlang::env_parent()) {

return(group)
}

abort_if_grouped_df <- function(call = rlang::caller_env()) {
cli::cli_abort(
c(
"x" = "This function does not work with grouped data frames.",
"i" = "Apply {.fn dplyr::ungroup} and use the {.arg .by} argument."
),
call = call
)
}
3 changes: 2 additions & 1 deletion man/cal_binary_tables.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 13 additions & 2 deletions man/cal_estimate_beta.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions man/cal_estimate_isotonic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 12 additions & 2 deletions man/cal_estimate_isotonic_boot.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading