Skip to content

Commit

Permalink
added precision and recall statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Jul 13, 2016
1 parent 1928a43 commit 7d056de
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 4 deletions.
1 change: 1 addition & 0 deletions pkg/caret/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Suggests:
mda,
mgcv,
mlbench,
MLmetrics,
nnet,
party (>= 0.9-99992),
pls,
Expand Down
16 changes: 16 additions & 0 deletions pkg/caret/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ export(anovaScores,
expoTrans,
extractPrediction,
extractProb,
F_meas,
F_meas.default,
F_meas.table,
featurePlot,
filterVarImp,
findCorrelation,
Expand Down Expand Up @@ -162,6 +165,9 @@ export(anovaScores,
posPredValue.default,
posPredValue.table,
postResample,
precision,
precision.default,
precision.table,
predict.avNNet,
predict.bag,
predict.bagEarth,
Expand All @@ -186,7 +192,11 @@ export(anovaScores,
print.train,
probFunction,
progress,
prSummary,
R2,
recall,
recall.default,
recall.table,
resampleHist,
resamples,
resamples.default,
Expand Down Expand Up @@ -287,6 +297,9 @@ S3method(dummyVars, default)
S3method(BoxCoxTrans, default)
S3method(cluster, default)
S3method(expoTrans, default)
S3method(precision, default)
S3method(recall, default)
S3method(F_meas, default)

S3method(calibration, formula)
S3method(lift, formula)
Expand Down Expand Up @@ -452,6 +465,9 @@ S3method(confusionMatrix, table)

S3method(sensitivity, table)
S3method(specificity, table)
S3method(precision, table)
S3method(recall, table)
S3method(F_meas, table)

S3method(posPredValue, table)
S3method(negPredValue, table)
Expand Down
131 changes: 131 additions & 0 deletions pkg/caret/R/prec_rec.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
recall <- function(data, ...) UseMethod("recall")

"recall.table" <- function(data, relevant = rownames(data)[1], ...){
if(!all.equal(nrow(data), ncol(data))) stop("the table must have nrow = ncol")
if(!all.equal(rownames(data), colnames(data))) stop("the table must the same groups in the same order")

if(nrow(data) > 2) {
tmp <- data
data <- matrix(NA, 2, 2)

colnames(data) <- rownames(data) <- c("rel", "irrel")
irrelCol <- which(!(colnames(tmp) %in% irrelevant))
relCol <- which(colnames(tmp) %in% relevant)

data[1, 1] <- sum(tmp[relCol, relCol])
data[1, 2] <- sum(tmp[relCol, irrelCol])
data[2, 1] <- sum(tmp[irrelCol, relCol])
data[2, 2] <- sum(tmp[irrelCol, irrelCol])
data <- as.table(data)
irrelevant <- "irrel"
relevant <- "rel"
rm(tmp)
} else irrelevant <- rownames(data)[rownames(data) != relevant]
numer <- data[relevant, relevant]
denom <- sum(data[, relevant])
rec <- ifelse(denom > 0, numer / denom, NA)
rec
}

recall.default <- function(data, reference, relevant = levels(reference)[1],
na.rm = TRUE, ...) {
if (!is.factor(reference) | !is.factor(data))
stop("input data must be a factor")
if (length(unique(c(levels(reference), levels(data)))) != 2)
stop("input data must have the same two levels")
if (na.rm) {
cc <- complete.cases(data) & complete.cases(reference)
if (any(!cc)) {
data <- data[cc]
reference <- reference[cc]
}
}
xtab <- table(data, reference)
recall.table(xtab, relevant = relevant)
}

precision <- function(data, ...) UseMethod("precision")

precision.default <- function(data, reference, relevant = levels(reference)[1],
na.rm = TRUE, ...) {
if (!is.factor(reference) | !is.factor(data))
stop("input data must be a factor")
if (length(unique(c(levels(reference), levels(data)))) != 2)
stop("input data must have the same two levels")
if (na.rm) {
cc <- complete.cases(data) & complete.cases(reference)
if (any(!cc)) {
data <- data[cc]
reference <- reference[cc]
}
}
xtab <- table(data, reference)
precision.table(xtab, relevant = relevant)
}

precision.table <- function (data, relevant = rownames(data)[1], ...) {
if (!all.equal(nrow(data), ncol(data)))
stop("the table must have nrow = ncol")
if (!all.equal(rownames(data), colnames(data)))
stop("the table must the same groups in the same order")
if (nrow(data) > 2) {
tmp <- data
data <- matrix(NA, 2, 2)
colnames(data) <- rownames(data) <- c("rel", "irrel")
irrelCol <- which(!(colnames(tmp) %in% relevant))
relCol <- which(colnames(tmp) %in% relevant)
data[1, 1] <- sum(tmp[relCol, relCol])
data[1, 2] <- sum(tmp[relCol, irrelCol])
data[2, 1] <- sum(tmp[irrelCol, relCol])
data[2, 2] <- sum(tmp[irrelCol, irrelCol])
data <- as.table(data)
relevant <- "rel"
relevant
rm(tmp)
}
numer <- data[relevant, relevant]
denom <- sum(data[relevant, ])
spec <- ifelse(denom > 0, numer/denom, NA)
spec
}

F_meas <- function(data, ...) UseMethod("F_meas")

F_meas.default <- function(data, reference, relevant = levels(reference)[1],
beta = 1, na.rm = TRUE, ...) {
if (!is.factor(reference) | !is.factor(data))
stop("input data must be a factor")
if (length(unique(c(levels(reference), levels(data)))) != 2)
stop("input data must have the same two levels")
if (na.rm) {
cc <- complete.cases(data) & complete.cases(reference)
if (any(!cc)) {
data <- data[cc]
reference <- reference[cc]
}
}
xtab <- table(data, reference)
F_meas.table(xtab, relevant = relevant, beta = beta)
}

F_meas.table <- function (data, relevant = rownames(data)[1], beta = 1, ...) {
prec <- precision.table(data, relevant = relevant)
rec <- recall.table(data, relevant = relevant)
(1+beta^2)*prec*rec/((beta^2 * prec)+rec)
}

prSummary <- function (data, lev = NULL, model = NULL) {

requireNamespaceQuietStop("MLmetrics")
if (length(levels(data$obs)) > 2)
stop(paste("Your outcome has", length(levels(data$obs)),
"levels. The prSummary() function isn't appropriate."))
if (!all(levels(data[, "pred"]) == levels(data[, "obs"])))
stop("levels of observed and predicted data do not match")

c(AUC = MLmetrics::PRAUC(y_pred = data[, lev[1]], y_true = ifelse(data$obs == lev[1], 1, 0)),
Precision = precision.default(data = data$pred, reference = data$obs, relevant = lev[1]),
Recall = recall.default(data = data$pred, reference = data$obs, relevant = lev[1]),
F = F_meas.default(data = data$pred, reference = data$obs, relevant = lev[1]))
}

9 changes: 5 additions & 4 deletions pkg/caret/man/postResample.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
\alias{postResample}
\alias{defaultSummary}
\alias{twoClassSummary}
\alias{prSummary}
\alias{getTrainPerf}
\alias{mnLogLoss}
\alias{R2}
Expand All @@ -18,6 +19,7 @@ postResample(pred, obs)
defaultSummary(data, lev = NULL, model = NULL)

twoClassSummary(data, lev = NULL, model = NULL)
prSummary(data, lev = NULL, model = NULL)

mnLogLoss(data, lev = NULL, model = NULL)
multiClassSummary(data, lev = NULL, model = NULL)
Expand Down Expand Up @@ -59,6 +61,8 @@ For \code{defaultSummary} is the default function to compute performance metrics
}
where the \code{y} values are binary indicators for the classes and \code{p} are the predicted class probabilities.

\code{prSummary} (for precision and recall) computes values for the default 0.50 probability cutoff as well as the area under the precision-recall curve across all cutoffs and is labelled as \code{"AUC"} in the output. If assumes that the first level of the factor variables corresponds to a relevant result but the \code{lev} argument can be used to change this.

\code{multiClassSummary} computes some overall measures of for performance (e.g. overall accuracy and the Kappa statistic) and several averages of statistics calculated from "one-versus-all" configurations. For example, if there are three classes, three sets of sensitivity values are determined and the average is reported with the name ("Mean_Sensitivity"). The same is true for a number of statistics generated by \code{\link{confusionMatrix}}. With two classes, the basic sensitivity is reported with the name "Sensitivity"

To use \code{twoClassSummary} and/or \code{mnLogLoss}, the \code{classProbs} argument of \code{\link{trainControl}} should be \code{TRUE}. \code{multiClassSummary} can be used without class probabilities but some statistics (e.g. overall log loss and the average of per-class area under the ROC curves) will not be in the result set.
Expand Down Expand Up @@ -89,11 +93,8 @@ dat <- data.frame(obs = factor(sample(classes, 50, replace = TRUE)),

defaultSummary(dat, lev = classes)
twoClassSummary(dat, lev = classes)
prSummary(dat, lev = classes)
mnLogLoss(dat, lev = classes)




}
\keyword{utilities}

Expand Down
115 changes: 115 additions & 0 deletions pkg/caret/man/prec_recall.Rd
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
\name{recall}
\alias{recall}
\alias{recall.default}
\alias{recall.table}
\alias{precision}
\alias{precision.default}
\alias{precision.table}
\alias{precision.matrix}
\alias{F_meas}
\alias{F_meas.default}
\alias{F_meas.table}
\title{Calculate recall, precision and F values}
\description{
These functions calculate the recall, precision or F values of a measurement system
for finding/retrieving relevant documents compared to
reference results (the truth regarding relevance). The measurement and "truth"
data must have the same two possible outcomes and one of the outcomes
must be thought of as a "relevant" results.

The recall (aka specificity) is defined as the proportion of relevant results out of the number of
samples which were actually relevant. When there are no relevant results, recall is
not defined and a value of \code{NA} is returned.

The precision is percentage of predicted truly relevant results of the total number of
predicted relevant results
and characterizes the "purity in retrieval performance" (Buckland and Gey, 1994)

The measure "F" is a combination of precision and recall (see below).

}
\usage{
recall(data, ...)
\method{recall}{default}(data, reference, relevant = levels(reference)[1], na.rm = TRUE, ...)
\method{recall}{table}(data, relevant = levels(reference)[1], ...)

precision(data, ...)
\method{precision}{default}(data, reference, relevant = levels(reference)[1], na.rm = TRUE, ...)
\method{precision}{table}(data, relevant = levels(reference)[1], ...)

F_meas(data, ...)
\method{F_meas}{default}(data, reference, relevant = levels(reference)[1],
na.rm = TRUE, ...)
\method{F_meas}{table}(data, relevant = levels(reference)[1], ...)
}

\arguments{
\item{data}{for the default functions, a factor containing the discrete measurements. For the \code{table} function, a table.}
\item{reference}{a factor containing the reference values (i.e. truth)}
\item{relevant}{a character string that defines the factor level corresponding to
the "relevant" results}
\item{beta}{a numeric value used to weight precision and recall. A value of 1 is traditionally used and corresponds to the harmonic mean of the two values but other values weight recall beta times more important than precision. }
\item{na.rm}{a logical value indicating whether \code{NA} values should be stripped before the computation proceeds}
\item{...}{not currently used}
}

\details{
Suppose a 2x2 table with notation

\tabular{rcc}{
\tab Reference \tab \cr
Predicted \tab relevant \tab Irrelevant \cr
relevant \tab A \tab B \cr
Irrelevant \tab C \tab D \cr
}

The formulas used here are:
\deqn{recall = A/(A+C)}
\deqn{precision = A/(A+B)}
\deqn{F_i = (1+i^2)*prec*recall/((i^2 * precision)+recall)}

See the references for discussions of the statistics.
}

\value{
A number between 0 and 1 (or NA).
}

\author{Max Kuhn}
\references{Kuhn, M. (2008), ``Building predictive models in R using the caret package, '' \emph{Journal of Statistical Software}, (\url{http://www.jstatsoft.org/article/view/v028i05/v28i05.pdf}).

Buckland, M., & Gey, F. (1994). The relationship between Recall and Precision. \emph{Journal of the American Society for Information Science}, 45(1), 1219.

Powers, D. (2007). Evaluation: From Precision, Recall and F Factor to ROC, Informedness, Markedness and Correlation. Technical Report SIE-07-001, Flinders University
}
\seealso{\code{\link{confusionMatrix}}}

\examples{
###################
## Data in Table 2 of Powers (2007)

lvs <- c("Relevant", "Irrelevant")
tbl_2_1_pred <- factor(rep(lvs, times = c(42, 58)), levels = lvs)
tbl_2_1_truth <- factor(c(rep(lvs, times = c(30, 12)),
rep(lvs, times = c(30, 28))),
levels = lvs)
tbl_2_1 <- table(tbl_2_1_pred, tbl_2_1_truth)

precision(tbl_2_1)
precision(data = tbl_2_1_pred, reference = tbl_2_1_truth, relevant = "Relevant")
recall(tbl_2_1)
recall(data = tbl_2_1_pred, reference = tbl_2_1_truth, relevant = "Relevant")


tbl_2_2_pred <- factor(rep(lvs, times = c(76, 24)), levels = lvs)
tbl_2_2_truth <- factor(c(rep(lvs, times = c(56, 20)),
rep(lvs, times = c(12, 12))),
levels = lvs)
tbl_2_2 <- table(tbl_2_2_pred, tbl_2_2_truth)

precision(tbl_2_2)
precision(data = tbl_2_2_pred, reference = tbl_2_2_truth, relevant = "Relevant")
recall(tbl_2_2)
recall(data = tbl_2_2_pred, reference = tbl_2_2_truth, relevant = "Relevant")
}
\keyword{manip}

0 comments on commit 7d056de

Please sign in to comment.