Skip to content

Commit

Permalink
support for DALEX explainers
Browse files Browse the repository at this point in the history
  • Loading branch information
pbiecek committed Sep 14, 2018
1 parent ba50fe9 commit 8bceea6
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 44 deletions.
1 change: 1 addition & 0 deletions NEWS.md
@@ -1,6 +1,7 @@
breakDown 0.2.0
----------------------------------------------------------------
* `break_down` function identifies inteactions
* `break_down` function supports DALEX explainers
* `break_down` function has complexity O(2p) for models without interactions, much faster than the old version

breakDown 0.1.6
Expand Down
68 changes: 45 additions & 23 deletions R/break_interactions.R
Expand Up @@ -2,11 +2,14 @@
#'
#' This function implements decomposition of model predictions with identification
#' of interactions.
#' The complexity of this function is O(2*p) for additive models and O(2*p^2) for interactions
#' The complexity of this function is O(2*p) for additive models and O(2*p^2) for interactions.
#' This function works in similar way to step-up and step-down greedy approaximations,
#' the main difference is that in the fisrt step the order of variables is determied.
#' And in the second step the impact is calculated.
#'
#' @param explainer a model to be explained, preprocessed by function `DALEX::explain()`.
#' @param new_observation a new observation with columns that corresponds to variables used in the model
#' @param check_interactions the orgin/baseline for the breakDown plots, where the rectangles start. It may be a number or a character "Intercept". In the latter case the orgin will be set to model intercept.
#' @param check_interactions the orgin/baseline for the `breakDown`` plots, where the rectangles start. It may be a number or a character "Intercept". In the latter case the orgin will be set to model intercept.
#' @param keep_distributions if TRUE, then the distribution of partial predictions is stored in addition to the average.
#'
#' @return an object of the broken class
Expand All @@ -15,26 +18,44 @@
#' \dontrun{
#' library("breakDown")
#' library("randomForest")
#' library("ggplot2")
#' set.seed(1313)
#' model <- randomForest(factor(left)~., data = HR_data, family = "binomial", maxnodes = 5)
#' predict.function <- function(model, new_observation)
#' predict(model, new_observation, type="prob")[,2]
#' predict.function(model, HR_data[11,-7])
#' explain_1 <- broken(model, HR_data[11,-7], data = HR_data[,-7],
#' predict.function = predict.function, direction = "down")
#' explain_1
#' plot(explain_1) + ggtitle("breakDown plot (direction=down) for randomForest model")
#' # example with interaction
#' # classification for HR data
#' model <- randomForest(status ~ . , data = HR)
#' new_observation <- HRTest[1,]
#' data <- HR[1:1000,]
#' predict.function <- function(m,x) predict(m,x, type = "prob")[,1]
#'
#' explain_2 <- broken(model, HR_data[11,-7], data = HR_data[,-7],
#' predict.function = predict.function, direction = "down", keep_distributions = TRUE)
#' plot(explain_2, plot_distributions = TRUE) +
#' ggtitle("breakDown distributions (direction=down) for randomForest model")
#' explainer_rf_fired <- explain(model,
#' data = HR[1:1000,1:5],
#' y = HR$status[1:1000] == "fired",
#' predict_function = function(m,x) predict(m,x, type = "prob")[,1],
#' label = "fired")
#'
#' explain_3 <- broken(model, HR_data[11,-7], data = HR_data[,-7],
#' predict.function = predict.function, direction = "up", keep_distributions = TRUE)
#' plot(explain_3, plot_distributions = TRUE) +
#' ggtitle("breakDown distributions (direction=up) for randomForest model")
#' bd_rf <- break_down(explainer_rf_fired,
#' new_observation,
#' keep_distributions = TRUE)
#'
#' bd_rf
#' plot(bd_rf)
#' plot(bd_rf, plot_distributions = TRUE)
#'
#' # example for regression - apartment prices
#' # here we do not have intreactions
#' model <- randomForest(m2.price ~ . , data = apartments)
#' explainer_rf <- explain(model,
#' data = apartmentsTest[1:1000,2:6],
#' y = apartmentsTest$m2.price[1:1000],
#' label = "rf")
#'
#' bd_rf <- break_down(explainer_rf,
#' apartmentsTest[1,],
#' check_interactions = FALSE,
#' keep_distributions = TRUE)
#'
#' bd_rf
#' plot(bd_rf)
#' plot(bd_rf, plot_distributions = TRUE)
#' }
#' @export

Expand All @@ -50,9 +71,10 @@ break_down <- function(explainer, new_observation,
# this will work only for data.frames
if ("data.frame" %in% class(data)) {
common_variables <- intersect(colnames(new_observation), colnames(data))
new_observation <- new_observation[,common_variables, drop = FALSE]
new_observation <- new_observation[, common_variables, drop = FALSE]
data <- data[,common_variables, drop = FALSE]
}
p <- ncol(data)

# set target
target_yhat <- predict_function(model, new_observation)
Expand Down Expand Up @@ -97,7 +119,7 @@ break_down <- function(explainer, new_observation,

# Now we know the path, so we can calculate contributions
# set variable indicators
open_variables <- 1:ncol(data)
open_variables <- 1:p
current_data <- data

step <- 0
Expand Down Expand Up @@ -159,7 +181,7 @@ break_down <- function(explainer, new_observation,
yhats0 <- data.frame(variable = "all data",
label = "all data",
id = 1:nrow(data),
prediction = predict_function(model, current_data)
prediction = predict_function(model, data)
)

yhats_distribution <- rbind(yhats0, do.call(rbind, yhats))
Expand Down Expand Up @@ -217,6 +239,6 @@ calculate_2d_changes <- function(model, new_observation, data, predict_function,
}
names(average_yhats) <- paste(colnames(data)[inds[,1]],
colnames(data)[inds[,2]],
sep=":")
sep = ":")
list(average_yhats = average_yhats, average_yhats_norm = average_yhats_norm)
}
63 changes: 42 additions & 21 deletions man/break_down.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8bceea6

Please sign in to comment.