Skip to content

Commit

Permalink
fixes #27 and does much work on check()
Browse files Browse the repository at this point in the history
  • Loading branch information
zmjones committed Mar 23, 2015
1 parent 74ecc90 commit 5afbba8
Show file tree
Hide file tree
Showing 14 changed files with 297 additions and 148 deletions.
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,5 @@ Suggests:
doParallel,
testthat
LinkingTo: Rcpp, RcppArmadillo
Depends: RcppArmadillo
BugReports: https://github.com/zmjones/edarf

1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ importFrom(party,proximity)
importFrom(party,varimp)
importFrom(party,varimpAUC)
importFrom(reshape2,melt)
importFrom(stats,predict)
useDynLib(edarf)
8 changes: 1 addition & 7 deletions R/edarf.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
#' edarf: Exploratory Data Analysis Using Random Forests
#'
#' This package provides utilities for exploratory data analysis (EDA) with random forests. It allows the calculation and visualization of partial dependence for single and multiple predictors, interactiondetection and proximity of observations.
#' This package provides utilities for exploratory data analysis (EDA) with random forests. It allows the calculation and visualization of partial dependence for single and multiple predictors, interaction detection and proximity of observations.
#'
#' @section Partial Dependence
#'
#' @section Interaction Detection
#'
#' @section Proximity
#'
#' @docType package
#' @name edarf
#' @useDynLib edarf
Expand Down
20 changes: 10 additions & 10 deletions R/imp.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ variable_importance <- function(fit, ...) UseMethod("variable_importance")
#' plot_imp(imp)
#' }
#' @export
variable_importance.randomForest <- function(fit, type = "accuracy", class_levels, ...) {
variable_importance.randomForest <- function(fit, type = "accuracy", class_levels = FALSE, ...) {
if (ncol(fit$importance) == 1 & type != "gini")
stop("set importance = TRUE in call to randomForest")
if (is.null(fit$localImportance) & type == "local")
Expand All @@ -44,15 +44,16 @@ variable_importance.randomForest <- function(fit, type = "accuracy", class_level
else if (type == "gini")
out <- fit$importance[, "MeanDecreaseGini"]
else if (type == "local") {
out <- fit$localImportance
row.names(out) <- NULL
out <- t(fit$localImportance)
} else
stop("Invalid type or fit input combination")

if (is.matrix(out)) {
out <- as.data.frame(out)
out$labels <- row.names(out)
row.names(out) <- NULL
if (type != "local") {
out$labels <- row.names(out)
row.names(out) <- NULL
}
} else
out <- data.frame(value = unname(out), labels = names(out))

Expand Down Expand Up @@ -85,18 +86,17 @@ variable_importance.randomForest <- function(fit, type = "accuracy", class_level
#' }
#' @export
variable_importance.RandomForest <- function(fit, conditional = FALSE, auc = FALSE, ...) {
if (auc & !(class(fit@responses@variables[, 1]) == "factor" &
length(levels(fit@responses@variables[, 1])) == 2))
if (auc & !(nrow(unique(fit@responses@variables)) == 2))
stop("auc only applicable to binary classification")

if (conditional)
conditional <- TRUE
else conditional <- FALSE

if (auc)
out <- varimpAUC(fit, conditional = conditional, ...)
out <- party::varimpAUC(fit, conditional = conditional, ...)
else
out <- varimp(fit, conditional = conditional, ...)
out <- party::varimp(fit, conditional = conditional, ...)

out <- data.frame("value" = out, "labels" = names(out), row.names = 1:length(out))

Expand Down Expand Up @@ -128,7 +128,7 @@ variable_importance.RandomForest <- function(fit, conditional = FALSE, auc = FAL
#' variable_importance(fit, "random", TRUE)
#' }
#' @export
variable_importance.rfsrc <- function(fit, ..., type = "permute", class_levels = FALSE) {
variable_importance.rfsrc <- function(fit, type = "permute", class_levels = FALSE, ...) {
if (!type %in% as.character(fit$call))
stop(paste("call rfsrc with importance =", type))

Expand Down
43 changes: 28 additions & 15 deletions R/pd.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#' from a fitted random forest object from the party, randomForest, or randomForestSRC packages
#'
#' @importFrom foreach foreach %dopar% %do% %:% getDoParWorkers
#' @importFrom stats predict
#' @param fit object of class 'RandomForest', 'randomForest', or 'rfsrc'
#' @param ... arguments to be passed to \code{partial_dependence}
#'
Expand Down Expand Up @@ -127,13 +128,14 @@ partial_dependence.randomForest <- function(fit, df, var, cutoff = 10, interacti
pred$low <- pred[, names(y_class)] - cl * se
pred$high <- pred[, names(y_class)] + cl * se
}
if (length(var) == 1 & type != "prob") pred <- fix_classes(c(var, names(y_class)), df, pred)

attr(pred, "class") <- c("pd", "data.frame")
attr(pred, "prob") <- type == "prob"
attr(pred, "interaction") <- length(var) > 1
attr(pred, "multivariate") <- FALSE
attr(pred, "var") <- var
attr(pred, "ci") <- ci
pred <- fix_classes(df, pred)
pred
}
#' Partial dependence for RandomForest objects from package \code{party}
Expand Down Expand Up @@ -212,6 +214,9 @@ partial_dependence.RandomForest <- function(fit, var, cutoff = 10, interaction =
inner_loop <- function(df, rng, idx, var, var_class) {
## fix var predictiors
df[, var] <- rng[idx, ]
## fixme
## should figure out what is generating the coercion
## that necessitates the below code
if (length(var) == 1) {
if (class(df[, var]) != var_class) class(df[, var]) <- var_class
} else if (any(sapply(df[, var], class) != var_class)) {
Expand Down Expand Up @@ -248,14 +253,17 @@ partial_dependence.RandomForest <- function(fit, var, cutoff = 10, interaction =
}
i <- x <- idx <- out <- NULL ## initialize to avoid R CMD check errors
if (is.data.frame(rng)) {
## fixme
## should figure out what is generating the coercion
## that necessitates the below code
if (length(var) > 1) {
var_class <- sapply(df[, var], class)
} else var_class <- class(df[, var])
pred <- foreach(i = 1:nrow(rng), .packages = pkg) %op% inner_loop(df, rng, i, var, var_class)
pred <- as.data.frame(do.call(rbind, lapply(pred, unlist)), stringsAsFactors = FALSE)
colnames(pred)[1:length(var)] <- var
if (type != "prob" & (!ci | !(class(y[, 1]) %in% c("numeric", "integer"))) & ncol(y) == 1)
colnames(pred)[ncol(pred)] <- colnames(y)
colnames(pred)[ncol(pred)] <- colnames(y)
} else {
pred <- foreach(x = var, .packages = pkg) %:%
foreach(idx = 1:nrow(rng[[x]]), .combine = rbind) %op% inner_loop(df, rng[[x]], idx, x, class(df[, x]))
Expand Down Expand Up @@ -287,6 +295,7 @@ partial_dependence.RandomForest <- function(fit, var, cutoff = 10, interaction =
attr(pred, "multivariate") <- dim(y)[2] != 1
attr(pred, "var") <- var
attr(pred, "ci") <- ci
pred <- fix_classes(df, pred)
pred
}
#' Partial dependence for rfsrc objects from package \code{randomForestSRC}
Expand Down Expand Up @@ -345,7 +354,7 @@ partial_dependence.rfsrc <- function(fit, var, cutoff = 10, interaction = FALSE,
y <- fit$yvar
if (!(class(y) %in% c("numeric", "integer"))) ci <- FALSE
if (length(var) == 1) interaction <- FALSE
df <- data.frame(fit$xvar, y)
df <- data.frame(fit$xvar, y) ## rfsrc casts integers to numerics
if (!is.data.frame(y)) colnames(df)[ncol(df)] <- fit$yvar.names
if (interaction) {
rng <- expand.grid(lapply(var, function(x) ivar_points(df, x, cutoff, empirical)))
Expand Down Expand Up @@ -398,11 +407,9 @@ partial_dependence.rfsrc <- function(fit, var, cutoff = 10, interaction = FALSE,
}
if (((length(var) > 1 & interaction) | length(var) == 1) &
(!ci & !(type == "prob")) & !is.data.frame(fit$yvar)) {
pred <- fix_classes(c(var, fit$yvar.names), df, pred)
} else if (is.data.frame(fit$yvar)) {
if (length(var) == 1 | interaction) {
colnames(pred)[ncol(pred)] <- "chf"
pred[, -ncol(pred)] <- fix_classes(var, df, pred[, -ncol(pred)])
} else colnames(pred)[ncol(pred) - 1] <- "chf"
} else if (ci & class(y) %in% c("numeric", "integer")) {
if (length(var) == 1 | interaction) colnames(pred)[ncol(pred) - 1] <- fit$yvar.names
Expand All @@ -418,6 +425,7 @@ partial_dependence.rfsrc <- function(fit, var, cutoff = 10, interaction = FALSE,
attr(pred, "multivariate") <- FALSE
attr(pred, "var") <- var
attr(pred, "ci") <- ci
pred <- fix_classes(df, pred)
pred
}
#' Creates a prediction vector for variables to decrease computation time
Expand Down Expand Up @@ -445,22 +453,27 @@ ivar_points <- function(df, x, cutoff = 10, empirical = TRUE) {
}
#' Matches column classes of the input data frame to the output
#'
#' @param var character vector of column names to match
#' @param df imput dataframe
#' @param pred output dataframe
#'
#' @return dataframe \code{pred} with \code{var} column classes matched to those in \code{df}
#' @return dataframe \code{pred} with column classes matched to those in \code{df}
#'
#' @export
fix_classes <- function(var, df, pred) {
for (x in var) {
if (class(df[, x]) == "factor")
pred[, x] <- factor(pred[, x])
else if (class(df[, x]) == "numeric") {
if (any(df[, x] %% 1 != 0))
fix_classes <- function(df, pred) {
for (x in colnames(pred)) {
if (x %in% colnames(df)) {
if (class(df[, x]) == "factor")
pred[, x] <- factor(pred[, x])
else if (class(df[, x]) == "numeric") {
pred[, x] <- as.numeric(pred[, x])
else pred[, x] <- as.integer(pred[, x])
}
} else if (class(df[, x]) == "integer") {
pred[, x] <- as.integer(pred[, x])
} else if (class(df[, x]) == "character") {
pred[, x] <- as.character(pred[, x])
} else {
stop("column of unsupported type input")
}
}
}
pred
}
1 change: 1 addition & 0 deletions R/prox.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#'
#' Extracts proximity matrices from random forest objects from the party, randomForest or randomForestSRC packages
#'
#' @importFrom stats predict
#' @param fit object of class 'RandomForest', 'randomForest', or 'rfsrc'
#' @param newdata new data with the same columns as the data used for \code{fit}
#' @param ... arguments to be passed to \code{extract_proximity}
Expand Down
13 changes: 1 addition & 12 deletions man/edarf.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,6 @@
\alias{edarf-package}
\title{edarf: Exploratory Data Analysis Using Random Forests}
\description{
This package provides utilities for exploratory data analysis (EDA) with random forests. It allows the calculation and visualization of partial dependence for single and multiple predictors, interactiondetection and proximity of observations.
}
\section{Partial Dependence}{

}

\section{Interaction Detection}{

}

\section{Proximity}{

This package provides utilities for exploratory data analysis (EDA) with random forests. It allows the calculation and visualization of partial dependence for single and multiple predictors, interaction detection and proximity of observations.
}

6 changes: 2 additions & 4 deletions man/fix_classes.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@
\alias{fix_classes}
\title{Matches column classes of the input data frame to the output}
\usage{
fix_classes(var, df, pred)
fix_classes(df, pred)
}
\arguments{
\item{var}{character vector of column names to match}

\item{df}{imput dataframe}

\item{pred}{output dataframe}
}
\value{
dataframe \code{pred} with \code{var} column classes matched to those in \code{df}
dataframe \code{pred} with column classes matched to those in \code{df}
}
\description{
Matches column classes of the input data frame to the output
Expand Down
73 changes: 0 additions & 73 deletions man/partial_dependence.RandomForest.Rd

This file was deleted.

8 changes: 4 additions & 4 deletions man/variable_importance.rfsrc.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
\alias{variable_importance.rfsrc}
\title{Variable importance for rfsrc objects}
\usage{
\method{variable_importance}{rfsrc}(fit, ..., type = "permute",
class_levels = FALSE)
\method{variable_importance}{rfsrc}(fit, type = "permute",
class_levels = FALSE, ...)
}
\arguments{
\item{fit}{an object of class 'rfsrc' returned from \code{rfsrc}}

\item{...}{further arguments to be passed to nothing}

\item{type}{character equal to "permute", "random", "permute.ensemble", or "random.ensemble"
this the \code{permute} argument must equal this value in the call to rfsrc}

\item{class_levels}{logical, when TRUE class level specific importances are returned otherwise the overal importance is returned}

\item{...}{further arguments to be passed to nothing}
}
\value{
a data.frame of class "importance"
Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/test_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,19 @@ test_that("ivar_points works correctly", {
expect_that(length(unique(ivar_points(df, "x", 10))), equals(length(ivar_points(df, "x", 10))))
expect_that(ivar_points(df, "x", nrow(df)), equals(df[, "x"]))
})

test_that("fix_classes works correctly", {
df <- data.frame(x = letters,
y = factor(letters),
z = seq(0, 1, length.out = length(letters)),
a = 1:length(letters), stringsAsFactors = FALSE)
ndf <- data.frame(x = as.factor(df$x),
y = as.character(df$y),
z = as.character(df$z),
a = as.numeric(df$a), stringsAsFactors = FALSE)
out <- fix_classes(df, ndf)
expect_that(out$x, is_a("character"))
expect_that(out$y, is_a("factor"))
expect_that(out$z, is_a("numeric"))
expect_that(out$a, is_a("integer"))
})
Loading

0 comments on commit 5afbba8

Please sign in to comment.