Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Jun 27, 2023
1 parent 0d162ee commit e4c6099
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 13 deletions.
7 changes: 3 additions & 4 deletions src/dependent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ If we don't do this, we might remove some rule `r` that causes another rule to b
dependent in one related set, but then is removed in another related set.
"""
function _filter_linearly_dependent(rules::Vector{Rule})::Vector{Rule}
filtered = _simplify_single_rules(rules)
sorted = _sort_by_gap_size(filtered)
sorted = _sort_by_gap_size(rules)
S = _unique_left_splits(sorted)
pairs = _left_triangular_product(S)
out = copy(sorted)
Expand Down Expand Up @@ -293,9 +292,9 @@ Return a linearly independent subset of `rules` of length ≤ `max_rules`.
"""
function _process_rules(
rules::Vector{Rule},
algo::Algorithm,
max_rules::Int
)::Vector{Rule}
filtered = _filter_linearly_dependent(rules)
simplified = _simplify_single_rules(rules)
filtered = _filter_linearly_dependent(simplified)
return first(filtered, max_rules)
end
2 changes: 1 addition & 1 deletion src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ function StableRules(
outcome,
model::Probabilistic
)::StableRules
processed = _process_rules(rules, algo, model.max_rules)
processed = _process_rules(rules, model.max_rules)
weights = _weights(processed, algo, classes, data, outcome, model)
filtered_rules, filtered_weights = _remove_zero_weights(processed, weights)
return StableRules(filtered_rules, algo, classes, filtered_weights)
Expand Down
4 changes: 2 additions & 2 deletions test/classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ end
end

fpreds = DecisionTree.apply_forest(dforest, data)
@show accuracy(fpreds, y)
# @show accuracy(fpreds, y)
@test 0.95 < accuracy(fpreds, y)

sfpreds = SIRUS._predict(sforest, data)
@show accuracy(mode.(sfpreds), y)
# @show accuracy(mode.(sfpreds), y)
@test 0.95 < accuracy(mode.(sfpreds), y)

empty_forest = SIRUS.StableForest(Union{SIRUS.Leaf, SIRUS.Node}[], algo, [1])
Expand Down
11 changes: 5 additions & 6 deletions test/dependent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ end
@test S._linearly_dependent([rule], A, B) == Bool[0]
end

@test S._filter_linearly_dependent(repeat([r1], 10)) == [r1]
@test S._process_rules(repeat([r1], 10), 10) == [r1]

function _canonicalize(rules::Vector{SIRUS.Rule})
[length(r.path.splits) == 1 ? SIRUS._left_rule(r) : r for r in rules]
Expand All @@ -118,10 +118,9 @@ actual = S._filter_linearly_dependent(canonical)
expected = _canonicalize(expected)
@test Set(actual) == Set(expected)

algo = SIRUS.Classification()
@test length(S._process_rules(allrules, algo, 9)) == 9
@test length(S._process_rules(allrules, algo, 10)) == 9
@test length(S._process_rules([r1], algo, 9)) == 1
@test length(S._process_rules(repeat(allrules, 200), algo, 9)) == 9
@test length(S._process_rules(allrules, 9)) == 9
@test length(S._process_rules(allrules, 10)) == 9
@test length(S._process_rules([r1], 9)) == 1
@test length(S._process_rules(repeat(allrules, 200), 9)) == 9

nothing

0 comments on commit e4c6099

Please sign in to comment.