Skip to content

Commit e4fea2a

Browse files
authored
Merge branch 'master' into varying_functions
2 parents c37e49b + 7b74051 commit e4fea2a

File tree

11 files changed

+862
-5
lines changed

11 files changed

+862
-5
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ S3method(print,mlp)
2525
S3method(print,model_fit)
2626
S3method(print,model_spec)
2727
S3method(print,multinom_reg)
28+
S3method(print,nearest_neighbor)
2829
S3method(print,rand_forest)
2930
S3method(print,surv_reg)
3031
S3method(translate,boost_tree)
@@ -41,6 +42,7 @@ S3method(update,logistic_reg)
4142
S3method(update,mars)
4243
S3method(update,mlp)
4344
S3method(update,multinom_reg)
45+
S3method(update,nearest_neighbor)
4446
S3method(update,rand_forest)
4547
S3method(update,surv_reg)
4648
S3method(varying_args,model_spec)
@@ -62,6 +64,7 @@ export(mlp)
6264
export(model_printer)
6365
export(multi_predict)
6466
export(multinom_reg)
67+
export(nearest_neighbor)
6568
export(predict.model_fit)
6669
export(predict_class)
6770
export(predict_class.model_fit)

R/nearest_neighbor.R

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# TODO) If implementing `class::knn()`, mention that it does not have
2+
# the distance param because it uses Euclidean distance. And no `weight_func`
3+
# param.
4+
5+
#' General Interface for K-Nearest Neighbor Models
6+
#'
7+
#' `nearest_neighbor()` is a way to generate a _specification_ of a model
8+
#' before fitting and allows the model to be created using
9+
#' different packages in R. The main arguments for the
10+
#' model are:
11+
#' \itemize{
12+
#' \item \code{neighbors}: The number of neighbors considered at
13+
#' each prediction.
14+
#' \item \code{weight_func}: The type of kernel function that weights the
15+
#' distances between samples.
16+
#' \item \code{dist_power}: The parameter used when calculating the Minkowski
17+
#' distance. This corresponds to the Manhattan distance with `dist_power = 1`
18+
#' and the Euclidean distance with `dist_power = 2`.
19+
#' }
20+
#' These arguments are converted to their specific names at the
21+
#' time that the model is fit. Other options and argument can be
22+
#' set using the `others` argument. If left to their defaults
23+
#' here (`NULL`), the values are taken from the underlying model
24+
#' functions. If parameters need to be modified, `update()` can be used
25+
#' in lieu of recreating the object from scratch.
26+
#'
27+
#' @param mode A single character string for the type of model.
28+
#' Possible values for this model are `"unknown"`, `"regression"`, or
29+
#' `"classification"`.
30+
#'
31+
#' @param neighbors A single integer for the number of neighbors
32+
#' to consider (often called `k`).
33+
#'
34+
#' @param weight_func A *single* character for the type of kernel function used
35+
#' to weight distances between samples. Valid choices are: `"rectangular"`,
36+
#' `"triangular"`, `"epanechnikov"`, `"biweight"`, `"triweight"`,
37+
#' `"cos"`, `"inv"`, `"gaussian"`, `"rank"`, or `"optimal"`.
38+
#'
39+
#' @param dist_power A single number for the parameter used in
40+
#' calculating Minkowski distance.
41+
#'
42+
#' @param others A named list of arguments to be used by the
43+
#' underlying models (e.g., `kknn::train.kknn`). These are not evaluated
44+
#' until the model is fit and will be substituted into the model
45+
#' fit expression.
46+
#'
47+
#' @param ... Used for S3 method consistency. Any arguments passed to
48+
#' the ellipses will result in an error. Use `others` instead.
49+
#'
50+
#' @details
51+
#' The model can be created using the `fit()` function using the
52+
#' following _engines_:
53+
#' \itemize{
54+
#' \item \pkg{R}: `"kknn"`
55+
#' }
56+
#'
57+
#' Engines may have pre-set default arguments when executing the
58+
#' model fit call. These can be changed by using the `others`
59+
#' argument to pass in the preferred values. For this type of
60+
#' model, the template of the fit calls are:
61+
#'
62+
#' \pkg{kknn} (classification or regression)
63+
#'
64+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::nearest_neighbor(), "kknn")}
65+
#'
66+
#' @note
67+
#' For `kknn`, the underlying modeling function used is a restricted
68+
#' version of `train.kknn()` and not `kknn()`. It is set up in this way so that
69+
#' `parsnip` can utilize the underlying `predict.train.kknn` method to predict
70+
#' on new data. This also means that a single value of that function's
71+
#' `kernel` argument (a.k.a `weight_func` here) can be supplied
72+
#'
73+
#' @seealso [varying()], [fit()]
74+
#'
75+
#' @examples
76+
#' nearest_neighbor()
77+
#'
78+
#' @export
79+
nearest_neighbor <- function(mode = "unknown",
80+
neighbors = NULL,
81+
weight_func = NULL,
82+
dist_power = NULL,
83+
others = list(),
84+
...) {
85+
86+
check_empty_ellipse(...)
87+
88+
## TODO: make a utility function here
89+
if (!(mode %in% nearest_neighbor_modes)) {
90+
stop("`mode` should be one of: ",
91+
paste0("'", nearest_neighbor_modes, "'", collapse = ", "),
92+
call. = FALSE)
93+
}
94+
95+
if(is.numeric(neighbors) && !positive_int_scalar(neighbors)) {
96+
stop("`neighbors` must be a length 1 positive integer.", call. = FALSE)
97+
}
98+
99+
if(is.character(weight_func) && length(weight_func) > 1) {
100+
stop("The length of `weight_func` must be 1.", call. = FALSE)
101+
}
102+
103+
args <- list(
104+
neighbors = neighbors,
105+
weight_func = weight_func,
106+
dist_power = dist_power
107+
)
108+
109+
no_value <- !vapply(others, is.null, logical(1))
110+
others <- others[no_value]
111+
112+
# write a constructor function
113+
out <- list(args = args, others = others,
114+
mode = mode, method = NULL, engine = NULL)
115+
# TODO: make_classes has wrong order; go from specific to general
116+
class(out) <- make_classes("nearest_neighbor")
117+
out
118+
}
119+
120+
#' @export
121+
print.nearest_neighbor <- function(x, ...) {
122+
cat("K-Nearest Neighbor Model Specification (", x$mode, ")\n\n", sep = "")
123+
model_printer(x, ...)
124+
125+
if(!is.null(x$method$fit$args)) {
126+
cat("Model fit template:\n")
127+
print(show_call(x))
128+
}
129+
invisible(x)
130+
}
131+
132+
# ------------------------------------------------------------------------------
133+
134+
#' @export
135+
update.nearest_neighbor <- function(object,
136+
neighbors = NULL,
137+
weight_func = NULL,
138+
dist_power = NULL,
139+
others = list(),
140+
fresh = FALSE,
141+
...) {
142+
143+
check_empty_ellipse(...)
144+
145+
if(is.numeric(neighbors) && !positive_int_scalar(neighbors)) {
146+
stop("`neighbors` must be a length 1 positive integer.", call. = FALSE)
147+
}
148+
149+
if(is.character(weight_func) && length(weight_func) > 1) {
150+
stop("The length of `weight_func` must be 1.", call. = FALSE)
151+
}
152+
153+
args <- list(
154+
neighbors = neighbors,
155+
weight_func = weight_func,
156+
dist_power = dist_power
157+
)
158+
159+
if (fresh) {
160+
object$args <- args
161+
} else {
162+
null_args <- map_lgl(args, null_value)
163+
if (any(null_args))
164+
args <- args[!null_args]
165+
if (length(args) > 0)
166+
object$args[names(args)] <- args
167+
}
168+
169+
if (length(others) > 0) {
170+
if (fresh)
171+
object$others <- others
172+
else
173+
object$others[names(others)] <- others
174+
}
175+
176+
object
177+
}
178+
179+
180+
positive_int_scalar <- function(x) {
181+
(length(x) == 1) && (x > 0) && (x %% 1 == 0)
182+
}

R/nearest_neighbor_data.R

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
nearest_neighbor_arg_key <- data.frame(
2+
kknn = c("ks", "kernel", "distance"),
3+
row.names = c("neighbors", "weight_func", "dist_power"),
4+
stringsAsFactors = FALSE
5+
)
6+
7+
nearest_neighbor_modes <- c("classification", "regression", "unknown")
8+
9+
nearest_neighbor_engines <- data.frame(
10+
kknn = c(TRUE, TRUE, FALSE),
11+
row.names = c("classification", "regression", "unknown")
12+
)
13+
14+
# ------------------------------------------------------------------------------
15+
16+
nearest_neighbor_kknn_data <-
17+
list(
18+
libs = "kknn",
19+
fit = list(
20+
interface = "formula",
21+
protect = c("formula", "data", "kmax"), # kmax is not allowed
22+
func = c(pkg = "kknn", fun = "train.kknn"),
23+
defaults = list()
24+
),
25+
pred = list(
26+
# seems unnecessary here as the predict_num catches it based on the
27+
# model mode
28+
pre = function(x, object) {
29+
if (object$fit$response != "continuous") {
30+
stop("`kknn` model does not appear to use numeric predictions. Was ",
31+
"the model fit with a continuous response variable?",
32+
call. = FALSE)
33+
}
34+
x
35+
},
36+
post = NULL,
37+
func = c(fun = "predict"),
38+
args =
39+
list(
40+
object = quote(object$fit),
41+
newdata = quote(new_data),
42+
type = "raw"
43+
)
44+
),
45+
classes = list(
46+
pre = function(x, object) {
47+
if (!(object$fit$response %in% c("ordinal", "nominal"))) {
48+
stop("`kknn` model does not appear to use class predictions. Was ",
49+
"the model fit with a factor response variable?",
50+
call. = FALSE)
51+
}
52+
x
53+
},
54+
post = NULL,
55+
func = c(fun = "predict"),
56+
args =
57+
list(
58+
object = quote(object$fit),
59+
newdata = quote(new_data),
60+
type = "raw"
61+
)
62+
),
63+
prob = list(
64+
pre = function(x, object) {
65+
if (!(object$fit$response %in% c("ordinal", "nominal"))) {
66+
stop("`kknn` model does not appear to use class predictions. Was ",
67+
"the model fit with a factor response variable?",
68+
call. = FALSE)
69+
}
70+
x
71+
},
72+
post = function(result, object) as_tibble(result),
73+
func = c(fun = "predict"),
74+
args =
75+
list(
76+
object = quote(object$fit),
77+
newdata = quote(new_data),
78+
type = "prob"
79+
)
80+
),
81+
raw = list(
82+
pre = NULL,
83+
func = c(fun = "predict"),
84+
args =
85+
list(
86+
object = quote(object$fit),
87+
newdata = quote(new_data)
88+
)
89+
)
90+
)

_pkgdown.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ reference:
2020
- mars
2121
- mlp
2222
- multinom_reg
23+
- nearest_neighbor
2324
- rand_forest
2425
- surv_reg
2526
- title: Infrastructure

docs/articles/articles/Classification.html

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/articles/articles/Models.html

Lines changed: 38 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)