-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R-package] add a tree plotting function #6729
base: master
Are you sure you want to change the base?
Changes from 1 commit
6862821
0a7ea0e
5206b11
757dc84
55aba68
85ff97a
b4b648a
ed62441
2710705
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Commented code.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -40,12 +40,12 @@ | |||||||||||||||||
#' } | ||||||||||||||||||
#' | ||||||||||||||||||
#' @export | ||||||||||||||||||
|
||||||||||||||||||
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) { | ||||||||||||||||||
# check model is lgb.Booster | ||||||||||||||||||
if (!.is_Booster(x = model)) { | ||||||||||||||||||
stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster")) | ||||||||||||||||||
} | ||||||||||||||||||
# check DiagrammeR is available | ||||||||||||||||||
if (!requireNamespace("DiagrammeR", quietly = TRUE)) { | ||||||||||||||||||
stop("lgb.plot.tree: DiagrammeR package is required", | ||||||||||||||||||
call. = FALSE | ||||||||||||||||||
|
@@ -63,26 +63,36 @@ lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) { | |||||||||||||||||
modelDT <- lgb.model.dt.tree(model) | ||||||||||||||||||
# check that tree is less than or equal to the maximum tree index in the model | ||||||||||||||||||
if (tree > max(modelDT$tree_index)) { | ||||||||||||||||||
stop("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index, "). Got: ," tree, ".") | ||||||||||||||||||
stop("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".") | ||||||||||||||||||
} | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please modify this error message so that it has enough information for someone to quickly debug the issue, like the provided value of Something like this:
|
||||||||||||||||||
# filter modelDT to just the rows for the selected tree | ||||||||||||||||||
modelDT <- modelDT[tree_index == tree, ] | ||||||||||||||||||
# change the column names to shorter more diagram friendly versions | ||||||||||||||||||
data.table::setnames(modelDT, old = c("tree_index", "split_feature", "threshold", "split_gain"), new = c("Tree", "Feature", "Split", "Gain")) | ||||||||||||||||||
modelDT[, Value := 0.0] | ||||||||||||||||||
data.table::setnames(modelDT | ||||||||||||||||||
, old = c("tree_index", "split_feature", "threshold", "split_gain") | ||||||||||||||||||
, new = c("Tree", "Feature", "Split", "Gain")) | ||||||||||||||||||
Comment on lines
+72
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Please, follow the style the rest of the project uses. I suspect that the linting configuration here would have caught this (not sure, as I haven't run it myself and it failed in CI for other unrelated reasons). From this point forward, before you push a commit please run the R-code linting and fix any issues it reports. From the root of the repo: Rscript ./.ci/lint-r-code.R ./R-package |
||||||||||||||||||
# assign leaf_value to the Value column in modelDT | ||||||||||||||||||
modelDT[, Value := leaf_value] | ||||||||||||||||||
# assign new values if NA | ||||||||||||||||||
modelDT[is.na(Value), Value := internal_value] | ||||||||||||||||||
modelDT[is.na(Gain), Gain := leaf_value] | ||||||||||||||||||
modelDT[is.na(Feature), Feature := "Leaf"] | ||||||||||||||||||
# assign internal_count to Cover, and if Feature is "Leaf", assign leaf_count to Cover | ||||||||||||||||||
modelDT[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count] | ||||||||||||||||||
# remove unnecessary columns | ||||||||||||||||||
modelDT[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL] | ||||||||||||||||||
# assign split_index to Node | ||||||||||||||||||
modelDT[, Node := split_index] | ||||||||||||||||||
# find the maximum value of Node, if Node is NA, assign max_node + leaf_index + 1 to Node | ||||||||||||||||||
max_node <- max(modelDT[["Node"]], na.rm = TRUE) | ||||||||||||||||||
modelDT[is.na(Node), Node := max_node + leaf_index + 1] | ||||||||||||||||||
# adding ID column | ||||||||||||||||||
modelDT[, ID := paste(Tree, Node, sep = "-")] | ||||||||||||||||||
# remove unnecessary columns | ||||||||||||||||||
modelDT[, c("depth", "leaf_index") := NULL] | ||||||||||||||||||
modelDT[, parent := node_parent][is.na(parent), parent := leaf_parent] | ||||||||||||||||||
modelDT[, c("node_parent", "leaf_parent", "split_index") := NULL] | ||||||||||||||||||
# assign the IDs of the matching parent nodes to Yes and No | ||||||||||||||||||
modelDT[, Yes := modelDT$ID[match(modelDT$Node, modelDT$parent)]] | ||||||||||||||||||
modelDT <- modelDT[nrow(modelDT):1, ] | ||||||||||||||||||
modelDT[, No := modelDT$ID[match(modelDT$Node, modelDT$parent)]] | ||||||||||||||||||
|
@@ -91,14 +101,16 @@ lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) { | |||||||||||||||||
modelDT[default_left == TRUE, Missing := Yes] | ||||||||||||||||||
modelDT[default_left == FALSE, Missing := No] | ||||||||||||||||||
modelDT[.zero_present(Split), Missing := Yes] | ||||||||||||||||||
# modelDT[, c('parent', 'default_left') := NULL] | ||||||||||||||||||
# data.table::setcolorder(modelDT, c('Tree','Node','ID','Feature','decision_type','Split','Yes','No','Missing','Gain','Cover','Value')) | ||||||||||||||||||
# create the label text | ||||||||||||||||||
modelDT[, label := paste0( | ||||||||||||||||||
Feature, | ||||||||||||||||||
"\nCover: ", Cover, | ||||||||||||||||||
ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf", "", round(Gain, 4)), | ||||||||||||||||||
"\nValue: ", round(Value, 4) | ||||||||||||||||||
Feature | ||||||||||||||||||
, "\nCover: " | ||||||||||||||||||
, Cover | ||||||||||||||||||
, ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf" | ||||||||||||||||||
, "" | ||||||||||||||||||
, round(Gain, 4)) | ||||||||||||||||||
, "\nValue: " | ||||||||||||||||||
, round(Value, 4) | ||||||||||||||||||
)] | ||||||||||||||||||
# style the nodes - same format as xgboost | ||||||||||||||||||
modelDT[Node == 0, label := paste0("Tree ", Tree, "\n", label)] | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't think of any situation where it would be ok for
model
ortree
to beNULL
, can you?If not, let's please require callers to provide values explicitly.