-
Notifications
You must be signed in to change notification settings - Fork 1
/
Function_Predict.R
55 lines (49 loc) · 1.93 KB
/
Function_Predict.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Predict Function
PredictFunction <- function(i.model = ptf_model,
pre.proc = usksat.pre,
in_data = soil.dt,
Alg = Model_Alg){
# Returns predicted Ks data frame
#
# Args:
# i.model: One of BRT or RF model files
# pre.proc: The centering and scaling file for the model.
# in_data: Data frame of soil variables with column names that exist in models
# Alg: Machine learning algorithm, RF or BRT
#
# Returns:
# A list containing
# 1. Data frame of predictions. If Alg= "RF" it returns the statistical
# summary of prediction including mean
# 2. List of predictors used in model(useful to verify model hierarchy used)
#
#
# Pre-proccess data by centering and scaling
# extra_cols because the centering and scaling file has those...
extra_cols <- c("d10", "d50", "d60", "logCU", "d10_2", "d50_2",
"d60_2", "logCU_2")
in_data[,extra_cols] <- NA
in_scaled.dt <- predict(pre.proc, in_data)
# subset data to predictors required by selected model
model_p <- predictors(i.model)
in_sub.dt <- subset(in_scaled.dt, select = model_p)
if(Alg == "RF"){
Ks_predicted <- predict(i.model$finalModel, newdata = in_sub.dt,
predict.all=TRUE)
Ks_predicted <- data.frame(t(Ks_predicted$individual))
Ks_predicted <- t(sapply(Ks_predicted, summary))
Ks_predicted <- data.frame(Ks_predicted)
}else if(Alg=="BRT"){
# Get number of trees from model
n_trees <- i.model$n.trees
#predict
Ks_predicted <- predict(i.model,
newdata = in_sub.dt,
n.trees = n_trees)
Ks_predicted <- data.frame(Ks_predicted)
}
#Prepare output
model_p_print <- paste(model_p, collapse = ", ")
return(list(Ks_predicted, model_p_print))
}
#