Skip to content

Commit

Permalink
Vectorize coords function (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
xrobin committed May 4, 2019
1 parent a812f94 commit b5b48e7
Showing 1 changed file with 61 additions and 51 deletions.
112 changes: 61 additions & 51 deletions R/coords.R
Original file line number Diff line number Diff line change
Expand Up @@ -223,71 +223,81 @@ coords.roc <- function(roc, x, input=c("threshold", "specificity", "sensitivity"
}
}
else if (is.numeric(x)) {
if (length(x) > 1) { # make this function a vector function
if (as.list) {
res <- lapply(x, function(x) coords.roc(roc, x, input, ret, as.list))
names(res) <- x
}
else {
res <- sapply(x, function(x) coords.roc(roc, x, input, ret, as.list))
if (length(ret) == 1) {# sapply returns a vector instead of a matrix
res <- t(res)
rownames(res) <- ret
}
colnames(res) <- x
}
return(res)
}
if (input == "threshold") {
res <- c(x, as.vector(roc.utils.perfs(x, roc$controls, roc$cases, roc$direction)) * ifelse(roc$percent, 100, 1))
# We must match every threshold given by the user to one of our
# selected threshold. However we need to be careful to assign
# them to the right one around the exact data point values
cut_points <- sort(unique(roc$predictor))
if (roc$direction == "<") {
cut_points <- c(cut_points, Inf)
thr_idx <- sapply(x, function(t) min(which(cut_points >= t)))
}
else {
cut_points <- c(rev(cut_points), Inf)
thr_idx <- sapply(x, function(t) min(which(cut_points <= t)))
}
res <- rbind(
threshold = x, # roc$thresholds[thr_idx], # user-supplied vs ours.
specificity = roc$specificities[thr_idx],
sensitivity = roc$sensitivities[thr_idx]
)
}
if (input == "specificity") {
if (x < 0 || x > ifelse(roc$percent, 100, 1))
stop("Input specificity not within the ROC space.")
if (x %in% roc$sp) {
idx <- match(x, roc$sp)
res <- c(roc$thresholds[idx], roc$sp[idx], roc$se[idx])
}
else { # need to interpolate
idx.next <- match(TRUE, roc$sp > x)
proportion <- (x - roc$sp[idx.next - 1]) / (roc$sp[idx.next] - roc$sp[idx.next - 1])
int.se <- roc$se[idx.next - 1] - proportion * (roc$se[idx.next - 1] - roc$se[idx.next])
res <- c(NA, x, int.se)
}
if (any(x < 0) || any(x > ifelse(roc$percent, 100, 1))) {
stop("Input specificity not within the ROC space.")
}
res <- matrix(nrow=3, ncol=length(x))
for (i in seq_along(x)) {
sp <- x[i]
if (sp %in% roc$sp) {
idx <- match(sp, roc$sp)
res[, i] <- c(roc$thresholds[idx], roc$sp[idx], roc$se[idx])
}
else { # need to interpolate
idx.next <- match(TRUE, roc$sp > sp)
proportion <- (sp - roc$sp[idx.next - 1]) / (roc$sp[idx.next] - roc$sp[idx.next - 1])
int.se <- roc$se[idx.next - 1] - proportion * (roc$se[idx.next - 1] - roc$se[idx.next])
res[, i] <- c(NA, sp, int.se)
}
}
}
if (input == "sensitivity") {
if (x < 0 || x > ifelse(roc$percent, 100, 1))
stop("Input sensitivity not within the ROC space.")
if (x %in% roc$se) {
idx <- length(roc$se) + 1 - match(TRUE, rev(roc$se) == x)
res <- c(roc$thresholds[idx], roc$sp[idx], roc$se[idx])
}
else { # need to interpolate
idx.next <- match(TRUE, roc$se < x)
proportion <- (x - roc$se[idx.next]) / (roc$se[idx.next - 1] - roc$se[idx.next])
int.sp <- roc$sp[idx.next] + proportion * (roc$sp[idx.next - 1] - roc$sp[idx.next])
res <- c(NA, int.sp, x)
}
if (x < 0 || x > ifelse(roc$percent, 100, 1)) {
stop("Input sensitivity not within the ROC space.")
}
res <- matrix(nrow=3, ncol=length(x))
for (i in seq_along(x)) {
se <- x[i]
if (se %in% roc$se) {
idx <- length(roc$se) + 1 - match(TRUE, rev(roc$se) == se)
res[, i] <- c(roc$thresholds[idx], roc$sp[idx], roc$se[idx])
}
else { # need to interpolate
idx.next <- match(TRUE, roc$se < se)
proportion <- (se - roc$se[idx.next]) / (roc$se[idx.next - 1] - roc$se[idx.next])
int.sp <- roc$sp[idx.next] + proportion * (roc$sp[idx.next - 1] - roc$sp[idx.next])
res[, i] <- c(NA, int.sp, se)
}
}
}
# Deduce additional tn, tp, fn, fp, npv, ppv
ncases <- ifelse(methods::is(roc, "smooth.roc"), length(attr(roc, "roc")$cases), length(roc$cases))
ncontrols <- ifelse(methods::is(roc, "smooth.roc"), length(attr(roc, "roc")$controls), length(roc$controls))
se <- res[3]
sp <- res[2]
se <- res[3, ]
sp <- res[2, ]

substr.percent <- ifelse(roc$percent, 100, 1)
co <- roc.utils.calc.coords(substr.percent,
se, sp, ncases, ncontrols)
co <- cbind(threshold = res[1], co)

co <- cbind(threshold = res[1, ], co)
rownames(co) <- x

if (as.list) {
list <- as.list(co)
list <- list[ret]
if (drop == FALSE) {
list <- list(list)
names(list) <- x
}
return(list)
list <- apply(co[, ret, drop=FALSE], 1, as.list)
if (drop == TRUE && length(x) == 1) {
return(list[[1]])
}
return(list)
}
else {
res <- t(co)
Expand Down

0 comments on commit b5b48e7

Please sign in to comment.