Skip to content

Commit

Permalink
minor: replaced feature_names with features_name in DT plot wrap
Browse files Browse the repository at this point in the history
  • Loading branch information
sylvaticus committed Jan 4, 2023
1 parent 109d710 commit 86a003f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/Trees/AbstractTrees_BetaML_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ In case of a `BetaML/DecisionTree` this is typically a list of feature names as
wrap(node::DecisionNode, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
wrap(leaf::Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)
wrap(mod::DecisionTreeEstimator, info::NamedTuple = NamedTuple()) = wrap(mod.par.tree, info)
wrap(m::Union{DecisionNode,Leaf,DecisionTreeEstimator};feature_names=[]) = wrap(m,(featurenames=feature_names,))
wrap(m::Union{DecisionNode,Leaf,DecisionTreeEstimator};features_names=[]) = wrap(m,(features_names=features_names,))



Expand All @@ -58,7 +58,7 @@ AbstractTrees.children(node::InfoLeaf) = ()
function AbstractTrees.printnode(io::IO, node::InfoNode)
q = node.node.question
condition = isa(q.value, Number) ? ">=" : "=="
col = :featurenames keys(node.info) ? node.info.featurenames[q.column] : q.column
col = :features_names keys(node.info) ? node.info.features_names[q.column] : q.column
print(io, "$(col) $condition $(q.value)?")
end

Expand Down
4 changes: 2 additions & 2 deletions src/Trees/DecisionTrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ Dict{String, Any}("job_is_regression" => 1, "fitted_records" => 6, "max_reached_
--> False: 3.3999999999999995
using Plots, TreeRecipe
feature_names = ["Something", "Som else"]
wrapped_tree = wrap(dtree, feature_names = feature_names) # feature_names is otional
features_names = ["Something", "Som else"]
wrapped_tree = wrap(dtree, features_names = features_names) # feature_names is otional
plot(wrapped_tree)
````
![DT plot](assets/dtplot.png)
Expand Down
4 changes: 2 additions & 2 deletions test/Trees_tests_additional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ tree = Tree()
yhat_train = Trees.fit!(model, xtrain, ytrain)

println("--> add information about feature names")
feature_names = ["Color", "Size"]
wrapped_tree = wrap(model, feature_names = feature_names)
features_names = ["Color", "Size"]
wrapped_tree = wrap(model, features_names = features_names)

println("--> plot the tree using the `TreeRecipe`")
plt = plot(wrapped_tree) # this calls automatically the `TreeRecipe`
Expand Down

0 comments on commit 86a003f

Please sign in to comment.