Skip to content

Commit

Permalink
Removed BetaML from test environment and improved API for wrap(decisi…
Browse files Browse the repository at this point in the history
…on tree) for plotting
  • Loading branch information
sylvaticus committed Jan 4, 2023
1 parent a8a506e commit 109d710
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 12 deletions.
Binary file added assets/dtplot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
101 changes: 101 additions & 0 deletions assets/dtplot.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions src/Trees/AbstractTrees_BetaML_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@ In case of a `BetaML/DecisionTree` this is typically a list of feature names as
`wdc = wrap(dc, (featurenames = feature_names, ))`
"""

wrap(node::DecisionNode, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
wrap(leaf::Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)
wrap(mod::DecisionTreeEstimator, info::NamedTuple = NamedTuple()) = wrap(mod.par.tree, info)
wrap(m::Union{DecisionNode,Leaf,DecisionTreeEstimator};feature_names=[]) = wrap(m,(featurenames=feature_names,))




#### Implementation of the `AbstractTrees`-interface
Expand All @@ -61,4 +66,13 @@ function AbstractTrees.printnode(io::IO, leaf::InfoLeaf)
for p in leaf.leaf.predictions
println(io, p)
end
end

function show(io::IO,node::Union{InfoNode,InfoLeaf})
#print(io, "Is col $(question.column) $condition $(question.value) ?")
print(io, "A wrapped Decision Tree")
end

function show(io::IO, ::MIME"text/plain", node::Union{InfoNode,InfoLeaf})
print(io, "A wrapped Decision Tree")
end
35 changes: 28 additions & 7 deletions src/Trees/DecisionTrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,34 @@ Dict{String, Any}("job_is_regression" => 1, "fitted_records" => 6, "max_reached_
--> False: -13.8
--> False: 3.3999999999999995
```
- Visualisation...
You can either text-print or plot a decision tree..
```julia
julia> println(mod)
DecisionTreeEstimator - A Decision Tree regressor (fitted on 6 records)
Dict{String, Any}("job_is_regression" => 1, "fitted_records" => 6, "max_reached_depth" => 4, "avg_depth" => 3.25, "xndims" => 2)
*** Printing Decision Tree: ***
1. Is col 2 >= 18.0 ?
--> True :
1.2. Is col 2 >= 31.0 ?
--> True : -27.2
--> False:
1.2.3. Is col 2 >= 20.5 ?
--> True : -17.450000000000003
--> False: -13.8
--> False: 3.3999999999999995
using Plots, TreeRecipe
feature_names = ["Something", "Som else"]
wrapped_tree = wrap(dtree, feature_names = feature_names) # feature_names is otional
plot(wrapped_tree)
````
![DT plot](assets/dtplot.png)
"""
mutable struct DecisionTreeEstimator <: BetaMLSupervisedModel
hpar::DTHyperParametersSet
Expand Down Expand Up @@ -758,13 +786,6 @@ function computeDepths(node::AbstractNode)
return (mean(leafDepths),maximum(leafDepths))
end

function show(io::IO,question::Question)
condition = "=="
if isa(question.value, Number)
condition = ">="
end
print(io, "Is col $(question.column) $condition $(question.value) ?")
end

"""
print(node)
Expand Down
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
BetaML = "024491cd-cc6b-443e-8034-08ea7eb7db2b"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand Down
7 changes: 3 additions & 4 deletions test/Trees_tests_additional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import MLJBase
const Mlj = MLJBase
using StableRNGs
#rng = StableRNG(123)
using BetaML.Trees
using BetaML

println("*** Additional testing for the Testing Decision trees/Random Forest algorithms...")

Expand Down Expand Up @@ -54,11 +54,10 @@ tree = Tree()

model = DecisionTreeEstimator()
yhat_train = Trees.fit!(model, xtrain, ytrain)
dtree = model.par.tree

println("--> add information about feature names")
feature_names = ["Color", "Intensity"]
wrapped_tree = Trees.wrap(dtree, (featurenames = feature_names, ))
feature_names = ["Color", "Size"]
wrapped_tree = wrap(model, feature_names = feature_names)

println("--> plot the tree using the `TreeRecipe`")
plt = plot(wrapped_tree) # this calls automatically the `TreeRecipe`
Expand Down

0 comments on commit 109d710

Please sign in to comment.