Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] add a tree plotting function #6729

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Updated tests. (based on R-package/tests/testthat/test_lgb.model.dt.t…
…ree.R)

Now tests regressions, binary, multiclass classification and ranks.
  • Loading branch information
fboudry committed Dec 25, 2024
commit 55aba68c7fac9273fc9e125eac8faa0aef222c41
137 changes: 98 additions & 39 deletions R-package/tests/testthat/test_lgb.plot.tree.R
Original file line number Diff line number Diff line change
@@ -1,43 +1,102 @@
test_that("lgb.plot.tree works as expected"){
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
# define model parameters and build a single tree
model <- lgb.train(
params = list(
objective = "regression"
, num_threads = .LGB_MAX_THREADS
)
, data = dtrain
, nrounds = 1L
, verbose = .LGB_VERBOSITY
NROUNDS <- 10L
MAX_DEPTH <- 3L
N <- nrow(iris)
X <- data.matrix(iris[2L:4L])
FEAT <- colnames(X)
NCLASS <- nlevels(iris[, 5L])

model_reg <- lgb.train(
params = list(
objective = "regression"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
)
# plot the tree and compare to the tree table
# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
expect_true({
lgb.plot.tree(model, 0)
}, regexp = "lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model")
}
, data = lgb.Dataset(X, label = iris[, 1L])
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)

model_binary <- lgb.train(
params = list(
objective = "binary"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
)
, data = lgb.Dataset(X, label = iris[, 5L] == "setosa")
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)

model_multiclass <- lgb.train(
params = list(
objective = "multiclass"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
, num_classes = NCLASS
)
, data = lgb.Dataset(X, label = as.integer(iris[, 5L]) - 1L)
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)

test_that("lgb.plot.tree fails when a non existing tree is selected"){
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
# define model parameters and build a single tree
model <- lgb.train(
params = list(
objective = "regression"
, num_threads = .LGB_MAX_THREADS
)
, data = dtrain
, nrounds = 1L
, verbose = .LGB_VERBOSITY
model_rank <- lgb.train(
params = list(
objective = "lambdarank"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
, lambdarank_truncation_level = 3L
)
# plot the tree and compare to the tree table
# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
expect_error({
lgb.plot.tree(model, 999)
}, regexp = "lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model")
, data = lgb.Dataset(
X
, label = as.integer(iris[, 1L] > 5.8)
, group = rep(10L, times = 15L)
)
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)

models <- list(
reg = model_reg
, bin = model_binary
, multi = model_multiclass
, rank = model_rank
)

for (model_name in names(models)){
model <- models[[model_name]]
expected_n_trees <- NROUNDS
if (model_name == "multi") {
expected_n_trees <- NROUNDS * NCLASS
}
df <- as.data.frame(lgb.model.dt.tree(model))
df_list <- split(df, f = df$tree_index, drop = TRUE)
df_leaf <- df[!is.na(df$leaf_index), ]
df_internal <- df[is.na(df$leaf_index), ]

test_that("lgb.plot.tree fails when a non existing tree is selected", {
expect_error({
lgb.plot.tree(model, 0)
}, regexp = "lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model")
})
test_that("lgb.plot.tree fails when a non existing tree is selected", {
expect_error({
lgb.plot.tree(model, 999)
}, regexp = "lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model")
})
test_that("lgb.plot.tree fails when a non numeric tree is selected", {
expect_error({
lgb.plot.tree(model, "a")
}, regexp = "lgb.plot.tree: Has to be an integer numeric")
})
test_that("lgb.plot.tree fails when a non integer tree is selected", {
expect_error({
lgb.plot.tree(model, 1.5)
}, regexp = "lgb.plot.tree: Has to be an integer numeric")
})
test_that("lgb.plot.tree fails when a non lgb.Booster model is passed", {
expect_error({
lgb.plot.tree(1, 0)
}, regexp = "lgb.plot.tree: model should be an 'lgb.Booster'")
})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For every use of expect_error() here, please check for the specific error you are expecting, like this:

https://github.com/microsoft/LightGBM/blob/83c0ff3de1925b0e2d4831a9ccb6ffc196aa795b/R-package/tests/testthat/test_lgb.importance.R#L33-35

That way, the test will be able to catch the case where some other unexpected issue causes this code path to fail.

}