Skip to content

Commit

Permalink
adds asserts to pd
Browse files Browse the repository at this point in the history
 - checks all inputs
 - adds na checking for plotting functions
 - updated roxygen2
  • Loading branch information
zmjones committed Oct 3, 2014
1 parent 9765948 commit 8439b8c
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 12 deletions.
7 changes: 6 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Generated by roxygen2 (4.0.1): do not edit by hand
# Generated by roxygen2 (4.0.2): do not edit by hand

export(ivar_points)
export(partial_dependence)
export(plot_imp)
export(plot_twoway_partial)
import(ggplot2)
importFrom(assertthat,assert_that)
importFrom(assertthat,is.count)
importFrom(assertthat,is.flag)
importFrom(assertthat,is.string)
importFrom(assertthat,noNA)
importFrom(assertthat,on_failure)
importFrom(foreach,"%do%")
importFrom(foreach,"%dopar%")
importFrom(foreach,foreach)
14 changes: 11 additions & 3 deletions R/pd.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#' from a fitted random forest object from the Party, randomForest, or randomForestSRC packages
#'
#' @importFrom foreach foreach %dopar% %do%
#' @importFrom assertthat assert_that on_failure is.string is.count is.flag noNA
#'
#' @param fit an object of class 'RandomForest-class' returned from \code{cforest}, an object
#' of class 'randomForest' returned from \code{randomForest}, or an object of class 'rfsrc'
Expand Down Expand Up @@ -74,13 +75,20 @@
#' }
#' @export
partial_dependence <- function(fit, df, var, cutoff = 10, empirical = TRUE) {
assert_that(any(class(fit) %in% c("RandomForest", "randomForest", "randomForestSRC")))
assert_that(is.data.frame(df))
assert_that(is.string(var))
assert_that(is.count(cutoff))
assert_that(is.flag(empirical))
assert_that(cutoff > nrow)

if (any(class(fit) == "RandomForest")) {
df <- data.frame(get("input", fit@data@env), get("response", fit@data@env))
type <- class(df[, ncol(df)])
y <- colnames(df[, ncol(df)])
pkg <- "party"
}
else if (any(class(fit) == "randomForest")) {
} else if (any(class(fit) == "randomForest")) {
assert_that(noNA(df))
type <- attr(fit$terms, "dataClasses")[1]
y <- attr(attr(fit$terms, "dataClasses"), "names")[1]
pkg <- "randomForest"
Expand All @@ -95,7 +103,7 @@ partial_dependence <- function(fit, df, var, cutoff = 10, empirical = TRUE) {
type <- "survival"
else
type <- class(df[, y])
} else stop("Unsupported fit object class")
}

rng <- expand.grid(lapply(var, function(x) ivar_points(df, x, cutoff, empirical)))

Expand Down
11 changes: 9 additions & 2 deletions R/plot.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Plot permutation importance from random
#' Plot permutation importance from random forests
#'
#' @import ggplot2
#' @importFrom assertthat assert_that
#' @importFrom assertthat assert_that noNA
#'
#' @param var character or factor vector of variable labels
#' @param imp numeric vector of variable permutation importance estimates
Expand All @@ -16,6 +16,8 @@ plot_imp <- function(var, imp, ylab = NULL, xlab = NULL, title = NULL) {
assert_that(is.factor(var) | is.character(var))
assert_that(length(var) == length(imp))
assert_that(!(any(is.na(var) | any(is.na(imp)))))
assert_that(noNA(var))
assert_that(noNA(imp))

df <- data.frame(imp, var)
df$var <- factor(df$var, levels = df$var[order(df$imp)])
Expand Down Expand Up @@ -56,7 +58,12 @@ plot_twoway_partial <- function(var, pred, var_lab, grid, smooth = FALSE,
assert_that(is.numeric(pred) | is.factor(pred) | is.integer(pred))
assert_that(length(var) == length(pred))
assert_that(!(any(is.na(var) | any(is.na(pred)))))
assert_that(noNA(var))
assert_that(noNA(pred))
assert_that(noNA(grid))
assert_that(is.flag(smooth))
if (!missing(var_lab)) {
assert_that(noNA(var_lab))
assert_that(is.factor(var_lab) | is.character(var_lab))
assert_that(!missing(grid))
assert_that(is.integer(grid) & length(grid == 2))
Expand Down
2 changes: 1 addition & 1 deletion man/ivar_points.Rd
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
% Generated by roxygen2 (4.0.1): do not edit by hand
% Generated by roxygen2 (4.0.2): do not edit by hand
\name{ivar_points}
\alias{ivar_points}
\title{Creates a prediction vector for variables to decrease computation time}
Expand Down
2 changes: 1 addition & 1 deletion man/partial_dependence.Rd
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
% Generated by roxygen2 (4.0.1): do not edit by hand
% Generated by roxygen2 (4.0.2): do not edit by hand
\name{partial_dependence}
\alias{partial_dependence}
\title{Partial dependence using random forests}
Expand Down
6 changes: 3 additions & 3 deletions man/plot_imp.Rd
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
% Generated by roxygen2 (4.0.1): do not edit by hand
% Generated by roxygen2 (4.0.2): do not edit by hand
\name{plot_imp}
\alias{plot_imp}
\title{Plot permutation importance from random}
\title{Plot permutation importance from random forests}
\usage{
plot_imp(var, imp, ylab = NULL, xlab = NULL, title = NULL)
}
Expand All @@ -20,6 +20,6 @@ plot_imp(var, imp, ylab = NULL, xlab = NULL, title = NULL)
a ggplot2 object
}
\description{
Plot permutation importance from random
Plot permutation importance from random forests
}

2 changes: 1 addition & 1 deletion man/plot_twoway_partial.Rd
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
% Generated by roxygen2 (4.0.1): do not edit by hand
% Generated by roxygen2 (4.0.2): do not edit by hand
\name{plot_twoway_partial}
\alias{plot_twoway_partial}
\title{Plot two-way partial dependence from random forests}
Expand Down

0 comments on commit 8439b8c

Please sign in to comment.