Skip to content

Commit

Permalink
Include option to predict 95% confidence interval with raw prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
slagtermaarten committed Jan 11, 2022
1 parent 5775df9 commit 1c1c922
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions R/kknn.R
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ kknn <- function (formula = formula(train), train, test, na.action=na.omit(),
fit[i] <- min((1:l)[weightClass[i, ] >= 0.5])
}
fit <- ordered(fit, levels = 1:l, labels = lev)
CI_table <- NULL
} else if (response == "nominal") {
fit <- apply(weightClass, 1, order, decreasing = TRUE)[1,]
fit <- factor(fit, levels = 1:l, labels = lev)
Expand All @@ -301,17 +302,36 @@ kknn <- function (formula = formula(train), train, test, na.action=na.omit(),
fitv[i] <- which(nM[i,]==min(nM[i,]))
}
fit[indices] <- factor(fitv[indices], levels = 1:l, labels = lev)
CI_table <- NULL
}
} else if (response == "continuous") {
fit <- rowSums(W*CL)/pmax(rowSums(W), 1e-6)
weight_sum <- pmax(rowSums(W), 1e-6)
fit <- rowSums(W*CL) / weight_sum
## Kish design effect-based SE as per:
## https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
sigma2 = apply(CL, 1, var)
wm <- rowMeans(CL)
wm2 <- rowMeans(CL^2)
k <- ncol(CL)
var_y <- sigma2 / k * pmax(wm2, 1e-6) / pmax(wm, 1e-6)
sigma_y <- sqrt(var_y) / sqrt(k)
CI_table <-
data.frame(
y = fit,
CI_l = fit - 1.96 * sigma_y,
CI_h = fit + 1.96 * sigma_y
)
}
options('contrasts'=old.contrasts)

result <- list(
fitted.values=fit, CL=CL, W=W, D=D, C=C,
fitted.values=fit,
CL=CL, W=W, D=D, C=C,
prob=weightClass,
response=response, distance=distance, call=ca, terms=mt
)
if (!is.null(CI_table))
result[['CI_table']] <- CI_table
class(result) <- 'kknn'
result
}
Expand Down Expand Up @@ -372,14 +392,15 @@ summary.kknn <- function(object, ...)
}


predict.kknn <- function(object, type = c("raw", "prob"), ...)
predict.kknn <- function(object, type = c("raw", "raw_CI", "prob", "all"), ...)
{
call <- object$call
extras <- match.call(expand.dots = FALSE)$...
if (length(extras)) {
names(extras)[names(extras) == "new.data"] = "test"
existing <- !is.na(match(names(extras), c("test", "k", "distance",
"kernel", "ykernel", "scale", "contrasts")))
existing <- !is.na(match(names(extras),
c("test", "k", "distance", "kernel", "ykernel", "scale",
"contrasts")))
for (a in names(extras)[existing]) call[[a]] <- extras[[a]]
# if (any(!existing)) {
# call <- c(as.list(call), extras[!existing])
Expand All @@ -388,8 +409,13 @@ predict.kknn <- function(object, type = c("raw", "prob"), ...)
object <- eval(call, object, parent.frame())
}
type <- match.arg(type)
if(type=="raw") return(object$fit)
if(type=="prob") return(object$prob)
if (type == "raw") return(object$fit)
if (type == "raw_CI") {
stopifnot(object$response == 'continuous')
return(object$CI_table)
}
if (type == "prob") return(object$prob)
if (type == "all") return(object)
return(NULL)
}

Expand Down

0 comments on commit 1c1c922

Please sign in to comment.