/
KNNTrainer.R
203 lines (175 loc) · 6.77 KB
/
KNNTrainer.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#' K Nearest Neighbours Trainer
#' @description Trains a k nearest neighbour model using fast search algorithms. KNN is a supervised learning
#' algorithm which is used for both regression and classification problems.
#' @format \code{\link{R6Class}} object.
#' @section Usage:
#' For usage details see \bold{Methods, Arguments and Examples} sections.
#' \preformatted{
#' bst = KNNTrainer$new(k=1, prob=FALSE, algorithm=NULL, type="class")
#' bst$fit(X_train, X_test, "target")
#' bst$predict(type)
#' }
#' @section Methods:
#' \describe{
#' \item{\code{$new()}}{Initialise the instance of the trainer}
#' \item{\code{$fit()}}{trains the knn model and stores the test prediction}
#' \item{\code{$predict()}}{returns predictions}
#' }
#' @section Arguments:
#' \describe{
#' \item{k}{number of neighbours to predict}
#' \item{prob}{if probability should be computed, default=FALSE}
#' \item{algorithm}{algorithm used to train the model, possible values are 'kd_tree','cover_tree','brute'}
#' \item{type}{type of problem to solve i.e. regression or classification, possible values are 'reg' or 'class'}
#' }
#' @export
#' @examples
#' data("iris")
#'
#' iris$Species <- as.integer(as.factor(iris$Species))
#'
#' xtrain <- iris[1:100,]
#' xtest <- iris[101:150,]
#'
#' bst <- KNNTrainer$new(k=3, prob=TRUE, type="class")
#' bst$fit(xtrain, xtest, 'Species')
#' pred <- bst$predict(type="raw")
KNNTrainer <- R6Class("KNNTrainer", public = list(
#' @field k number of neighbours to predict
k = 1,
#' @field prob if probability should be computed, default=FALSE
prob = FALSE,
#' @field algorithm algorithm used to train the model, possible values are 'kd_tree','cover_tree','brute'
algorithm = NULL,
#' @field type type of problem to solve i.e. regression or classification, possible values are 'reg' or 'class'
type = "class",
#' @field model for internal use
model = NA,
#' @details
#' Create a new `KNNTrainer` object.
#'
#' @param k k number of neighbours to predict
#' @param prob if probability should be computed, default=FALSE
#' @param algorithm algorithm used to train the model, possible values are 'kd_tree','cover_tree','brute'
#' @param type type of problem to solve i.e. regression or classification, possible values are 'reg' or 'class'
#' @return A `KNNTrainer` object.
#'
#' @examples
#' data("iris")
#'
#' iris$Species <- as.integer(as.factor(iris$Species))
#'
#' xtrain <- iris[1:100,]
#' xtest <- iris[101:150,]
#'
#' bst <- KNNTrainer$new(k=3, prob=TRUE, type="class")
#' bst$fit(xtrain, xtest, 'Species')
#' pred <- bst$predict(type="raw")
initialize = function(k, prob, algorithm, type){
if(!(missing(k))) self$k <- k
if(!(missing(prob))) self$prob <- prob
if(!(missing(algorithm))) self$algorithm <- algorithm
if(!(missing(type))) self$type <- type
superml::check_package("FNN")
},
#' @details
#' Trains the KNNTrainer model
#'
#' @param train data.frame or matrix
#' @param test data.frame or matrix
#' @param y character, name of target variable
#' @return NULL
#'
#' @examples
#' data("iris")
#'
#' iris$Species <- as.integer(as.factor(iris$Species))
#'
#' xtrain <- iris[1:100,]
#' xtest <- iris[101:150,]
#'
#' bst <- KNNTrainer$new(k=3, prob=TRUE, type="class")
#' bst$fit(xtrain, xtest, 'Species')
fit = function(train, test, y){
data <- private$prepare_data(train, test, y)
if(self$type == "class"){
self$model <- FNN::knn(train = data$train
,test = data$test
,cl = data$y
,k = self$k
,prob = self$prob
,algorithm = self$algorithm)
} else if (self$type == "reg"){
self$model <- FNN::knn.reg(train = data$train
,test = data$test
,y = data$y
,k = self$k
,algorithm = self$algorithm)
}
},
#' @details
#' Predits the nearest neigbours for test data
#'
#' @param type character, 'raw' for labels else 'prob'
#' @return a list of predicted neighbours
#'
#' @examples
#' data("iris")
#'
#' iris$Species <- as.integer(as.factor(iris$Species))
#'
#' xtrain <- iris[1:100,]
#' xtest <- iris[101:150,]
#'
#' bst <- KNNTrainer$new(k=3, prob=TRUE, type="class")
#' bst$fit(xtrain, xtest, 'Species')
#' pred <- bst$predict(type="raw")
predict = function(type="raw"){
if (self$type == "class") {
if (type == "raw") {
return(as.numeric(as.character(self$model)))
} else if (type == "prob") {
return(attr(self$model, "prob"))
}
} else if (self$type == "reg") {
return(self$model$pred)
}
}),
private = list(
prepare_data = function(train, test, y){
train <- as.data.table(train)
test <- as.data.table(test)
if (!(y %in% names(train)))
stop(sprintf("%s not available in training data", y))
# get dependent variable and store temporarily
y_temp <- train[[y]]
# select all independent features
train <- train[,setdiff(names(train), y), with = F]
# subset from test, just in case if the dependet variable is in test
test <- test[, setdiff(names(test), y), with = F]
# set dependent variable to y
y <- y_temp
if (ncol(test) != ncol(train))
stop(sprintf('Train and test data have
unequal independent variables.'))
if (any(vapply(train, is.factor, logical(1)))
| any(vapply(train, is.character, logical(1))))
stop("Train data contains non-numeric variables.
Please convert them into integer.")
if (any(vapply(test, is.factor, logical(1)))
| any(vapply(test, is.character, logical(1))))
stop("Test data contains non-numeric variables.
Please convert them into integer.")
# check in case target variable contains float values or NA values
if (any(is.na(y)))
stop("The target variable contains NA values.")
if (self$type=="class") {
if (is.numeric(y)){
if (!(all(y == floor(y))))
stop("The target variable contains float values")
}
}
return(list(train = train, test = test, y = y))
}
)
)