Skip to content

Commit

Permalink
Remove rule sorting step (#68)
Browse files Browse the repository at this point in the history
It turns out that the whole rule sorting logic in `_sort_by_frequency`
wasn't necessary, so let's get rid of it. This PR also moves some code
around for clarity.
  • Loading branch information
rikhuijzer committed Nov 27, 2023
1 parent c13e1a9 commit 4f1ecd1
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 72 deletions.
6 changes: 3 additions & 3 deletions docs/src/binary-classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ Therefore, it makes more sense to truncate the rules to somewhere in the range 5
`max_depth` specifies how many levels the trees have.
For larger datasets, `max_depth=2` makes the most sense since it can find more complex patterns in the data.
For smaller datasets, `max_depth=1` makes more sense since it reduces the chance of overfitting.
It also simplifies the rules because with `max_depth=1`, the rule will contain only one conditional (for example, "if A then ...") versus two conditionals (for example, "if A & B then ...").
It also simplifies the rules because with `max_depth=1`, the rule will contain only one subclause (for example, "if A then ...") versus two subclauses (for example, "if A & B then ...").
In some cases, model accuracy can be improved by increasing `n_trees`.
The higher this number, the more trees are fitted and, hence, the higher the chance that the right rules are extracted from the trees.
"""
Expand All @@ -232,8 +232,8 @@ Since we know that the model performs well on the cross-validations, we can fit
md"""
## Visualization
Since our rules are relatively simple with only a binary outcome and only one clause in each rule, the following figure is a way to visualize the obtained rules per fold.
For multiple clauses, I would not know how to visualize the rules.
Since our rules are relatively simple with only a binary outcome and only one subclause in each rule, the following figure is a way to visualize the obtained rules per fold.
For multiple subclauses, I would not know how to visualize the rules.
Also, this plot is probably not perfect; let me know if you have suggestions.
This figure shows the model uncertainty by visualizing the obtained models for different cross-validation folds.
Expand Down
18 changes: 0 additions & 18 deletions src/dependent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,24 +224,6 @@ function _sort_by_gap_size(rules::Vector{Rule})
return sort(rules; alg, by=_gap_size, rev=true)
end

"""
Simplify the rules that contain a single split by only retaining rules that point left and
removing duplicates.
"""
function _simplify_single_rules(rules::Vector{Rule})::Vector{Rule}
out = OrderedSet{Rule}()
for rule in rules
splits = _subclauses(rule)
if length(splits) == 1
left_rule = _left_rule(rule)
push!(out, left_rule)
else
push!(out, rule)
end
end
return collect(out)
end

"""
Return a vector of rules that are not linearly dependent on any other rule.
Expand Down
75 changes: 38 additions & 37 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ struct Rule
otherwise::LeafContent
end

_clause(rule::Rule) = rule.clause
_subclauses(rule::Rule) = rule.clause.subclauses

"""
Expand Down Expand Up @@ -336,59 +337,59 @@ function _isempty_error(::StableRules)
throw(AssertionError("The rule model contains no rules"))
end

function _remove_zero_weights(rules::Vector{Rule}, weights::Vector{Float16})
filtered_rules = Rule[]
filtered_weights = Float16[]
@assert length(rules) == length(weights)
for i in eachindex(rules)
if weights[i] != Float16(0.0)
push!(filtered_rules, rules[i])
push!(filtered_weights, weights[i])
end
end
return filtered_rules, filtered_weights
end

function _count_unique(V::AbstractVector{T}) where T
U = unique(V)
l = length(U)
counts = Dict{T,Int}(zip(U, zeros(l)))
for v in V
counts[v] += 1
end
return counts
end

"""
Return a vector of unique values in `V` sorted by frequency.
Simplify the rules that contain a single split by only retaining rules that point left and
removing duplicates.
"""
function _sort_by_frequency(V::AbstractVector{T}) where T
counts = _count_unique(V)::Dict{T, Int}
alg = Helpers.STABLE_SORT_ALG
sorted = sort(collect(counts); alg, by=last, rev=true)
return first.(sorted)
function _simplify_single_rules(rules::Vector{Rule})::Vector{Rule}
out = OrderedSet{Rule}()
for rule in rules
subclauses = _subclauses(rule)
if length(subclauses) == 1
left_rule = _left_rule(rule)
push!(out, left_rule)
else
push!(out, rule)
end
end
return collect(out)
end

"""
Apply _rule selection_ and _rule set post-treatment_
(Bénard et al., [2021](http://proceedings.mlr.press/v130/benard21a)).
Rule selection, here, denotes sorting the set by frequency.
Next, linearly dependent rules are removed from the set.
To ensure the size of the final set is equal to `max_rules` in most cases, we ignore the
p0 parameter and instead pass all rules directly to the linearly dependent filter.
This is possible because the filter for linear dependencies is quite fast.
We have a slight modification here:
we do not sort first, select some p0 first, and then remove linearly dependent
rules because our linearly dependent filter is quick enough to handle all rules.
This means we don't need to sort at all.
For the sorting, note that the paper talks about sorting by frequency of the
**path** (clause) and not the rule, that is, clause with then and otherwise
probabalities.
"""
function _process_rules(
rules::Vector{Rule},
max_rules::Int
)::Vector{Rule}
simplified = _simplify_single_rules(rules)
sorted = _sort_by_frequency(simplified)
filtered = _filter_linearly_dependent(sorted)
simplified = _simplify_single_rules(rules)::Vector{Rule}
filtered = _filter_linearly_dependent(simplified)::Vector{Rule}
return first(filtered, max_rules)
end

function _remove_zero_weights(rules::Vector{Rule}, weights::Vector{Float16})
filtered_rules = Rule[]
filtered_weights = Float16[]
@assert length(rules) == length(weights)
for i in eachindex(rules)
if weights[i] != Float16(0.0)
push!(filtered_rules, rules[i])
push!(filtered_weights, weights[i])
end
end
return filtered_rules, filtered_weights
end

function StableRules(
rules::Vector{Rule},
algo::Algorithm,
Expand Down
10 changes: 10 additions & 0 deletions test/docs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
api_docs = read(joinpath(pkgdir(SIRUS), "docs", "src", "api.md"), String)

# Testing manually because setting doctest too restrictive doesn't work with PlutoStaticHTML.
for name in names(SIRUS)
@test contains(api_docs, string(name))
end

# warn suppresses warnings when keys already exist.
DocMeta.setdocmeta!(SIRUS, :DocTestSetup, :(using SIRUS); recursive=true, warn=false)
doctest(SIRUS)
10 changes: 0 additions & 10 deletions test/rules.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
@testset "sort is deterministic" begin
rng = StableRNG(1)
X = rand(rng, 1:10, 10_000)
@test SIRUS._sort_by_frequency(X) == SIRUS._sort_by_frequency(X)
end

let
text = " X[i, 1] < 1.0 & X[i, 1] ≥ 4.0 "
@test repr(S.Clause(text)) == "Clause(\"$text\")"
Expand Down Expand Up @@ -59,10 +53,6 @@ rules = S._rules(forest)
@test hash(r1) == hash(r1b)
@test hash(r1.clause) == hash(r1b.clause)

@test S._count_unique([1, 1, 1, 2]) == Dict(1 => 3, 2 => 1)

@test S._sort_by_frequency([r1, r5, r1]) == [r1, r5]

algo = SIRUS.Classification()
empty_model = S.StableRules(S.Rule[], algo, [1], Float16[0.1])
@test_throws AssertionError S._predict(empty_model, [31000])
Expand Down
6 changes: 2 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
include("preliminaries.jl")

@testset "doctests" begin
# warn suppresses warnings when keys already exist.
DocMeta.setdocmeta!(SIRUS, :DocTestSetup, :(using SIRUS); recursive=true, warn=false)
doctest(SIRUS)
@testset "docs" begin
include("docs.jl")
end

@testset "empiricalquantiles" begin
Expand Down

0 comments on commit 4f1ecd1

Please sign in to comment.