From e4c60999a534c61625891794a7d164ef0bd481e4 Mon Sep 17 00:00:00 2001 From: rikhuijzer Date: Tue, 27 Jun 2023 07:27:55 +0200 Subject: [PATCH] Cleanup --- src/dependent.jl | 7 +++---- src/rules.jl | 2 +- test/classification.jl | 4 ++-- test/dependent.jl | 11 +++++------ 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/dependent.jl b/src/dependent.jl index 8013586..676dc2d 100644 --- a/src/dependent.jl +++ b/src/dependent.jl @@ -254,8 +254,7 @@ If we don't do this, we might remove some rule `r` that causes another rule to b dependent in one related set, but then is removed in another related set. """ function _filter_linearly_dependent(rules::Vector{Rule})::Vector{Rule} - filtered = _simplify_single_rules(rules) - sorted = _sort_by_gap_size(filtered) + sorted = _sort_by_gap_size(rules) S = _unique_left_splits(sorted) pairs = _left_triangular_product(S) out = copy(sorted) @@ -293,9 +292,9 @@ Return a linearly independent subset of `rules` of length ≤ `max_rules`. """ function _process_rules( rules::Vector{Rule}, - algo::Algorithm, max_rules::Int )::Vector{Rule} - filtered = _filter_linearly_dependent(rules) + simplified = _simplify_single_rules(rules) + filtered = _filter_linearly_dependent(simplified) return first(filtered, max_rules) end diff --git a/src/rules.jl b/src/rules.jl index 9cc2da3..d721db1 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -328,7 +328,7 @@ function StableRules( outcome, model::Probabilistic )::StableRules - processed = _process_rules(rules, algo, model.max_rules) + processed = _process_rules(rules, model.max_rules) weights = _weights(processed, algo, classes, data, outcome, model) filtered_rules, filtered_weights = _remove_zero_weights(processed, weights) return StableRules(filtered_rules, algo, classes, filtered_weights) diff --git a/test/classification.jl b/test/classification.jl index b9e5920..119a605 100644 --- a/test/classification.jl +++ b/test/classification.jl @@ -93,11 +93,11 @@ end end fpreds = DecisionTree.apply_forest(dforest, data) -@show accuracy(fpreds, y) +# @show accuracy(fpreds, y) @test 0.95 < accuracy(fpreds, y) sfpreds = SIRUS._predict(sforest, data) -@show accuracy(mode.(sfpreds), y) +# @show accuracy(mode.(sfpreds), y) @test 0.95 < accuracy(mode.(sfpreds), y) empty_forest = SIRUS.StableForest(Union{SIRUS.Leaf, SIRUS.Node}[], algo, [1]) diff --git a/test/dependent.jl b/test/dependent.jl index 661b2e2..2f1fb3f 100644 --- a/test/dependent.jl +++ b/test/dependent.jl @@ -101,7 +101,7 @@ end @test S._linearly_dependent([rule], A, B) == Bool[0] end -@test S._filter_linearly_dependent(repeat([r1], 10)) == [r1] +@test S._process_rules(repeat([r1], 10), 10) == [r1] function _canonicalize(rules::Vector{SIRUS.Rule}) [length(r.path.splits) == 1 ? SIRUS._left_rule(r) : r for r in rules] @@ -118,10 +118,9 @@ actual = S._filter_linearly_dependent(canonical) expected = _canonicalize(expected) @test Set(actual) == Set(expected) -algo = SIRUS.Classification() -@test length(S._process_rules(allrules, algo, 9)) == 9 -@test length(S._process_rules(allrules, algo, 10)) == 9 -@test length(S._process_rules([r1], algo, 9)) == 1 -@test length(S._process_rules(repeat(allrules, 200), algo, 9)) == 9 +@test length(S._process_rules(allrules, 9)) == 9 +@test length(S._process_rules(allrules, 10)) == 9 +@test length(S._process_rules([r1], 9)) == 1 +@test length(S._process_rules(repeat(allrules, 200), 9)) == 9 nothing