-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R-package] add a tree plotting function #6729
base: master
Are you sure you want to change the base?
Changes from 1 commit
6862821
0a7ea0e
5206b11
757dc84
55aba68
85ff97a
b4b648a
ed62441
2710705
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,59 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||
test_that("lgb.plot.tree works as expected"){ | ||||||||||||||||||||||||||||||||||||||||||||||||
data(agaricus.train, package = "lightgbm") | ||||||||||||||||||||||||||||||||||||||||||||||||
train <- agaricus.train | ||||||||||||||||||||||||||||||||||||||||||||||||
dtrain <- lgb.Dataset(train$data, label = train$label) | ||||||||||||||||||||||||||||||||||||||||||||||||
data(agaricus.test, package = "lightgbm") | ||||||||||||||||||||||||||||||||||||||||||||||||
test <- agaricus.test | ||||||||||||||||||||||||||||||||||||||||||||||||
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label) | ||||||||||||||||||||||||||||||||||||||||||||||||
# define model parameters and build a single tree | ||||||||||||||||||||||||||||||||||||||||||||||||
params <- list( | ||||||||||||||||||||||||||||||||||||||||||||||||
objective = "regression" | ||||||||||||||||||||||||||||||||||||||||||||||||
, metric = "l2" | ||||||||||||||||||||||||||||||||||||||||||||||||
, min_data = 1L | ||||||||||||||||||||||||||||||||||||||||||||||||
, learning_rate = 1.0 | ||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||
valids <- list(test = dtest) | ||||||||||||||||||||||||||||||||||||||||||||||||
model <- lgb.train( | ||||||||||||||||||||||||||||||||||||||||||||||||
params = params | ||||||||||||||||||||||||||||||||||||||||||||||||
, data = dtrain | ||||||||||||||||||||||||||||||||||||||||||||||||
, nrounds = 1L | ||||||||||||||||||||||||||||||||||||||||||||||||
, valids = valids | ||||||||||||||||||||||||||||||||||||||||||||||||
, early_stopping_rounds = 1L | ||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This is part of some suggested changes, please apply other changes following from it and to other examples and tests.
|
||||||||||||||||||||||||||||||||||||||||||||||||
# plot the tree and compare to the tree table | ||||||||||||||||||||||||||||||||||||||||||||||||
# trees start from 0 in lgb.model.dt.tree | ||||||||||||||||||||||||||||||||||||||||||||||||
tree_table <- lgb.model.dt.tree(model) | ||||||||||||||||||||||||||||||||||||||||||||||||
expect_true({ | ||||||||||||||||||||||||||||||||||||||||||||||||
lgb.plot.tree(model, 0)TRUE | ||||||||||||||||||||||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
test_that("lgb.plot.tree fails when a non existing tree is selected"){ | ||||||||||||||||||||||||||||||||||||||||||||||||
data(agaricus.train, package = "lightgbm") | ||||||||||||||||||||||||||||||||||||||||||||||||
train <- agaricus.train | ||||||||||||||||||||||||||||||||||||||||||||||||
dtrain <- lgb.Dataset(train$data, label = train$label) | ||||||||||||||||||||||||||||||||||||||||||||||||
data(agaricus.test, package = "lightgbm") | ||||||||||||||||||||||||||||||||||||||||||||||||
test <- agaricus.test | ||||||||||||||||||||||||||||||||||||||||||||||||
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label) | ||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please remove all these uses of a validation set? This feature is about plotting the trained model, and you are not using early stopping, so all of this work to create validation sets is unnecessary. Keeping the tests and examples as small and simple as possible makes the code easier to read / develop, and makes it clearer how test cases differ from each other. |
||||||||||||||||||||||||||||||||||||||||||||||||
# define model parameters and build a single tree | ||||||||||||||||||||||||||||||||||||||||||||||||
params <- list( | ||||||||||||||||||||||||||||||||||||||||||||||||
objective = "regression" | ||||||||||||||||||||||||||||||||||||||||||||||||
, metric = "l2" | ||||||||||||||||||||||||||||||||||||||||||||||||
, min_data = 1L | ||||||||||||||||||||||||||||||||||||||||||||||||
, learning_rate = 1.0 | ||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Similar to my comments on the docs... I strongly suspect we could just use default parameters here. |
||||||||||||||||||||||||||||||||||||||||||||||||
valids <- list(test = dtest) | ||||||||||||||||||||||||||||||||||||||||||||||||
model <- lgb.train( | ||||||||||||||||||||||||||||||||||||||||||||||||
params = params | ||||||||||||||||||||||||||||||||||||||||||||||||
, data = dtrain | ||||||||||||||||||||||||||||||||||||||||||||||||
, nrounds = 1L | ||||||||||||||||||||||||||||||||||||||||||||||||
, valids = valids | ||||||||||||||||||||||||||||||||||||||||||||||||
, early_stopping_rounds = 1L | ||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||
# plot the tree and compare to the tree table | ||||||||||||||||||||||||||||||||||||||||||||||||
# trees start from 0 in lgb.model.dt.tree | ||||||||||||||||||||||||||||||||||||||||||||||||
tree_table <- lgb.model.dt.tree(model) | ||||||||||||||||||||||||||||||||||||||||||||||||
expect_error({ | ||||||||||||||||||||||||||||||||||||||||||||||||
lgb.plot.tree(model, 999)TRUE | ||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This looks like it was included accidentally? |
||||||||||||||||||||||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For every use of That way, the test will be able to catch the case where some other unexpected issue causes this code path to fail. |
||||||||||||||||||||||||||||||||||||||||||||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also add tests for the other types of machine learning tasks LightGBM can be used for:
num_classes
trees produced per iteration)And for the following model situations:
These are all cases that could affect the code as written... for example, categorical features have different splitting rules.