-
Notifications
You must be signed in to change notification settings - Fork 61
Expand file tree
/
Copy pathmetric-types.Rmd
More file actions
129 lines (94 loc) · 3.08 KB
/
metric-types.Rmd
File metadata and controls
129 lines (94 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
---
title: "Metric types"
author: "Davis Vaughan"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Metric types}
%\VignetteEncoding{UTF-8}
%\VignetteEngine{knitr::rmarkdown}
editor_options:
chunk_output_type: console
---
```{r setup, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
```
## Metric types
There are three main metric types in `yardstick`: class, class probability, and
numeric. Each type of metric has standardized argument syntax, and all metrics
return the same kind of output (a tibble with 3 columns). This standardization
allows metrics to easily be grouped together and used with grouped data frames
for computing on multiple resamples at once. Below are the five types of
metrics, along with the types of the inputs they take.
1) **Class metrics** (hard predictions)
- `truth` - factor
- `estimate` - factor
2) **Class probability metrics** (soft predictions)
- `truth` - factor
- `estimate / ...` - multiple numeric columns containing class probabilities
3) **Numeric metrics**
- `truth` - numeric
- `estimate` - numeric
4) **Static survival metircs**
- `truth` - Surv
- `estimate` - numeric
5) **Dynamic survival metrics**
- `truth` - Surv
- `...` - list of data.frames, each containing the 3 columns `.eval_time`, `.pred_survival, and `.weight_censored`
## Example
In the following example, the `hpc_cv` data set is used. It contains class
probabilities and class predictions for a linear discriminant analysis fit to
the HPC data set of Kuhn and Johnson (2013). It is fit with 10 fold cross-validation,
and the predictions for all folds are included.
```{r, warning = FALSE, message = FALSE}
library(yardstick)
library(dplyr)
data("hpc_cv")
hpc_cv %>%
group_by(Resample) %>%
slice(1:3)
```
1 metric, 1 resample
```{r}
hpc_cv %>%
filter(Resample == "Fold01") %>%
accuracy(obs, pred)
```
1 metric, 10 resamples
```{r}
hpc_cv %>%
group_by(Resample) %>%
accuracy(obs, pred)
```
2 metrics, 10 resamples
```{r}
class_metrics <- metric_set(accuracy, kap)
hpc_cv %>%
group_by(Resample) %>%
class_metrics(obs, estimate = pred)
```
## Metrics
Below is a table of all of the metrics available in `yardstick`, grouped
by type.
```{r, echo=FALSE, warning=FALSE, message=FALSE, results='asis'}
library(knitr)
library(dplyr)
yardns <- asNamespace("yardstick")
fns <- lapply(names(yardns), get, envir = yardns)
names(fns) <- names(yardns)
get_metrics <- function(fns, type) {
where <- vapply(fns, inherits, what = type, FUN.VALUE = logical(1))
paste0("`", sort(names(fns[where])), "()`")
}
all_metrics <- bind_rows(
tibble(type = "class", metric = get_metrics(fns, "class_metric")),
tibble(type = "class prob", metric = get_metrics(fns, "prob_metric")),
tibble(type = "numeric", metric = get_metrics(fns, "numeric_metric")),
tibble(type = "dynamic survival", metric = get_metrics(fns, "dynamic_survival_metric")),
tibble(type = "static survival", metric = get_metrics(fns, "static_survival_metric"))
)
kable(all_metrics, format = "html")
```