diff --git a/README.md b/README.md index 3da939b..b5403b1 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@

SIRUS.jl

- Explainable machine learning via rule extraction + Interpretable Machine Learning via Rule Extraction

diff --git a/docs/src/implementation-overview.md b/docs/src/implementation-overview.md index c1e4349..ee46d23 100644 --- a/docs/src/implementation-overview.md +++ b/docs/src/implementation-overview.md @@ -198,7 +198,7 @@ In other words, this step is ignored because it seems like a premature optimizat The second step is more important and more involved. As said before, the second step is to remove the least important linear combinations of other rules. -An example of this is shown in the original paper (Bénard et al., [2021](http://proceedings.mlr.press/v130/benard21a.html), Table 3 (the second) in Section 4 _Post-treatment Illustration_ of the Supplementary PDF), which we repeat here: +An example of this is shown in the original paper (Bénard et al., [2021](http://proceedings.mlr.press/v130/benard21a.html), Table 3 (the second) in Section 4 _Post-treatment Illustration_ of the Supplementary PDF), which is repeated here: Rule Number | If Clause | Then | Else | Remove | Reason --- | --- | --- | --- | --- | --- @@ -223,13 +223,15 @@ Rule Number | If Clause | Then | Else | Remove | Reason Compared to the example from the Supplementary PDF, the features are renamed such that 2 = MMAX, 2 = MMIN, 3 = CACH, 4 = CHMIN, and 5 = MYCT and one sign was flipped in rule 14 after email correspondence with Clément. From this set of rules, the algorithm should remove rule 2, 4, 6, 9, 11, 12, 15, and 17. This is because rule 2, 4, 6, 9, and 11 are the reverse of an earlier rule and because 12, 15, and 17 are linearly dependent. +For the complex linearly dependent duplicates, remove the rule with the widest gap in the outputs. +In the example above, rule 7 has a wider gap than rule 12, which implies that it has a larger CART-splitting criterion and a higher occurrence frequency. The implementation for this can be done by converting the training data to a feature space in which each rule becomes a binary feature indicating whether the data point satisfies the constraint or not. This is quite computationally intensive since there is a lot of duplication in the data and it doesn't guarantee that all cases of duplication will be found since some may not be in the training set. Luckily, D.W. on StackExchange (https://cs.stackexchange.com/questions/152803) has provided a solution, which I will repeat here. The idea is to remove each rule ``r`` when it is linearly dependent on the preceding rules. -To do this, observe that a rule of the form ``A \: \& \: B`` can only depend on rules that use some combination of ``A``, ``!A``, ``B``, and/or ``!B``. +To do this, observe that a rule of the form ``A`` can only depend on rules ``A`` or ``!A``, and ``A \: \& \: B`` can only depend on rules that use some combination of ``A``, ``!A``, ``B``, and/or ``!B``. This works by iteratively calculating the rank and seeing whether the rank increases. We can assume that we are limited to a set of rules where either `A & B`, `A & !B`, `!A & B`, `!A & !B`, `A`, `!A`, `B`, `!B` or `True`. diff --git a/src/dependent.jl b/src/dependent.jl index 88afcbd..cd394fd 100644 --- a/src/dependent.jl +++ b/src/dependent.jl @@ -1,92 +1,136 @@ -_unique_features(split::Split) = split.splitpoint.feature -_unique_features(rule::Rule) = unique(_unique_features.(rule.path.splits)) -_unique_features(rules::Vector{Rule}) = unique(reduce(vcat, _unique_features.(rules))) +"Return whether `clause1` implies `clause2`." +function _implies(clause1::Split, clause2::Split)::Bool + if _feature(clause1) == _feature(clause2) + if _direction(clause1) == :L + if _direction(clause2) == :L + return _value(clause1) ≤ _value(clause2) + else + return false + end + else + if _direction(clause2) == :R + return _value(clause1) ≥ _value(clause2) + else + return false + end + end + else + return false + end +end """ -Return a point which satisifies `A` and `B`. -This assumes that `A` and `B` contain the features in the same order as `_unique_features`. -Basically, this point generation is a way to encode the information such that the constraints can be consistently answered. +Return whether `condition` implies `rule`, that is, whether `A & B => rule`. """ -function _point(A::Split, B::Split) - va = _value(A) - a = _direction(A) == :L ? va - 1 : va - vb = _value(B) - b = _direction(B) == :L ? vb - 1 : vb - return [a, b] -end - -"Return whether `point` satisifies `rule`." -function _satisfies(unique_features::Vector{Int}, point::Vector, rule::Rule) - for split in _splits(rule) - index = findfirst(==(_feature(split)), unique_features) - value = point[index] - threshold = _value(split) - if !(_direction(split) == :L ? value < threshold : value ≥ threshold) - return false - end +function _implies(condition::Tuple{Split, Split}, rule::Rule) + A, B = condition + splits = _splits(rule) + implied = map(splits) do split + _implies(A, split) || _implies(B, split) end - return true + return all(implied) end """ +Return a binary space which can be used to determine whether rules are linearly dependent. + +For example, for the conditions + +- u: A ≥ 3 & B ≥ 2 (U & V) +- v: A ≥ 3 & B < 2 (U & !V) +- w: A < 3 & B ≥ 2 (!U & V) +- x: A < 3 & B < 2 (!U & !V) + +For example, given the following rules: + +Rule 1: x[i, 1] < 32000 +Rule 5: x[i, 3] < 64 +Rule 7: x[i, 1] ≥ 32000 & x[i, 3] < 64 +Rule 12: x[i, 1] < 32000 & x[i, 3] < 64 + +and the following clauses + +A: x[i, 1] < 32000 +B: x[i, 3] < 64 + +This function generates a matrix containing a row for + +- A && B (x[i, 1] < 32000 & x[i, 3] < 64) +- A && !B (x[i, 1] < 32000 & x[i, 3] ≥ 64) +- !A && B (x[i, 1] ≥ 32000 & x[i, 3] < 64) +- !A && !B (x[i, 1] ≥ 32000 & x[i, 3] ≥ 64) + +and one zeroes column: + +| Condition | Ones | R1 | R5 | R7 | R12 | +| ---------- | ---- | -- | -- | -- | --- | +| A && B | 1 | 1 | 1 | 0 | 0 | +| A && !B | 1 | 1 | 0 | 0 | 0 | +| !A && B | 1 | 0 | 1 | 0 | 1 | +| !A && !B | 1 | 0 | 0 | 1 | 0 | + +In other words, the matrix represents which rules are implied by each syntetic datapoint +(conditions in the rows). +Next, this can be used to determine which rules are linearly dependent by checking whether +the rank increases when adding rules. + +# Example + +```jldoctest +julia> A = SIRUS.Split(SIRUS.SplitPoint(1, 32000.0f0, "1"), :L); + +julia> B = SIRUS.Split(SIRUS.SplitPoint(3, 64.0f0, "3"), :L); + +julia> r1 = SIRUS.Rule(TreePath(" X[i, 1] < 32000.0 "), [0.061], [0.408]); + +julia> r5 = SIRUS.Rule(TreePath(" X[i, 3] < 64.0 "), [0.056], [0.334]); + +julia> r7 = SIRUS.Rule(TreePath(" X[i, 1] ≥ 32000.0 & X[i, 3] ≥ 64.0 "), [0.517], [0.067]); + +julia> r12 = SIRUS.Rule(TreePath(" X[i, 1] ≥ 32000.0 & X[i, 3] < 64.0 "), [0.192], [0.102]); + +julia> SIRUS.rank(SIRUS._feature_space([r1, r5], A, B)) +3 + +julia> SIRUS.rank(SIRUS._feature_space([r1, r5, r7], A, B)) +4 + +julia> SIRUS.rank(SIRUS._feature_space([r1, r5, r7, r12], A, B)) +4 +``` """ -function _feature_space(rules::AbstractVector{Rule}, A::Split, B::Split) +function _feature_space(rules::AbstractVector{Rule}, A::Split, B::Split)::BitMatrix l = length(rules) - data = BitArray(undef, 4, l + 1) + data = BitMatrix(undef, 4, l + 1) for i in 1:4 data[i, 1] = 1 end - F = [_feature(A), _feature(B)] nA = _reverse(A) nB = _reverse(B) for col in 2:l+1 rule = rules[col-1] - data[1, col] = _satisfies(F, _point(A, B), rule) - data[2, col] = _satisfies(F, _point(A, nB), rule) - data[3, col] = _satisfies(F, _point(nA, B), rule) - data[4, col] = _satisfies(F, _point(nA, nB), rule) + data[1, col] = _implies((A, B), rule) + data[2, col] = _implies((A, nB), rule) + data[3, col] = _implies((nA, B), rule) + data[4, col] = _implies((nA, nB), rule) end return data end -""" -Return a vector of booleans with a true for every rule in `rules` that is linearly dependent on a combination of the previous rules. -To find rules for this method, collect all rules containing some feature for each pair of features. -That should be a fairly quick way to find subsets that are easy to process. -""" -function _linearly_dependent( - rules::AbstractVector{Rule}, - A::Split, - B::Split - )::BitArray - data = _feature_space(rules, A, B) - l = length(rules) - dependent = BitArray(undef, l) - result = 1 - for i in 1:l - new_result = rank(view(data, :, 1:i+1)) - rank_increased = new_result == result + 1 - if rank_increased - result = new_result - dependent[i] = false - else - result = new_result - dependent[i] = true - end - end - return dependent -end +_left_split(s::Split) = _direction(s) == :L ? s : _reverse(s) """ Return a vector of unique left splits for `rules`. -These splits are required to form `[A, B]` pairs in the next step. +These splits will be used to form `(A, B)` pairs and generate the feature space. +For example, the pair `x[i, 1] < 32000` (A) and `x[i, 3] < 64` (B) will be used to generate +the feature space `A & B`, `A & !B`, `!A & B`, `!A & !B`. """ function _unique_left_splits(rules::Vector{Rule}) splits = Split[] for rule in rules for split in _splits(rule) - left_split = _direction(split) == :L ? split : _reverse(split) + left_split = _left_split(split) if !(left_split in splits) push!(splits, left_split) end @@ -95,16 +139,18 @@ function _unique_left_splits(rules::Vector{Rule}) return splits end -"Return the product of `V` and `V` for all pairs (v_i, v_j) where i < j." +""" +Return all unique pairs of elements in `V`. +More formally, return all pairs (v_i, v_j) where i < j. +""" function _left_triangular_product(V::Vector{T}) where {T} l = length(V) - nl = l - 1 product = Tuple{T,T}[] for i in 1:l left = V[i] for j in 1:l - right = V[j] if i < j + right = V[j] push!(product, (left, right)) end end @@ -112,8 +158,6 @@ function _left_triangular_product(V::Vector{T}) where {T} return product end -_left_split(s::Split) = _direction(s) == :L ? s : _reverse(s) - """ Return whether some rule is either related to `A` or `B` or both. Here, it is very important to get rid of rules which are about the same feature but different thresholds. @@ -123,51 +167,119 @@ function _related_rule(rule::Rule, A::Split, B::Split)::Bool @assert _direction(A) == :L @assert _direction(B) == :L splits = _splits(rule) - fa = _feature(A) - fb = _feature(B) if length(splits) == 1 split = only(splits) left_split = _left_split(split) return left_split == A || left_split == B - else + elseif length(splits) == 2 l1 = _left_split(splits[1]) l2 = _left_split(splits[2]) return (l1 == A && l2 == B) || (l1 == B && l2 == A) + else + @error "Rule $rule has more than two splits; this is not supported." end end -function _linearly_dependent(rules::Vector{Rule})::BitVector - S = _unique_left_splits(rules) - P = _left_triangular_product(S) - # A `BitVector(undef, length(rules))` here will cause randomness. - dependent = falses(length(rules)) - for (A, B) in P - indexes = filter(i -> _related_rule(rules[i], A, B), 1:length(rules)) - subset = view(rules, indexes) - dependent_subset = _linearly_dependent(subset, A, B) - # Only allow setting true to avoid setting things to false. - for i in 1:length(dependent_subset) - if dependent_subset[i] - dependent[indexes[i]] = true - end +""" +Return a vector of booleans with a true for every rule in `rules` that is linearly dependent on a combination of the previous rules. +To find rules for this method, collect all rules containing some feature for each pair of features. +That should be a fairly quick way to find subsets that are easy to process. +""" +function _linearly_dependent( + rules::AbstractVector{Rule}, + A::Split, + B::Split + )::BitArray + data = _feature_space(rules, A, B) + l = length(rules) + dependent = BitArray(undef, l) + atol = 1e-6 + current_rank = rank(data[:, 1:1]; atol) + for i in 1:l + new_rank = rank(view(data, :, 1:i+1); atol) + if current_rank < new_rank + dependent[i] = false + current_rank = new_rank + else + dependent[i] = true end end return dependent end +function _gap_size(rule::Rule) + @assert length(rule.then) == length(rule.otherwise) + gap_size_per_class = abs.(rule.then .- rule.otherwise) + sum(gap_size_per_class) +end + """ -Return the subset of `rules` which are not linearly dependent. -This is based on a complex heuristic involving calculating the rank of the matrix, see above StackExchange link for more information. -Also note that this method assumes that the rules are assumed to be in ordered by frequency of occurence in the trees. -This assumption is used to filter less common rules when finding linearly dependent rules. +Return the vector rule sorted by decreasing gap size. +This allows the linearly dependent filter to remove the rules further down the list since +they have a smaller gap. """ -function _filter_linearly_dependent(rules::Vector{Rule})::Vector{Rule} - dependent = _linearly_dependent(rules) - out = Rule[] - for i in 1:length(dependent) - if !dependent[i] - push!(out, rules[i]) +function _sort_by_gap_size(rules::Vector{Rule}) + return sort(rules; by=_gap_size, rev=true) +end + +""" +Simplify the rules that contain a single split by only retaining rules that point left and +removing duplicates. +""" +function _simplify_single_rules(rules::Vector{Rule})::Vector{Rule} + out = Set{Rule}() + for rule in rules + splits = _splits(rule) + if length(splits) == 1 + left_rule = _left_rule(rule) + push!(out, left_rule) + else + push!(out, rule) end end + return collect(out) +end + +""" +Return a vector of rules that are not linearly dependent on any other rule. + +This is done by considering each pair of splits. +For example, considers the pair `x[i, 1] < 32000` (A) and `x[i, 3] < 64` (B). +Then, for each rule, it checks whether the rule is linearly dependent on the pair. +As soon as a dependent rule is found, it is removed from the set to avoid considering it again. +If we don't do this, we might remove some rule `r` that causes another rule to be linearly +dependent in one related set, but then is removed in another related set. +""" +function _filter_linearly_dependent(rules::Vector{Rule})::Vector{Rule} + sorted = _sort_by_gap_size(rules) + S = _unique_left_splits(sorted) + pairs = _left_triangular_product(S) + out = copy(sorted) + for (A, B) in pairs + indexes = filter(i -> _related_rule(out[i], A, B), 1:length(out)) + subset = view(out, indexes) + dependent_subset = _linearly_dependent(subset, A, B) + @assert length(indexes) == length(subset) + @assert length(dependent_subset) == length(subset) + dependent_indexes = indexes[dependent_subset] + deleteat!(out, sort(dependent_indexes)) + end return out end + +""" +Return a linearly independent subset of `rules` of length ≤ `max_rules`. + +!!! note + This doesn't use p0 like is done in the paper. + The problem, IMO, with p0 is that it is very difficult to decide beforehand what p0 is suitable and so it requires hyperparameter tuning. + Instead, luckily, the linearly dependent filter is quite fast here, so passing a load of rules into that and then selecting the first `max_rules` is feasible. +""" +function _process_rules( + rules::Vector{Rule}, + max_rules::Int + )::Vector{Rule} + simplified = _simplify_single_rules(rules) + filtered = _filter_linearly_dependent(simplified) + return first(filtered, max_rules) +end diff --git a/src/rules.jl b/src/rules.jl index 0ced07b..d721db1 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -1,8 +1,11 @@ """ - Split + Split(splitpoint::SplitPoint, direction::Symbol) -> Split + Split(feature::Int, name::String, splitval::Float32, direction::Symbol) -> Split A split in a tree. Each rule is based on one or more splits. + +Data can be accessed via `_feature`, `_value`, `_feature_name`, `_direction`, and `_reverse`. """ struct Split splitpoint::SplitPoint @@ -20,17 +23,24 @@ _direction(split::Split) = split.direction _reverse(split::Split) = Split(split.splitpoint, split.direction == :L ? :R : :L) """ - TreePath + TreePath(splits::Vector{Split}) -> TreePath + TreePath(text::String) -> TreePath A path of length `d` is defined as consisting of `d` splits. See SIRUS paper page 434. Typically, `d ≤ 2`. Note that a path can also be a path to a node; not necessarily a leaf. +Another term for a treepath is a _condition_. +For example, `X[i, 1] < 3 & X[i, 2] < 1` is a condition. + +Data can be accessed via `_splits`. """ struct TreePath splits::Vector{Split} end +_splits(path::TreePath) = path.splits + function TreePath(text::String) try comparisons = split(strip(text), '&') @@ -286,35 +296,6 @@ function _count_unique(V::AbstractVector{T}) where T return counts end -""" -Return a linearly independent subset of `rules` of length ≤ `max_rules`. - -!!! note - This doesn't use p0 like is done in the paper. - The problem, IMO, with p0 is that it is very difficult to decide beforehand what p0 is suitable and so it requires hyperparameter tuning. - Instead, luckily, the linearly dependent filter is quite fast here, so passing a load of rules into that and then selecting the first `max_rules` is feasible. -""" -function _process_rules( - rules::Vector{Rule}, - algo::Algorithm, - max_rules::Int - )::Vector{Rule} - # This loop is an optimization which manually takes a p0 and checks whether we end up with - # enough rules. If not, we loop again with more rules. - for i in 1:3 - required_rule_guess = i^2 * 10 * max_rules - before = first(rules, required_rule_guess) - filtered = _filter_linearly_dependent(before) - too_few = length(filtered) < max_rules - more_possible = required_rule_guess < length(rules) - if i < 3 && too_few && more_possible - continue - end - return first(filtered, max_rules) - end - @error "This should never happen" -end - struct StableRules{T} <: StableModel rules::Vector{Rule} algo::Algorithm @@ -347,7 +328,7 @@ function StableRules( outcome, model::Probabilistic )::StableRules - processed = _process_rules(rules, algo, model.max_rules) + processed = _process_rules(rules, model.max_rules) weights = _weights(processed, algo, classes, data, outcome, model) filtered_rules, filtered_weights = _remove_zero_weights(processed, weights) return StableRules(filtered_rules, algo, classes, filtered_weights) diff --git a/test/Project.toml b/test/Project.toml index a002acb..7fd4029 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,11 +4,13 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" LightGBM = "7acf609c-83a4-11e9-1ffb-b912bcd3b04a" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" diff --git a/test/classification.jl b/test/classification.jl index b9e5920..119a605 100644 --- a/test/classification.jl +++ b/test/classification.jl @@ -93,11 +93,11 @@ end end fpreds = DecisionTree.apply_forest(dforest, data) -@show accuracy(fpreds, y) +# @show accuracy(fpreds, y) @test 0.95 < accuracy(fpreds, y) sfpreds = SIRUS._predict(sforest, data) -@show accuracy(mode.(sfpreds), y) +# @show accuracy(mode.(sfpreds), y) @test 0.95 < accuracy(mode.(sfpreds), y) empty_forest = SIRUS.StableForest(Union{SIRUS.Leaf, SIRUS.Node}[], algo, [1]) diff --git a/test/dependent.jl b/test/dependent.jl index 3ec8a97..1d2586d 100644 --- a/test/dependent.jl +++ b/test/dependent.jl @@ -6,28 +6,25 @@ # 3: CACH # 4: CHMIN # 5: MYCT -r1 = S.Rule(S.TreePath(" X[i, 1] < 32000 "), [0.61], [0.408]) -r2 = S.Rule(S.TreePath(" X[i, 1] ≥ 32000 "), [0.408], [0.61]) - -r3 = S.Rule(S.TreePath(" X[i, 2] < 8000 "), [0.62], [0.386]) -r4 = S.Rule(S.TreePath(" X[i, 2] ≥ 8000 "), [0.386], [0.62]) -r5 = S.Rule(S.TreePath(" X[i, 3] < 64 "), [0.56], [0.334]) -r6 = S.Rule(S.TreePath(" X[i, 3] ≥ 64 "), [0.334], [0.56]) -r7 = S.Rule(S.TreePath(" X[i, 1] ≥ 32000 & X[i, 3] ≥ 64 "), [0.517], [0.67]) -r8 = S.Rule(S.TreePath(" X[i, 4] < 8 "), [0.50], [0.312]) -r9 = S.Rule(S.TreePath(" X[i, 4] ≥ 8 "), [0.312], [0.50]) -r10 = S.Rule(S.TreePath(" X[i, 5] < 50 "), [0.335], [0.58]) -r11 = S.Rule(S.TreePath(" X[i, 5] ≥ 50 "), [0.58], [0.335]) +r1 = S.Rule(S.TreePath(" X[i, 1] < 32000 "), [0.061], [0.408]) +r2 = S.Rule(S.TreePath(" X[i, 1] ≥ 32000 "), [0.408], [0.061]) + +r3 = S.Rule(S.TreePath(" X[i, 2] < 8000 "), [0.062], [0.386]) +r4 = S.Rule(S.TreePath(" X[i, 2] ≥ 8000 "), [0.386], [0.062]) +r5 = S.Rule(S.TreePath(" X[i, 3] < 64 "), [0.056], [0.334]) +r6 = S.Rule(S.TreePath(" X[i, 3] ≥ 64 "), [0.334], [0.056]) +r7 = S.Rule(S.TreePath(" X[i, 1] ≥ 32000 & X[i, 3] ≥ 64 "), [0.517], [0.067]) +r8 = S.Rule(S.TreePath(" X[i, 4] < 8 "), [0.050], [0.312]) +r9 = S.Rule(S.TreePath(" X[i, 4] ≥ 8 "), [0.312], [0.050]) +r10 = S.Rule(S.TreePath(" X[i, 5] < 50 "), [0.335], [0.058]) +r11 = S.Rule(S.TreePath(" X[i, 5] ≥ 50 "), [0.058], [0.335]) r12 = S.Rule(S.TreePath(" X[i, 1] ≥ 32000 & X[i, 3] < 64 "), [0.192], [0.102]) -r13 = S.Rule(S.TreePath(" X[i, 1] < 32000 & X[i, 4] ≥ 8 "), [0.554], [0.73]) +r13 = S.Rule(S.TreePath(" X[i, 1] < 32000 & X[i, 4] ≥ 8 "), [0.157], [0.100]) # First constraint is updated based on a comment from Clément via email. -r14 = S.Rule(S.TreePath(" X[i, 1] ≥ 32000 & X[i, 4] ≥ 12 "), [0.192], [0.102]) -r15 = S.Rule(S.TreePath(" X[i, 1] ≥ 32000 & X[i, 4] < 12 "), [0.192], [0.102]) -r16 = S.Rule(S.TreePath(" X[i, 2] ≥ 8000 & X[i, 4] ≥ 12 "), [0.586], [0.76]) -r17 = S.Rule(S.TreePath(" X[i, 2] ≥ 8000 & X[i, 4] < 12 "), [0.236], [0.94]) - -@test S._unique_features([r1, r7, r12]) == [1, 3] -@test sort(S._unique_features([r1, r7, r12, r17])) == [1, 2, 3, 4] +r14 = S.Rule(S.TreePath(" X[i, 1] ≥ 32000 & X[i, 4] ≥ 12 "), [0.554], [0.073]) +r15 = S.Rule(S.TreePath(" X[i, 1] ≥ 32000 & X[i, 4] < 12 "), [0.192], [0.096]) +r16 = S.Rule(S.TreePath(" X[i, 2] ≥ 8000 & X[i, 4] ≥ 12 "), [0.586], [0.076]) +r17 = S.Rule(S.TreePath(" X[i, 2] ≥ 8000 & X[i, 4] < 12 "), [0.236], [0.094]) @test S._filter_linearly_dependent([r1, r2, r3, r5]) == [r1, r3, r5] @@ -86,32 +83,58 @@ let @test !(S._related_rule(r1, _Split(1, 31000.0f0, :L), B)) end -@test S._linearly_dependent([r1, r3]) == Bool[0, 0] -@test S._linearly_dependent([r1, r5, r7, r12]) == Bool[0, 0, 0, 1] +@test S._filter_linearly_dependent([r1, r3]) == [r1, r3] -@test S._filter_linearly_dependent([r1, r5, r7, r12]) == [r1, r5, r7] +@testset "r12 is removed because r7 has a wider gap" begin + @test Set(S._filter_linearly_dependent([r1, r5, r7, r12])) == Set([r1, r5, r7]) + @test Set(S._filter_linearly_dependent([r1, r5, r12, r7])) == Set([r1, r5, r7]) +end -@test S._linearly_dependent([r3, r16, r17]) == Bool[0, 0, 1] -@test S._linearly_dependent([r3, r16, r13]) == Bool[0, 0, 0] +@test Set(S._filter_linearly_dependent([r3, r16, r17])) == Set([r3, r16]) +@test Set(S._filter_linearly_dependent([r3, r16, r13])) == Set([r3, r13, r16]) -@testset "r12 is removed because r7 has a wider gap" begin - @test S._filter_linearly_dependent([r1, r5, r7, r12]) == [r1, r5, r7] - @test S._filter_linearly_dependent([r1, r5, r12, r7]) == [r1, r5, r7] +@testset "single rule is not linearly dependent" begin + A = S.Split(S.SplitPoint(4, 12.0f0, "4"), :L) + B = S.Split(S.SplitPoint(4, 8.0f0, "4"), :L) + rule = SIRUS.Rule(TreePath(" X[i, 4] < 8.0 "), [0.05], [0.312]) + @test S._feature_space([rule], A, B)[:, 2] == Bool[1, 0, 1, 0] + @test S._linearly_dependent([rule], A, B) == Bool[0] end -let - allrules = [r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15, r16, r17] - expected = [r1, r3, r5, r7, r8, r10, r13, r14, r16] - @test S._filter_linearly_dependent(allrules) == expected - - # allrules = shuffle(_rng(), allrules) - # @test Set(S._filter_linearly_dependent(allrules)) == Set(expected) - - algo = SIRUS.Classification() - @test length(S._process_rules(allrules, algo, 9)) == 9 - @test length(S._process_rules(allrules, algo, 10)) == 9 - @test length(S._process_rules([r1], algo, 9)) == 1 - @test length(S._process_rules(repeat(allrules, 200), algo, 9)) == 9 +@test S._process_rules(repeat([r1], 10), 10) == [r1] + +@testset "rank calculation is precise enough" begin + A = S.Split(S.SplitPoint(2, 8000.0f0, "2"), :L) + B = S.Split(S.SplitPoint(1, 32000.0f0, "1"), :L) + n = 34 + dependent = S._linearly_dependent([repeat([r2, r1], 34); r4], A, B) + expected = Bool[0; repeat([true], 2n-1); 0] + @test length(dependent) == length(expected) + @test dependent == expected + + n = 1_000 + dependent = S._linearly_dependent([repeat([r2, r1], n); r4], A, B) + expected = Bool[0; repeat([true], 2n-1); 0] + @test length(dependent) == length(expected) + @test dependent == expected end +function _canonicalize(rules::Vector{SIRUS.Rule}) + [length(r.path.splits) == 1 ? SIRUS._left_rule(r) : r for r in rules] +end + +allrules = [r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15, r16, r17] +expected = [r1, r3, r5, r7, r8, r10, r13, r14, r16] +actual = S._filter_linearly_dependent(allrules) +@test Set(actual) == Set(expected) + +allrules = shuffle(_rng(), allrules) +actual = S._process_rules(allrules, 100) +@test Set(actual) == Set(expected) + +@test length(S._filter_linearly_dependent(allrules)) == 9 +@test length(S._filter_linearly_dependent(allrules)) == 9 +@test length(S._filter_linearly_dependent([r1])) == 1 +@test length(S._filter_linearly_dependent(repeat(allrules, 200))) == 9 + nothing diff --git a/test/mlj.jl b/test/mlj.jl index c20e44e..a6c96d5 100644 --- a/test/mlj.jl +++ b/test/mlj.jl @@ -25,7 +25,7 @@ datasets = Dict{String,Tuple}( end, "boston" => boston(), "make_regression" => let - make_regression(200, 3; noise=0.0, sparse=0.0, outliers=0.0) + make_regression(200, 3; noise=0.0, sparse=0.0, outliers=0.0, rng=_rng()) end ) @@ -156,7 +156,7 @@ let @test 0.80 < _score(e) e = _evaluate!(results, "titanic", StableRulesClassifier, hyper) - @test 0.80 < _score(e) + @test 0.79 < _score(e) end @testset "y as String" begin diff --git a/test/preliminaries.jl b/test/preliminaries.jl index 2121fa6..6d26896 100644 --- a/test/preliminaries.jl +++ b/test/preliminaries.jl @@ -5,6 +5,7 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" using CategoricalArrays: CategoricalValue, categorical, unwrap using CSV: CSV using DataDeps: DataDeps, DataDep, @datadep_str +using Documenter: DocMeta, doctest using MLDatasets: BostonHousing, Titanic using DataFrames: DataFrames, diff --git a/test/rules.jl b/test/rules.jl index 9c3b845..cf46fe5 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -68,8 +68,6 @@ let @test S._predict(model, [33000, 0, 61]) == [mean([0.408, 0.56])] end -@test Set(S._process_rules([r5, r1, r1], algo, 10)) == Set([r5, r1]) - function generate_rules() algo = S.Classification() forest = S._forest(_rng(), algo, X, y) diff --git a/test/runtests.jl b/test/runtests.jl index 4bc5c13..8545ef0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,11 @@ include("preliminaries.jl") +@testset "doctests" begin + # warn suppresses warnings when keys already exist. + DocMeta.setdocmeta!(SIRUS, :DocTestSetup, :(using SIRUS); recursive=true, warn=false) + doctest(SIRUS) +end + @testset "empiricalquantiles" begin include("empiricalquantiles.jl") end