-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R-package] add a tree plotting function #6729
base: master
Are you sure you want to change the base?
Changes from 1 commit
6862821
0a7ea0e
5206b11
757dc84
55aba68
85ff97a
b4b648a
ed62441
2710705
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Added a warning to functions and shorter stop message to make tests work.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,8 +62,9 @@ lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) { | |
# extract data.table model structure | ||
modelDT <- lgb.model.dt.tree(model) | ||
# check that tree is less than or equal to the maximum tree index in the model | ||
if (tree > max(modelDT$tree_index)) { | ||
stop("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".") | ||
if (tree > max(modelDT$tree_index) || tree < 1) { | ||
warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".") | ||
stop("lgb.plot.tree: Invalid tree number") | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please modify this error message so that it has enough information for someone to quickly debug the issue, like the provided value of Something like this:
|
||
# filter modelDT to just the rows for the selected tree | ||
modelDT <- modelDT[tree_index == tree, ] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,24 +64,17 @@ models <- list( | |
|
||
for (model_name in names(models)){ | ||
model <- models[[model_name]] | ||
expected_n_trees <- NROUNDS | ||
if (model_name == "multi") { | ||
expected_n_trees <- NROUNDS * NCLASS | ||
} | ||
df <- as.data.frame(lgb.model.dt.tree(model)) | ||
df_list <- split(df, f = df$tree_index, drop = TRUE) | ||
df_leaf <- df[!is.na(df$leaf_index), ] | ||
df_internal <- df[is.na(df$leaf_index), ] | ||
modelDT <- lgb.model.dt.tree(model) | ||
|
||
test_that("lgb.plot.tree fails when a non existing tree is selected", { | ||
expect_error({ | ||
lgb.plot.tree(model, 0) | ||
}, regexp = "lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model") | ||
}, regexp = paste0("lgb.plot.tree: Invalid tree number")) | ||
}) | ||
test_that("lgb.plot.tree fails when a non existing tree is selected", { | ||
expect_error({ | ||
lgb.plot.tree(model, 999) | ||
}, regexp = "lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model") | ||
}, regexp = paste0("lgb.plot.tree: Invalid tree number")) | ||
}) | ||
test_that("lgb.plot.tree fails when a non numeric tree is selected", { | ||
expect_error({ | ||
|
@@ -96,7 +89,7 @@ for (model_name in names(models)){ | |
test_that("lgb.plot.tree fails when a non lgb.Booster model is passed", { | ||
expect_error({ | ||
lgb.plot.tree(1, 0) | ||
}, regexp = "lgb.plot.tree: model should be an 'lgb.Booster'") | ||
}, regexp = paste0("lgb.plot.tree: model should be an ", sQuote("lgb.Booster"))) | ||
}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For every use of That way, the test will be able to catch the case where some other unexpected issue causes this code path to fail. |
||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for having all of the information in a
warning()
and then immediately raising an error after? If there isn't a specific reason, then let's please combine these for simplicity and to make the logs easier for users to understand.