diff --git a/src/Trees/AbstractTrees_BetaML_interface.jl b/src/Trees/AbstractTrees_BetaML_interface.jl index 313cf952..85ffb84a 100644 --- a/src/Trees/AbstractTrees_BetaML_interface.jl +++ b/src/Trees/AbstractTrees_BetaML_interface.jl @@ -42,7 +42,7 @@ In case of a `BetaML/DecisionTree` this is typically a list of feature names as wrap(node::DecisionNode, info::NamedTuple = NamedTuple()) = InfoNode(node, info) wrap(leaf::Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info) wrap(mod::DecisionTreeEstimator, info::NamedTuple = NamedTuple()) = wrap(mod.par.tree, info) -wrap(m::Union{DecisionNode,Leaf,DecisionTreeEstimator};feature_names=[]) = wrap(m,(featurenames=feature_names,)) +wrap(m::Union{DecisionNode,Leaf,DecisionTreeEstimator};features_names=[]) = wrap(m,(features_names=features_names,)) @@ -58,7 +58,7 @@ AbstractTrees.children(node::InfoLeaf) = () function AbstractTrees.printnode(io::IO, node::InfoNode) q = node.node.question condition = isa(q.value, Number) ? ">=" : "==" - col = :featurenames ∈ keys(node.info) ? node.info.featurenames[q.column] : q.column + col = :features_names ∈ keys(node.info) ? node.info.features_names[q.column] : q.column print(io, "$(col) $condition $(q.value)?") end diff --git a/src/Trees/DecisionTrees.jl b/src/Trees/DecisionTrees.jl index b08c5f5d..eb34b7ec 100644 --- a/src/Trees/DecisionTrees.jl +++ b/src/Trees/DecisionTrees.jl @@ -253,8 +253,8 @@ Dict{String, Any}("job_is_regression" => 1, "fitted_records" => 6, "max_reached_ --> False: 3.3999999999999995 using Plots, TreeRecipe -feature_names = ["Something", "Som else"] -wrapped_tree = wrap(dtree, feature_names = feature_names) # feature_names is otional +features_names = ["Something", "Som else"] +wrapped_tree = wrap(dtree, features_names = features_names) # feature_names is otional plot(wrapped_tree) ```` ![DT plot](assets/dtplot.png) diff --git a/test/Trees_tests_additional.jl b/test/Trees_tests_additional.jl index d7e9c7b1..80a95c13 100644 --- a/test/Trees_tests_additional.jl +++ b/test/Trees_tests_additional.jl @@ -56,8 +56,8 @@ tree = Tree() yhat_train = Trees.fit!(model, xtrain, ytrain) println("--> add information about feature names") - feature_names = ["Color", "Size"] - wrapped_tree = wrap(model, feature_names = feature_names) + features_names = ["Color", "Size"] + wrapped_tree = wrap(model, features_names = features_names) println("--> plot the tree using the `TreeRecipe`") plt = plot(wrapped_tree) # this calls automatically the `TreeRecipe`