-
Notifications
You must be signed in to change notification settings - Fork 12
/
conformal_infer_quantile.R
155 lines (137 loc) · 5.3 KB
/
conformal_infer_quantile.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#' Prediction intervals via conformal inference and quantile regression
#'
#' Nonparametric prediction intervals can be computed for fitted regression
#' workflow objects using the split conformal inference method described by
#' Romano _et al_ (2019). To compute quantiles, this function uses Quantile
#' Random Forests instead of classic quantile regression.
#' @param object A fitted [workflows::workflow()] object.
#' @param ... Not currently used.
#' @param train_data,cal_data Data frames with the _predictor and outcome data_.
#' `train_data` should be the same data used to produce `object` and `cal_data` is
#' used to produce predictions (and residuals). If the workflow used a recipe,
#' these should be the data that were inputs to the recipe (and not the product
#' of a recipe).
#' @param level The confidence level for the intervals.
#' @param ... Options to pass to [quantregForest::quantregForest()] (such as the
#' number of trees).
#' @return An object of class `"int_conformal_quantile"` containing the
#' information to create intervals (which includes `object`).
#' The `predict()` method is used to produce the intervals.
#' @details
#' Note that the significance level should be specified in this function
#' (instead of the `predict()` method).
#'
#' `cal_data` should be large enough to get a good estimates of a extreme
#' quantile (e.g., the 95th for 95% interval) and should not include rows that
#' were in the original training set.
#'
#' Note that the because of the method used to construct the interval, it is
#' possible that the prediction intervals will not include the predicted value.
#' @seealso [predict.int_conformal_quantile()]
#' @references
#' Romano, Yaniv, Evan Patterson, and Emmanuel Candes. "Conformalized quantile
#' regression." _Advances in neural information processing systems_ 32 (2019).
#' @examplesIf !probably:::is_cran_check()
#' library(workflows)
#' library(dplyr)
#' library(parsnip)
#' library(rsample)
#' library(tune)
#' library(modeldata)
#'
#' set.seed(2)
#' sim_train <- sim_regression(500)
#' sim_cal <- sim_regression(200)
#' sim_new <- sim_regression(5) %>% select(-outcome)
#'
#' # We'll use a neural network model
#' mlp_spec <-
#' mlp(hidden_units = 5, penalty = 0.01) %>%
#' set_mode("regression")
#'
#' mlp_wflow <-
#' workflow() %>%
#' add_model(mlp_spec) %>%
#' add_formula(outcome ~ .)
#'
#' mlp_fit <- fit(mlp_wflow, data = sim_train)
#'
#' mlp_int <- int_conformal_quantile(mlp_fit, sim_train, sim_cal,
#' level = 0.90
#' )
#' mlp_int
#'
#' predict(mlp_int, sim_new)
#' @export
int_conformal_quantile <- function(object, ...) {
UseMethod("int_conformal_quantile")
}
#' @export
#' @rdname int_conformal_quantile
int_conformal_quantile.workflow <-
function(object, train_data, cal_data, level = 0.95, ...) {
check_data_all(train_data, object)
check_data_all(cal_data, object)
# ------------------------------------------------------------------------------
y_name <- names(hardhat::extract_mold(object)$outcomes)
quant_fit <- quant_train(train_data, y_name, ...)
quant_pred <- quant_predict(quant_fit, cal_data, level)
# ------------------------------------------------------------------------------
R_low <- quant_pred$.pred_lower - cal_data[[y_name]]
R_high <- cal_data[[y_name]] - quant_pred$.pred_upper
resid <- pmax(R_low, R_high)
# ------------------------------------------------------------------------------
res <-
list(
resid = sort(resid),
wflow = object,
n = nrow(cal_data),
quant = quant_fit,
level = level
)
class(res) <- c("conformal_reg_quantile", "int_conformal_quantile")
res
}
#' @export
print.int_conformal_quantile <- function(x, ...) {
cat("Split Conformal inference via Quantile Regression\n")
cat("preprocessor:", .get_pre_type(x$wflow), "\n")
cat("model:", .get_fit_type(x$wflow), "\n")
cat("calibration set size:", format(x$n, big.mark = ","), "\n")
cat("confidence level:", format(x$level, digits = 3), "\n\n")
cat("Use `predict(object, new_data)` to compute prediction intervals\n")
invisible(x)
}
#' @export
#' @rdname predict.int_conformal_full
predict.int_conformal_quantile <- function(object, new_data, ...) {
check_data(new_data, object$wflow)
rlang::check_dots_empty()
new_pred <- predict(object$wflow, new_data)
quant_pred <- quant_predict(object$quant, new_data, object$level)
alpha <- (1 - object$level)
q_ind <- ceiling((1 - alpha) * (object$n + 1))
q_val <- object$resid[q_ind]
quant_pred$.pred_lower <- quant_pred$.pred_lower - q_val
quant_pred$.pred_upper <- quant_pred$.pred_upper + q_val
dplyr::bind_cols(new_pred, quant_pred)
}
quant_train <- function(train_data, y_name, ...) {
rlang::check_installed("quantregForest", reason = "to compute quantiles")
cl <- rlang::call2(
"quantregForest",
.ns = "quantregForest",
x = quote(train_data %>% select(-dplyr::all_of(y_name))),
y = quote(train_data[[y_name]]),
...
)
rlang::eval_tidy(cl)
}
quant_predict <- function(fit, new_data, level) {
rlang::check_installed("quantregForest", reason = "to compute quantiles")
alpha <- (1 - level)
quant_pred <- predict(fit, new_data, what = c(alpha / 2, 1 - (alpha / 2)))
quant_pred <- dplyr::as_tibble(quant_pred)
quant_pred <- stats::setNames(quant_pred, c(".pred_lower", ".pred_upper"))
quant_pred
}