-
Notifications
You must be signed in to change notification settings - Fork 0
/
set_pred.R
154 lines (141 loc) · 4.78 KB
/
set_pred.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#' Register Prediction Method for Model
#'
#' This function is used to register prediction method information for a model,
#' mode, and engine combination.
#'
#' @param model A single character string for the model type (e.g. `"k_means"`,
#' etc).
#' @param mode A single character string for the model mode (e.g. "partition").
#' @param eng A single character string for the model engine.
#' @param type A single character value for the type of prediction. Possible
#' values are: `cluster` and `raw`.
#' @param value A list of values, described in the Details.
#' @details
#' The list passed to `value` needs the following values:
#'
#' - **pre** and **post** are optional functions that can preprocess the data
#' being fed to the prediction code and to postprocess the raw output of the
#' predictions. These won’t be needed for this example, but a section below
#' has examples of how these can be used when the model code is not easy to
#' use. If the data being predicted has a simple type requirement, you can
#' avoid using a **pre** function with the **args** below.
#' - **func** is the prediction function (in the same format as above). In many
#' cases, packages have a predict method for their model’s class but this is
#' typically not exported. In this case (and the example below), it is simple
#' enough to make a generic call to `predict()` with no associated package.
#' - **args** is a list of arguments to pass to the prediction function. These
#' will most likely be wrapped in `rlang::expr()` so that they are not
#' evaluated when defining the method. For mda, the code would be
#' `predict(object, newdata, type = "class")`. What is actually given to the
#' function is the model fit object, which includes a sub-object
#' called `fit()` that houses the mda model object. If the data need to be a
#' matrix or data frame, you could also use
#' `newdata = quote(as.data.frame(newdata))` or similar.
#'
#' @return A tibble
#' @examplesIf FALSE
#' set_new_model("shallow_learning_model")
#' set_model_mode("shallow_learning_model", "partition")
#' set_model_engine("shallow_learning_model", "partition", "stats")
#'
#' set_pred(
#' model = "shallow_learning_model",
#' eng = "stats",
#' mode = "partition",
#' type = "cluster",
#' value = list(
#' pre = NULL,
#' post = NULL,
#' func = c(fun = "predict"),
#' args =
#' list(
#' object = rlang::expr(object$fit),
#' newdata = rlang::expr(new_data),
#' type = "response"
#' )
#' )
#' )
#'
#' get_pred_type("shallow_learning_model", "cluster")
#' get_pred_type("shallow_learning_model", "cluster")$value
#' @export
set_pred <- function(model, mode, eng, type, value) {
check_model_val(model)
check_mode_val(mode)
check_eng_val(eng)
check_spec_mode_engine_val(model, mode, eng)
check_pred_info(value, type)
new_pred <- tibble::tibble(
engine = eng,
mode = mode,
type = type,
value = list(value)
)
pred_check <- is_discordant_info(
model = model,
mode = mode,
eng = eng,
candidate = new_pred,
pred_type = type,
component = "predict"
)
if (!pred_check) {
return(invisible(NULL))
}
old_pred <- get_from_env(paste0(model, "_predict"))
updated <- vctrs::vec_rbind(old_pred, new_pred)
set_env_val(paste0(model, "_predict"), updated)
invisible(NULL)
}
#' @rdname set_pred
#' @export
get_pred_type <- function(model, type) {
check_model_val(model)
pred_name <- paste0(model, "_predict")
all_preds <- rlang::env_get(get_model_env(), pred_name)
vctrs::vec_slice(all_preds, all_preds$type == type)
}
check_pred_info <- function(pred_obj, type, call = rlang::caller_env()) {
if (rlang::is_missing(pred_obj)) {
rlang::abort(
"Argument `value` is missing, with no default.",
call = call
)
}
if (all(type != pred_types)) {
rlang::abort(
glue::glue(
"The prediction type should be one of: ",
glue::glue_collapse(glue::glue("'{pred_types}'"), sep = ", ")
),
call = call
)
}
exp_nms <- c("args", "func", "post", "pre")
if (!isTRUE(all.equal(sort(names(pred_obj)), exp_nms))) {
rlang::abort(
glue::glue(
"The `predict` module should have elements: ",
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", ")
),
call = call
)
}
if (!is.null(pred_obj$pre) & !is.function(pred_obj$pre)) {
rlang::abort(
"The `pre` module should be null or a function: ",
call = call
)
}
if (!is.null(pred_obj$post) & !is.function(pred_obj$post)) {
rlang::abort(
"The `post` module should be null or a function: ",
call = call
)
}
check_func_val(pred_obj$func, call = call)
if (!is.list(pred_obj$args)) {
rlang::abort("The `args` element should be a list.", call = call)
}
invisible(NULL)
}