Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] add a tree plotting function #6729

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Added tests.
  • Loading branch information
fboudry committed Nov 23, 2024
commit 0a7ea0e433c067b633913694d30b4917336f9db9
1 change: 1 addition & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@ export(lgb.make_serializable)
export(lgb.model.dt.tree)
export(lgb.plot.importance)
export(lgb.plot.interpretation)
export(lgb.plot.tree)
export(lgb.restore_handle)
export(lgb.save)
export(lgb.slice.Dataset)
59 changes: 59 additions & 0 deletions R-package/tests/testthat/test_lgb.plot.tree.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
test_that("lgb.plot.tree works as expected"){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add tests for the other types of machine learning tasks LightGBM can be used for:

  • binary classification
  • multiclass classification (where, please note, there are num_classes trees produced per iteration)
  • learning-to-rank

And for the following model situations:

  • uses categorical features

These are all cases that could affect the code as written... for example, categorical features have different splitting rules.

data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
# define model parameters and build a single tree
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
, learning_rate = 1.0
)
valids <- list(test = dtest)
model <- lgb.train(
params = params
, data = dtrain
, nrounds = 1L
, valids = valids
, early_stopping_rounds = 1L
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
, learning_rate = 1.0
)
valids <- list(test = dtest)
model <- lgb.train(
params = params
, data = dtrain
, nrounds = 1L
, valids = valids
, early_stopping_rounds = 1L
)
model <- lgb.train(
params = list(
objective = "regression"
, num_threads = .LGB_MAX_THREADS
)
, data = dtrain
, nrounds = 1L
, verbose = .LGB_VERBOSITY
)

This is part of some suggested changes, please apply other changes following from it and to other examples and tests.

  1. every call to lgb.train() should set num_threads = .LGB_MAX_THREADS in params, to avoid using too many CPUs on the CRAN check machines (see [R-package] limit number of threads used in tests and examples (fixes #5987) #5988 for background)
  2. every call to lightgbm functions should set verbosity to .LGB_VERBOSITY to allow globally controlling the amount of log messages produced across all tests (see https://github.com/microsoft/LightGBM/blob/master/R-package/README.md#running-the-tests)
  3. since params is small and only being used once in this test code, just define it inline
  4. only specify things in params which are necessary for the test to be effective (e.g., no need to set learning_rate to a non-default value)

# plot the tree and compare to the tree table
# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
expect_true({
lgb.plot.tree(model, 0)TRUE
})
}

test_that("lgb.plot.tree fails when a non existing tree is selected"){
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please remove all these uses of a validation set? This feature is about plotting the trained model, and you are not using early stopping, so all of this work to create validation sets is unnecessary.

Keeping the tests and examples as small and simple as possible makes the code easier to read / develop, and makes it clearer how test cases differ from each other.

# define model parameters and build a single tree
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
, learning_rate = 1.0
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
, learning_rate = 1.0
)
params <- list(
objective = "regression"
)

Similar to my comments on the docs... I strongly suspect we could just use default parameters here.

valids <- list(test = dtest)
model <- lgb.train(
params = params
, data = dtrain
, nrounds = 1L
, valids = valids
, early_stopping_rounds = 1L
)
# plot the tree and compare to the tree table
# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
expect_error({
lgb.plot.tree(model, 999)TRUE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lgb.plot.tree(model, 999)TRUE
lgb.plot.tree(model, 999)

This looks like it was included accidentally?

})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For every use of expect_error() here, please check for the specific error you are expecting, like this:

https://github.com/microsoft/LightGBM/blob/83c0ff3de1925b0e2d4831a9ccb6ffc196aa795b/R-package/tests/testthat/test_lgb.importance.R#L33-35

That way, the test will be able to catch the case where some other unexpected issue causes this code path to fail.

}
Loading
Oops, something went wrong.