Skip to content

Commit

Permalink
Rewrite Split to SubClause
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Nov 14, 2023
1 parent 71abde0 commit dcdb99e
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions src/rules.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
"""
Split(splitpoint::SplitPoint, direction::Symbol) -> Split
Split(feature::Int, name::String, splitval::Float32, direction::Symbol) -> Split
SubClause
A split in a tree.
Each rule is based on one or more splits.
A subclause denotes a conditional on one feature.
Each rule contains a clause with one or more subclauses.
For example, the rule `if X[i, 1] > 3 & X[i, 2] < 4, then ...` contains two subclauses.
Data can be accessed via `_feature`, `_value`, `_feature_name`, `_direction`, and `_reverse`.
A subclause is equivalent to a split in a decision tree.
In other words, each rule is based on one or more splits.
In pratice, a rule is based on at most two splits (has at most two subclauses).
The reason for this is that rules with more than two subclauses will not end up
in the final model, as is discussed in the original SIRUS paper.
The data inside a `SubClause` can be accessed via
- `_feature`,
- `_feature_name`,
- `_splitval`, and
- `_direction`.
To obtain the reverse, use `_reverse`.
"""
struct SubClause
# Removed splitpoint
# splitpoint::SplitPoint
feature::Int,
feature_name::String,
splitval::Float32,
direction::Symbol # :L or :R
end

function Split(feature::Int, name::String, splitval::Float32, direction::Symbol)
return Split(SplitPoint(feature, splitval, name), direction)
end
_feature(s::SubClause) = s.feature
_feature_name(s::SubClause) = s.feature_name
_splitval(s::SubClause) = s.splitval
_direction(s::SubClause) = s.direction

_feature(split::Split) = _feature(split.splitpoint)
_value(split::Split) = _value(split.splitpoint)
_feature_name(split::Split) = _feature_name(split.splitpoint)
_direction(split::Split) = split.direction
_reverse(split::Split) = Split(split.splitpoint, split.direction == :L ? :R : :L)
function _reverse(s::SubClause)
direction = s.direction == :L ? :R : :L
return SubClause(s.feature, s.feature_name, s.splitval, direction)
end

"""
TreePath(splits::Vector{Split}) -> TreePath
Expand Down

0 comments on commit dcdb99e

Please sign in to comment.