Skip to content

Commit

Permalink
as requested in #17
Browse files Browse the repository at this point in the history
  • Loading branch information
pbiecek committed May 17, 2018
1 parent 5081b24 commit af54d4b
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 4 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: breakDown
Title: Break Down Plots
Version: 0.1.5
Version: 0.1.6
Authors@R: person("Przemyslaw", "Biecek", email = "przemyslaw.biecek@gmail.com", role = c("aut", "cre"))
Description: Break Down Plots are inspired by waterfall plots created by 'xgboostExplainer' package
(see <https://github.com/AppliedDataSciencePartners/xgboostExplainer>).
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
breakDown 0.1.6
----------------------------------------------------------------
* `broken.default` has now the `keep_distributions` arguments. If `TRUE` then the whole distribution of conditional residuals is remebered and avaliable for plotting [#17](https://github.com/pbiecek/breakDown/issues/17)

breakDown 0.1.5
----------------------------------------------------------------
* small changes in `broken.default` to make it work with `xgboost` and other non `data.frame` data
Expand Down
40 changes: 38 additions & 2 deletions R/break.R
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ broken.glm <- function(model, new_observation, ..., baseline = 0, predict.functi
#' @param ... other parameters
#' @param baseline the orgin/baseline for the breakDown plots, where the rectangles start. It may be a number or a character "Intercept". In the latter case the orgin will be set to model intercept.
#' @param predict.function function that will calculate predictions out of model. It shall return a single numeric value per observation. For classification it may be a probability of the default class.
#' @param keep_distributions if TRUE, then the distribution of partial predictions is stored in addition to the average.
#'
#' @return an object of the broken class
#'
Expand All @@ -181,10 +182,18 @@ broken.glm <- function(model, new_observation, ..., baseline = 0, predict.functi
#' explain_1
#' plot(explain_1) + ggtitle("breakDown plot (direction=down) for randomForest model")
#'
#' explain_2 <- broken(model, HR_data[11,-7], data = HR_data[,-7],
#' predict.function = predict.function, direction = "down", keep_distributions = TRUE)
#' plot(explain_2, plot_distributions = TRUE) + ggtitle("breakDown distributions (direction=down) for randomForest model")
#'
#' explain_3 <- broken(model, HR_data[11,-7], data = HR_data[,-7],
#' predict.function = predict.function, direction = "up", keep_distributions = TRUE)
#' plot(explain_3, plot_distributions = TRUE) + ggtitle("breakDown distributions (direction=up) for randomForest model")
#'
#' @export

broken.default <- function(model, new_observation, data, direction = "up", ..., baseline = 0,
predict.function = predict) {
keep_distributions = FALSE, predict.function = predict) {
# just in case only some variables are specified
# this will work only for data.frames
if ("data.frame" %in% class(data)) {
Expand All @@ -201,6 +210,27 @@ broken.default <- function(model, new_observation, data, direction = "up", ...,
predict.function, ...)
}

if (keep_distributions) {
# calcuate distribution for partial predictions
open_variables <- as.character(broken_sorted$variable_name)
current_data <- data
yhats_distribution <- list(data.frame(variable = "all data",
label = "all data",
id = 1:nrow(current_data),
prediction = predict.function(model, current_data, ...)))
for (i in seq_along(open_variables)) {
tmp_variable <- open_variables[i]
current_data[,tmp_variable] <- new_observation[,tmp_variable]
yhats_distribution[[tmp_variable]] <- data.frame(variable = tmp_variable,
label = as.character(broken_sorted$variable[i]),
id = 1:nrow(current_data),
prediction = predict.function(model, current_data, ...)
)
}
yhats_df <- do.call(rbind, yhats_distribution)
}


if (tolower(baseline) == "intercept") {
baseline <- mean(predict.function(model, data, ...))
broken_sorted <- rbind(
Expand All @@ -218,7 +248,13 @@ broken.default <- function(model, new_observation, data, direction = "up", ...,
broken_sorted)
}

create.broken(broken_sorted, baseline)
result <- create.broken(broken_sorted, baseline)

if (keep_distributions) {
attr(result, "yhats_distribution") <- yhats_df
}

result
}

broken_go_up <- function(model, new_observation, data,
Expand Down
12 changes: 11 additions & 1 deletion man/broken.default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit af54d4b

Please sign in to comment.