diff --git a/pkg/caret/R/findCorrelation.R b/pkg/caret/R/findCorrelation.R index 356e30aa0..dec9f39d3 100644 --- a/pkg/caret/R/findCorrelation.R +++ b/pkg/caret/R/findCorrelation.R @@ -1,4 +1,32 @@ -findCorrelation <- function(x, cutoff = 0.90, verbose = FALSE) + +findCorrelation_fast <- function(x, cutoff = .90, verbose = FALSE){ + averageCorr <- colMeans(abs(x)) + averageCorr <- as.numeric(as.factor(averageCorr)) + x[lower.tri(x, diag = TRUE)] <- NA + combsAboveCutoff <- which(abs(x) > cutoff) + + colsToCheck <- ceiling(combsAboveCutoff / nrow(x)) + rowsToCheck <- combsAboveCutoff %% nrow(x) + + colsToDiscard <- averageCorr[colsToCheck] > averageCorr[rowsToCheck] + rowsToDiscard <- !colsToDiscard + + if(verbose){ + colsFlagged <- pmin(ifelse(colsToDiscard, colsToCheck, NA), + ifelse(rowsToDiscard, rowsToCheck, NA), na.rm = TRUE) + values <- round(x[combsAboveCutoff], 3) + cat('\n',paste('Combination row', rowsToCheck, 'and column', colsToCheck, + 'is above the cut-off, value =', values, + '\n \t Flagging column', colsFlagged, '\n' + )) + } + + deletecol <- c(colsToCheck[colsToDiscard], rowsToCheck[rowsToDiscard]) + deletecol <- unique(deletecol) + deletecol +} + +findCorrelation_exact <- function(x, cutoff = 0.90, verbose = FALSE) { varnum <- dim(x)[1] @@ -58,3 +86,18 @@ findCorrelation <- function(x, cutoff = 0.90, verbose = FALSE) } newOrder[which(deletecol)] } + + +findCorrelation <- function(x, cutoff = 0.90, verbose = FALSE, names = FALSE, exact = ncol(x) < 100) { + if(names & is.null(colnames(x))) + stop("'x' must have column names when `names = TRUE`") + out <- if(exact) + findCorrelation_exact(x = x, cutoff = cutoff, verbose = verbose) else + findCorrelation_fast(x = x, cutoff = cutoff, verbose = verbose) + out + if(names) out <- colnames(x)[out] + out +} + + + diff --git a/pkg/caret/man/findCorrelation.Rd b/pkg/caret/man/findCorrelation.Rd index 11dae7ed7..89b694ac8 100644 --- a/pkg/caret/man/findCorrelation.Rd +++ b/pkg/caret/man/findCorrelation.Rd @@ -6,47 +6,62 @@ This function searches through a correlation matrix and returns a vector of inte corresponding to columns to remove to reduce pair-wise correlations. } \usage{ -findCorrelation(x, cutoff = .90, verbose = FALSE) +findCorrelation(x, cutoff = .90, verbose = FALSE, + names = FALSE, exact = ncol(x) < 100) } \arguments{ \item{x}{A correlation matrix} \item{cutoff}{A numeric value for the pair-wise absolute correlation cutoff} \item{verbose}{A boolean for printing the details} + \item{names}{a logical; should the column names be returned (\code{TRUE}) or the column index (\code{FALSE})?} + \item{exact}{a logical; should the average correlations be recomputed at each step? See Details below.} } \details{ The absolute values of pair-wise correlations are considered. If two variables have a high correlation, the function looks at the mean absolute correlation of each variable and removes the variable with the largest mean absolute correlation. + Using \code{exact = TRUE} will cause the function to re-evaluate the average correlations at each step + while \code{exact = FALSE} uses all the correlations regardless of whether they have been eliminated or + not. The exact calculations will remove a smaller number of predictors but can be much slower + when the problem dimensions are "big". + There are several function in the \pkg{subselect} package (\code{\link[subselect:eleaps]{leaps}}, \code{\link[subselect:genetic]{genetic}}, \code{\link[subselect:anneal]{anneal}}) that can also be used - to accomplish the same goal. + to accomplish the same goal but tend to retain more predictors. } \value{ - A vector of indices denoting the columns to remove. If no correlations meet the criteria, \code{numeric(0)} is returned. + A vector of indices denoting the columns to remove (when \code{names = TRUE}) otherwise a vector of column names. If no correlations meet the criteria, \code{integer(0)} is returned. } \author{Original R code by Dong Li, modified by Max Kuhn} \seealso{\code{\link[subselect:eleaps]{leaps}}, \code{\link[subselect:genetic]{genetic}}, \code{\link[subselect:anneal]{anneal}}, \code{\link{findLinearCombos}}} \examples{ -corrMatrix <- diag(rep(1, 5)) -corrMatrix[2, 3] <- corrMatrix[3, 2] <- .7 -corrMatrix[5, 3] <- corrMatrix[3, 5] <- -.7 -corrMatrix[4, 1] <- corrMatrix[1, 4] <- -.67 +R1 <- structure(c(1, 0.86, 0.56, 0.32, 0.85, 0.86, 1, 0.01, 0.74, 0.32, + 0.56, 0.01, 1, 0.65, 0.91, 0.32, 0.74, 0.65, 1, 0.36, + 0.85, 0.32, 0.91, 0.36, 1), + .Dim = c(5L, 5L)) +colnames(R1) <- rownames(R1) <- paste0("x", 1:ncol(R1)) +R1 + +findCorrelation(R1, cutoff = .6, exact = FALSE) +findCorrelation(R1, cutoff = .6, exact = TRUE) +findCorrelation(R1, cutoff = .6, exact = TRUE, names = FALSE) -corrDF <- expand.grid(row = 1:5, col = 1:5) -corrDF$correlation <- as.vector(corrMatrix) -levelplot(correlation ~ row+ col, corrDF) -findCorrelation(corrMatrix, cutoff = .65, verbose = TRUE) +R2 <- diag(rep(1, 5)) +R2[2, 3] <- R2[3, 2] <- .7 +R2[5, 3] <- R2[3, 5] <- -.7 +R2[4, 1] <- R2[1, 4] <- -.67 -findCorrelation(corrMatrix, cutoff = .99, verbose = TRUE) +corrDF <- expand.grid(row = 1:5, col = 1:5) +corrDF$correlation <- as.vector(R2) +levelplot(correlation ~ row + col, corrDF) + +findCorrelation(R2, cutoff = .65, verbose = TRUE) -\dontshow{ - removeCols <- findCorrelation(corrMatrix, cutoff = .65, verbose = FALSE) - if(!isTRUE(all.equal(corrMatrix[-removeCols, -removeCols], diag(rep(1, 3))))) stop("test 1 failed") - if(!isTRUE(all.equal( findCorrelation(corrMatrix, .99, verbose = FALSE), numeric(0)))) stop("test 2 failed") - } +findCorrelation(R2, cutoff = .99, verbose = TRUE) } + \keyword{manip} diff --git a/pkg/caret/vignettes/caret.Rnw b/pkg/caret/vignettes/caret.Rnw index d156c1582..5c42a0c99 100644 --- a/pkg/caret/vignettes/caret.Rnw +++ b/pkg/caret/vignettes/caret.Rnw @@ -91,6 +91,7 @@ <>= +library(MASS) library(caret) library(mlbench) data(Sonar)