Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] add a tree plotting function #6729

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Corrected error (missing comma) in the selected tree check (L66).
Commented code.
  • Loading branch information
fboudry committed Dec 25, 2024
commit 85ff97aa733001c06085c6057422d4b8bc582aa4
32 changes: 22 additions & 10 deletions R-package/R/lgb.plot.tree.R
Original file line number Diff line number Diff line change
@@ -40,12 +40,12 @@
#' }
#'
#' @export

lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
lgb.plot.tree <- function(model, tree, rules = NULL) {

I can't think of any situation where it would be ok for model or tree to be NULL, can you?

If not, let's please require callers to provide values explicitly.

# check model is lgb.Booster
if (!.is_Booster(x = model)) {
stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster"))
}
# check DiagrammeR is available
if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
stop("lgb.plot.tree: DiagrammeR package is required",
call. = FALSE
@@ -63,26 +63,36 @@ lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
modelDT <- lgb.model.dt.tree(model)
# check that tree is less than or equal to the maximum tree index in the model
if (tree > max(modelDT$tree_index)) {
stop("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index, "). Got: ," tree, ".")
stop("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please modify this error message so that it has enough information for someone to quickly debug the issue, like the provided value of tree and the number of trees in the model. And please combine it with the other check that the value is `>=01.

Something like this:

lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (125). Got: 181.

# filter modelDT to just the rows for the selected tree
modelDT <- modelDT[tree_index == tree, ]
# change the column names to shorter more diagram friendly versions
data.table::setnames(modelDT, old = c("tree_index", "split_feature", "threshold", "split_gain"), new = c("Tree", "Feature", "Split", "Gain"))
modelDT[, Value := 0.0]
data.table::setnames(modelDT
, old = c("tree_index", "split_feature", "threshold", "split_gain")
, new = c("Tree", "Feature", "Split", "Gain"))
Comment on lines +72 to +74
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data.table::setnames(modelDT
, old = c("tree_index", "split_feature", "threshold", "split_gain")
, new = c("Tree", "Feature", "Split", "Gain"))
data.table::setnames(
modelDT
, old = c("tree_index", "split_feature", "threshold", "split_gain")
, new = c("Tree", "Feature", "Split", "Gain")
)

Please, follow the style the rest of the project uses. I suspect that the linting configuration here would have caught this (not sure, as I haven't run it myself and it failed in CI for other unrelated reasons).

From this point forward, before you push a commit please run the R-code linting and fix any issues it reports.

From the root of the repo:

Rscript ./.ci/lint-r-code.R ./R-package

# assign leaf_value to the Value column in modelDT
modelDT[, Value := leaf_value]
# assign new values if NA
modelDT[is.na(Value), Value := internal_value]
modelDT[is.na(Gain), Gain := leaf_value]
modelDT[is.na(Feature), Feature := "Leaf"]
# assign internal_count to Cover, and if Feature is "Leaf", assign leaf_count to Cover
modelDT[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count]
# remove unnecessary columns
modelDT[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL]
# assign split_index to Node
modelDT[, Node := split_index]
# find the maximum value of Node, if Node is NA, assign max_node + leaf_index + 1 to Node
max_node <- max(modelDT[["Node"]], na.rm = TRUE)
modelDT[is.na(Node), Node := max_node + leaf_index + 1]
# adding ID column
modelDT[, ID := paste(Tree, Node, sep = "-")]
# remove unnecessary columns
modelDT[, c("depth", "leaf_index") := NULL]
modelDT[, parent := node_parent][is.na(parent), parent := leaf_parent]
modelDT[, c("node_parent", "leaf_parent", "split_index") := NULL]
# assign the IDs of the matching parent nodes to Yes and No
modelDT[, Yes := modelDT$ID[match(modelDT$Node, modelDT$parent)]]
modelDT <- modelDT[nrow(modelDT):1, ]
modelDT[, No := modelDT$ID[match(modelDT$Node, modelDT$parent)]]
@@ -91,14 +101,16 @@ lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
modelDT[default_left == TRUE, Missing := Yes]
modelDT[default_left == FALSE, Missing := No]
modelDT[.zero_present(Split), Missing := Yes]
# modelDT[, c('parent', 'default_left') := NULL]
# data.table::setcolorder(modelDT, c('Tree','Node','ID','Feature','decision_type','Split','Yes','No','Missing','Gain','Cover','Value'))
# create the label text
modelDT[, label := paste0(
Feature,
"\nCover: ", Cover,
ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf", "", round(Gain, 4)),
"\nValue: ", round(Value, 4)
Feature
, "\nCover: "
, Cover
, ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf"
, ""
, round(Gain, 4))
, "\nValue: "
, round(Value, 4)
)]
# style the nodes - same format as xgboost
modelDT[Node == 0, label := paste0("Tree ", Tree, "\n", label)]