Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Nov 21, 2023
1 parent 747893a commit a4055b0
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
1 change: 0 additions & 1 deletion src/SIRUS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ include("rules.jl")
export StableRules, feature_names, directions, satisfies
include("ruleshow.jl")
include("weights.jl")
export TreePath
include("dependent.jl")
include("extract.jl")
export sum_weights
Expand Down
13 changes: 10 additions & 3 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ struct SubClause
end
end

function SubClause(
sp::SplitPoint,
direction::Symbol
)::SubClause
return SubClause(sp.feature, sp.feature_name, sp.value, direction)
end

_feature(s::SubClause) = s.feature
_feature_name(s::SubClause) = s.feature_name
_splitval(s::SubClause) = s.splitval
Expand Down Expand Up @@ -158,7 +165,7 @@ end
Return a vector split values; one for each subclause in `rule`.
"""
function Base.values(rule::Rule)::Vector{Float64}
return Float64[Float64(_value(s)) for s in _subclauses(rule)]
return Float64[Float64(_splitval(s)) for s in _subclauses(rule)]
end

"""
Expand Down Expand Up @@ -250,12 +257,12 @@ function Rule(
node::Union{Node, Leaf},
subclauses::Vector{SubClause}
)::Rule
path = Clause(subclauses)
clause = Clause(subclauses)
then_output = _then_output!(node, Vector{LeafContent}())
then = _mean(then_output)
else_output = _else_output!(node, root, Vector{LeafContent}())
otherwise = _mean(else_output)
return Rule(path, then, otherwise)
return Rule(clause, then, otherwise)
end

function _rules!(
Expand Down
14 changes: 7 additions & 7 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ end

let
text = " X[i, 1] < 1.0 & X[i, 1] ≥ 4.0 "
@test repr(TreePath(text)) == "TreePath(\"$text\")"
@test repr(S.Clause(text)) == "Clause(\"$text\")"
end

let
text = " X[i, :A] < 1.0 "
@test_throws ArgumentError repr(TreePath(text))
@test_throws ArgumentError repr(S.Clause(text))
end

classes = [:a, :b, :c]
Expand All @@ -24,18 +24,18 @@ right = S.Node(
S.ClassificationLeaf([0.0, 0.0, 1.0])
)

left_rule = S.Rule(S.TreePath(" X[i, 1] < 32000 "), [0.61], [0.408])
left_rule = S.Rule(S.Clause(" X[i, 1] < 32000 "), [0.61], [0.408])

@testset "exported functions" begin
@test feature_names(left_rule) == ["1"]
@test directions(left_rule) == [:L]
@test values(left_rule) == [32000]
end

r1 = S.Rule(S.TreePath(" X[i, 1] < 32000 "), [0.61], [0.408])
r1b = S.Rule(S.TreePath(" X[i, 1] < 32000 "), [0.61], [0.408])
r1c = S.Rule(S.TreePath(" X[i, 1] < 32000 "), [0.0], [0.408])
r5 = S.Rule(S.TreePath(" X[i, 3] < 64 "), [0.56], [0.334])
r1 = S.Rule(S.Clause(" X[i, 1] < 32000 "), [0.61], [0.408])
r1b = S.Rule(S.Clause(" X[i, 1] < 32000 "), [0.61], [0.408])
r1c = S.Rule(S.Clause(" X[i, 1] < 32000 "), [0.0], [0.408])
r5 = S.Rule(S.Clause(" X[i, 3] < 64 "), [0.56], [0.334])

algo = SIRUS.Classification()
@test S._mean([[1, 4], [2, 4]]) == [1.5, 4.0]
Expand Down

0 comments on commit a4055b0

Please sign in to comment.