Skip to content

Commit

Permalink
Merge pull request #453 from tidymodels/groupwise-vignette
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Nov 2, 2023
2 parents ce03a94 + 9145ba1 commit d9cee7d
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 16 deletions.
4 changes: 2 additions & 2 deletions R/aaa-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -624,12 +624,12 @@ validate_function_class <- function(fns) {
}
}

# Special case unevaluated group-wise metric factories
# Special case unevaluated groupwise metric factories
if ("metric_factory" %in% fn_cls) {
factories <- fn_cls[fn_cls == "metric_factory"]
cli::cli_abort(
c("{cli::qty(factories)}The input{?s} {.arg {names(factories)}} \\
{?is a/are} {.help [group-wise metric](yardstick::new_groupwise_metric)} \\
{?is a/are} {.help [groupwise metric](yardstick::new_groupwise_metric)} \\
{?factory/factories} and must be passed a data-column before
addition to a metric set.",
"i" = "Did you mean to type e.g. `{names(factories)[1]}(col_name)`?"),
Expand Down
11 changes: 6 additions & 5 deletions R/fair-aaa.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
#' Create group-wise metrics
#' Create groupwise metrics
#'
#' Group-wise metrics quantify the disparity in value of a metric across a
#' number of groups. Group-wise metrics with a value of zero indicate that the
#' Groupwise metrics quantify the disparity in value of a metric across a
#' number of groups. Groupwise metrics with a value of zero indicate that the
#' underlying metric is equal across groups. yardstick defines
#' several common fairness metrics using this function, such as
#' [demographic_parity()], [equal_opportunity()], and [equalized_odds()].
#'
#' Note that _all_ yardstick metrics are group-aware in that, when passed
#' grouped data, they will return metric values calculated for each group.
#' When passed grouped data, group-wise metrics also return metric values
#' When passed grouped data, groupwise metrics also return metric values
#' for each group, but those metric values are calculated by first additionally
#' grouping by the variable passed to `by` and then summarizing the per-group
#' metric estimates across groups using the function passed as the
#' `aggregate` argument.
#' `aggregate` argument. Learn more about grouping behavior in yardstick using
#' `vignette("grouping", "yardstick")`.
#'
#' @param fn A yardstick metric function or metric set.
#' @param name The name of the metric to place in the `.metric` column
Expand Down
11 changes: 6 additions & 5 deletions man/new_groupwise_metric.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions tests/testthat/_snaps/aaa-metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
metric_set(demographic_parity)
Condition
Error in `metric_set()`:
! The input `demographic_parity` is a group-wise metric (`?yardstick::new_groupwise_metric()`) factory and must be passed a data-column before addition to a metric set.
! The input `demographic_parity` is a groupwise metric (`?yardstick::new_groupwise_metric()`) factory and must be passed a data-column before addition to a metric set.
i Did you mean to type e.g. `demographic_parity(col_name)`?

---
Expand All @@ -137,7 +137,7 @@
metric_set(demographic_parity, roc_auc)
Condition
Error in `metric_set()`:
! The input `demographic_parity` is a group-wise metric (`?yardstick::new_groupwise_metric()`) factory and must be passed a data-column before addition to a metric set.
! The input `demographic_parity` is a groupwise metric (`?yardstick::new_groupwise_metric()`) factory and must be passed a data-column before addition to a metric set.
i Did you mean to type e.g. `demographic_parity(col_name)`?

---
Expand All @@ -146,7 +146,7 @@
metric_set(demographic_parity, equal_opportunity)
Condition
Error in `metric_set()`:
! The inputs `demographic_parity` and `equal_opportunity` are group-wise metric (`?yardstick::new_groupwise_metric()`) factories and must be passed a data-column before addition to a metric set.
! The inputs `demographic_parity` and `equal_opportunity` are groupwise metric (`?yardstick::new_groupwise_metric()`) factories and must be passed a data-column before addition to a metric set.
i Did you mean to type e.g. `demographic_parity(col_name)`?

---
Expand All @@ -155,7 +155,7 @@
metric_set(demographic_parity, equal_opportunity, roc_auc)
Condition
Error in `metric_set()`:
! The inputs `demographic_parity` and `equal_opportunity` are group-wise metric (`?yardstick::new_groupwise_metric()`) factories and must be passed a data-column before addition to a metric set.
! The inputs `demographic_parity` and `equal_opportunity` are groupwise metric (`?yardstick::new_groupwise_metric()`) factories and must be passed a data-column before addition to a metric set.
i Did you mean to type e.g. `demographic_parity(col_name)`?

# propagates 'caused by' error message when specifying the wrong column name
Expand Down
172 changes: 172 additions & 0 deletions vignettes/grouping.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
---
title: "Grouping behavior in yardstick"
author: "Simon Couch"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Grouping behavior in yardstick}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---

```{r setup, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
```

The 1.3.0 release of yardstick introduced an implementation for _groupwise metrics_. The use case motivating the implementation of this functionality is _fairness metrics_, though groupwise metrics have applications beyond that domain. Fairness metrics quantify the degree of disparity in a metric value across groups. To learn more about carrying out fairness-oriented analyses with tidymodels, see the blog post on the tidymodels website. This vignette will instead focus on groupwise metrics generally, clarifying the meaning of "groupwise" and demonstrating functionality with an example dataset.

<!-- TODO: link to forthcoming tidymodels blog post -->

```{r pkgs, message = FALSE}
library(yardstick)
library(dplyr)
data("hpc_cv")
```

# Group-awareness

Even before the implementation of groupwise metrics, _all_ yardstick metrics had been _group-aware_. When grouped data is passed to a group-aware metric, it will return metric values calculated for each group.

To demonstrate, we'll make use of the `hpc_cv` data set, containing class probabilities and class predictions for a linear discriminant analysis fit to the HPC data set of Kuhn and Johnson (2013). The model is evaluated via 10-fold cross-validation, and the predictions for all folds are included.

```{r hpc-cv}
tibble(hpc_cv)
```

For the purposes of this vignette, we'll also add a column `batch` to the data and select off the columns for the class probabilities, which we don't need.

```{r hpc-modify}
set.seed(1)
hpc <-
tibble(hpc_cv) %>%
mutate(batch = sample(c("a", "b"), nrow(.), replace = TRUE)) %>%
select(-c(VF, F, M, L))
hpc
```

If we wanted to compute the accuracy of the first resampled model, we could write:

```{r acc-1}
hpc %>%
filter(Resample == "Fold01") %>%
accuracy(obs, pred)
```

The metric function returns one row, giving the `.metric`, `.estimator`, and `.estimate` for the whole data set it is passed.

If we instead group the data by fold, metric functions like `accuracy` will know to compute values for each group; in the output, each row will correspond to a Resample.

```{r hpc-cv-2}
hpc %>%
group_by(Resample) %>%
accuracy(obs, pred)
```

Note that the first row, corresponding to `Fold01`, gives the same value as manually filtering for the observations corresponding to the first resample and then computing the accuracy.

This behavior is what we mean by group-awareness. When grouped data is passed to group-aware metric functions, they will return values for each group.

# Groupwise metrics

Groupwise metrics are associated with a data-column such that, when passed data with that column, the metric will temporarily group by that column, compute values for each of the groups defined by the column, and then aggregate the values computed for the temporary grouping back to the level of the input data's grouping.

More concretely, let's turn to an example where there is no pre-existing grouping in the data. Consider the portion of the HPC data pertaining to the first resample:

```{r res-1}
hpc %>%
filter(Resample == "Fold01")
```

Suppose that the `batch`es in the data represent two groups for which model performance ought not to differ. To quantify the degree to which model performance differs for these two groups, we could compute accuracy values for either group separately, and then take their difference. First, computing accuracies:

```{r acc-by-group}
acc_by_group <-
hpc %>%
filter(Resample == "Fold01") %>%
group_by(batch) %>%
accuracy(obs, pred)
acc_by_group
```

Now, taking the difference:

```{r diff-acc}
diff(c(acc_by_group$.estimate[2], acc_by_group$.estimate[1]))
```

Groupwise metrics encode the `group_by()` and aggregation step (in this case, subtraction) shown above into a yardstick metric. We can define a new groupwise metric with the `new_groupwise_metric()` function:

```{r}
accuracy_diff <-
new_groupwise_metric(
fn = accuracy,
name = "accuracy_diff",
aggregate = function(acc_by_group) {
diff(c(acc_by_group$.estimate[2], acc_by_group$.estimate[1]))
}
)
```

* The `fn` argument is the yardstick metric that will be computed for each data group.
* The `name` argument gives the name of the new metric we've created; we'll call ours "accuracy difference."
* The `aggregate` argument is a function defining how to go from `fn` output by group to a single numeric value.

The output, `accuracy_diff`, is a function subclass called a `metric_factory`:

```{r acc-diff-class}
class(accuracy_diff)
```

`accuracy_diff` now knows to take accuracy values for each group and then return the difference between the accuracy for the first and second result as output. The last thing we need to associate with the object is the name of the grouping variable to pass to `group_by()`; we can pass that variable name to `accuracy_diff` to do so:

```{r acc-diff-by}
accuracy_diff_by_batch <- accuracy_diff(batch)
```

The output, `accuracy_diff_by_batch`, is a yardstick metric function like any other:

```{r metric-classes}
class(accuracy)
class(accuracy_diff_by_batch)
```

<!-- TODO: once a print method is added, we can print this out and the meaning of "this is just a yardstick metric" will be clearer -->

We can use the `accuracy_diff_by_batch()` metric in the same way that we would use `accuracy()`. On its own:

```{r ex-acc-diff-by-batch}
hpc %>%
filter(Resample == "Fold01") %>%
accuracy_diff_by_batch(obs, pred)
```

We can also add `accuracy_diff_by_batch()` to metric sets:

```{r ex-acc-diff-by-batch-ms}
acc_ms <- metric_set(accuracy, accuracy_diff_by_batch)
hpc %>%
filter(Resample == "Fold01") %>%
acc_ms(truth = obs, estimate = pred)
```

_Groupwise metrics are group-aware._ When passed data with any grouping variables other than the column passed as the first argument to `accuracy_diff()`---in this case, `group`---`accuracy_diff_by_batch()` will behave like any other yardstick metric. For example:

```{r ex-acc-diff-by-batch-2}
hpc %>%
group_by(Resample) %>%
accuracy_diff_by_batch(obs, pred)
```

Groupwise metrics form the backend of fairness metrics in tidymodels. To learn more about groupwise metrics and their applications in fairness problems, see `new_groupwise_metric()`.

<!-- TODO: link to tidyverse blog post and tidymodels articles. -->

0 comments on commit d9cee7d

Please sign in to comment.