/
num-huber_loss.R
100 lines (90 loc) · 2.7 KB
/
num-huber_loss.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#' Huber loss
#'
#' Calculate the Huber loss, a loss function used in robust regression. This
#' loss function is less sensitive to outliers than [rmse()]. This function is
#' quadratic for small residual values and linear for large residual values.
#'
#' @family numeric metrics
#' @family accuracy metrics
#' @templateVar fn huber_loss
#' @template return
#'
#' @inheritParams rmse
#'
#' @param delta A single `numeric` value. Defines the boundary where the loss function
#' transitions from quadratic to linear. Defaults to 1.
#'
#' @author James Blair
#'
#' @references
#'
#' Huber, P. (1964). Robust Estimation of a Location Parameter.
#' _Annals of Statistics_, 53 (1), 73-101.
#'
#' @template examples-numeric
#'
#' @export
huber_loss <- function(data, ...) {
UseMethod("huber_loss")
}
huber_loss <- new_numeric_metric(
huber_loss,
direction = "minimize"
)
#' @rdname huber_loss
#' @export
huber_loss.data.frame <- function(data,
truth,
estimate,
delta = 1,
na_rm = TRUE,
case_weights = NULL,
...) {
numeric_metric_summarizer(
name = "huber_loss",
fn = huber_loss_vec,
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
na_rm = na_rm,
case_weights = !!enquo(case_weights),
# Extra argument for huber_loss_impl()
fn_options = list(delta = delta)
)
}
#' @export
#' @rdname huber_loss
huber_loss_vec <- function(truth,
estimate,
delta = 1,
na_rm = TRUE,
case_weights = NULL,
...) {
check_numeric_metric(truth, estimate, case_weights)
if (na_rm) {
result <- yardstick_remove_missing(truth, estimate, case_weights)
truth <- result$truth
estimate <- result$estimate
case_weights <- result$case_weights
} else if (yardstick_any_missing(truth, estimate, case_weights)) {
return(NA_real_)
}
huber_loss_impl(truth, estimate, delta, case_weights)
}
huber_loss_impl <- function(truth,
estimate,
delta,
case_weights,
call = caller_env()) {
# Weighted Huber Loss implementation confirmed against matlab:
# https://www.mathworks.com/help/deeplearning/ref/dlarray.huber.html
check_number_decimal(delta, min = 0, call = call)
a <- truth - estimate
abs_a <- abs(a)
loss <- ifelse(
abs_a <= delta,
0.5 * a^2,
delta * (abs_a - 0.5 * delta)
)
yardstick_mean(loss, case_weights = case_weights)
}