Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] add a tree plotting function #6729

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Corrected tests.
Added a warning to functions and shorter stop message to make tests work.
  • Loading branch information
fboudry committed Dec 25, 2024
commit b4b648ab6522d305ce9e2c9f30c494c82b4e285f
5 changes: 3 additions & 2 deletions R-package/R/lgb.plot.tree.R
Original file line number Diff line number Diff line change
@@ -62,8 +62,9 @@ lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
# extract data.table model structure
modelDT <- lgb.model.dt.tree(model)
# check that tree is less than or equal to the maximum tree index in the model
if (tree > max(modelDT$tree_index)) {
stop("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")
if (tree > max(modelDT$tree_index) || tree < 1) {
warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")
stop("lgb.plot.tree: Invalid tree number")
Comment on lines +66 to +67
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")
stop("lgb.plot.tree: Invalid tree number")
stop("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")

What's the reason for having all of the information in a warning() and then immediately raising an error after? If there isn't a specific reason, then let's please combine these for simplicity and to make the logs easier for users to understand.

}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please modify this error message so that it has enough information for someone to quickly debug the issue, like the provided value of tree and the number of trees in the model. And please combine it with the other check that the value is `>=01.

Something like this:

lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (125). Got: 181.

# filter modelDT to just the rows for the selected tree
modelDT <- modelDT[tree_index == tree, ]
15 changes: 4 additions & 11 deletions R-package/tests/testthat/test_lgb.plot.tree.R
Original file line number Diff line number Diff line change
@@ -64,24 +64,17 @@ models <- list(

for (model_name in names(models)){
model <- models[[model_name]]
expected_n_trees <- NROUNDS
if (model_name == "multi") {
expected_n_trees <- NROUNDS * NCLASS
}
df <- as.data.frame(lgb.model.dt.tree(model))
df_list <- split(df, f = df$tree_index, drop = TRUE)
df_leaf <- df[!is.na(df$leaf_index), ]
df_internal <- df[is.na(df$leaf_index), ]
modelDT <- lgb.model.dt.tree(model)

test_that("lgb.plot.tree fails when a non existing tree is selected", {
expect_error({
lgb.plot.tree(model, 0)
}, regexp = "lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model")
}, regexp = paste0("lgb.plot.tree: Invalid tree number"))
})
test_that("lgb.plot.tree fails when a non existing tree is selected", {
expect_error({
lgb.plot.tree(model, 999)
}, regexp = "lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model")
}, regexp = paste0("lgb.plot.tree: Invalid tree number"))
})
test_that("lgb.plot.tree fails when a non numeric tree is selected", {
expect_error({
@@ -96,7 +89,7 @@ for (model_name in names(models)){
test_that("lgb.plot.tree fails when a non lgb.Booster model is passed", {
expect_error({
lgb.plot.tree(1, 0)
}, regexp = "lgb.plot.tree: model should be an 'lgb.Booster'")
}, regexp = paste0("lgb.plot.tree: model should be an ", sQuote("lgb.Booster")))
})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For every use of expect_error() here, please check for the specific error you are expecting, like this:

https://github.com/microsoft/LightGBM/blob/83c0ff3de1925b0e2d4831a9ccb6ffc196aa795b/R-package/tests/testthat/test_lgb.importance.R#L33-35

That way, the test will be able to catch the case where some other unexpected issue causes this code path to fail.

}

Loading
Oops, something went wrong.