Skip to content

Commit

Permalink
Fix a bug in the rank calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Jun 27, 2023
1 parent eb74b60 commit 20624a5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 71 deletions.
32 changes: 3 additions & 29 deletions src/dependent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ function _linearly_dependent(
data = _feature_space(rules, A, B)
l = length(rules)
dependent = BitArray(undef, l)
current_rank = rank(data[:, 1:1])
atol = 1e-6
current_rank = rank(data[:, 1:1]; atol)
for i in 1:l
new_rank = rank(view(data, :, 1:i+1))
new_rank = rank(view(data, :, 1:i+1); atol)
if current_rank < new_rank
dependent[i] = false
current_rank = new_rank
Expand Down Expand Up @@ -266,33 +267,6 @@ function _filter_linearly_dependent(rules::Vector{Rule})::Vector{Rule}
@assert length(indexes) == length(subset)
@assert length(dependent_subset) == length(subset)
dependent_indexes = indexes[dependent_subset]
r3 = Rule(TreePath(" X[i, 2] < 8000.0 "), [0.062], [0.386])
if r3 in subset
for i in 1:length(dependent_subset)
println(dependent_subset[i], ": ", out[indexes[i]])
end
# @show dependent_subset
# @show A
# @show B
# @show subset
retainables = out[indexes[.!dependent_subset]]
# @show r3 in retainables
# @show retainables
after_remove = deleteat!(copy(out), sort(dependent_indexes))
# This is wrong. It should be in there still.
# @show r3 in after_remove
end
r4 = Rule(TreePath(" X[i, 2] ≥ 8000 "), [0.386], [0.062])
if r4 in subset
retainables = out[indexes[.!dependent_subset]]
if !(r4 in retainables)
@show A
@show B
for r in subset
@show r
end
end
end
deleteat!(out, sort(dependent_indexes))
end
return out
Expand Down
59 changes: 17 additions & 42 deletions test/dependent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,22 @@ end

@test S._process_rules(repeat([r1], 10), 10) == [r1]

@testset "rank calculation is precise enough" begin
A = S.Split(S.SplitPoint(2, 8000.0f0, "2"), :L)
B = S.Split(S.SplitPoint(1, 32000.0f0, "1"), :L)
n = 34
dependent = S._linearly_dependent([repeat([r2, r1], 34); r4], A, B)
expected = Bool[0; repeat([true], 2n-1); 0]
@test length(dependent) == length(expected)
@test dependent == expected

n = 1_000
dependent = S._linearly_dependent([repeat([r2, r1], n); r4], A, B)
expected = Bool[0; repeat([true], 2n-1); 0]
@test length(dependent) == length(expected)
@test dependent == expected
end

function _canonicalize(rules::Vector{SIRUS.Rule})
[length(r.path.splits) == 1 ? SIRUS._left_rule(r) : r for r in rules]
end
Expand All @@ -121,47 +137,6 @@ expected = _canonicalize(expected)
@test length(S._filter_linearly_dependent(allrules)) == 9
@test length(S._filter_linearly_dependent(allrules)) == 9
@test length(S._filter_linearly_dependent([r1])) == 1

actual = S._filter_linearly_dependent(repeat(allrules, 34))
for r in actual
@show r in expected, r
end
for r in expected
@show r in actual, r
end
@show r3 in actual, r3
# @test r3 in actual
# @test length(actual) == 9

# @test length(S._filter_linearly_dependent(repeat(allrules, 200))) == 9

A = S.Split(S.SplitPoint(2, 8000.0f0, "2"), :L)
B = S.Split(S.SplitPoint(1, 32000.0f0, "1"), :L)
dependent = S._linearly_dependent([repeat([r2, r1], 34); r4], A, B)
# THIS IS A CLEAR BUG!
@test dependent == Bool[0; repeat([true], 67); 0]

A = S.Split(S.SplitPoint(2, 8000.0f0, "2"), :L)
B = S.Split(S.SplitPoint(4, 12.0f0, "4"), :L)
rules = [
repeat([r16], 34);
repeat([r4, r3], 34);
repeat([r17], 34);
]
dependent = S._linearly_dependent(rules, A, B)
for i in 1:length(dependent)
println(dependent[i], ": ", rules[i])
end
filtered = rules[findall(.!dependent)]
@test r3 in filtered
@test !(r4 in filtered)
@test r16 in filtered
@test !(r17 in filtered)

filtered = S._filter_linearly_dependent(rules)
@test r3 in filtered
@test !(r4 in filtered)
@test r16 in filtered
@test !(r17 in filtered)
@test length(S._filter_linearly_dependent(repeat(allrules, 200))) == 9

nothing

0 comments on commit 20624a5

Please sign in to comment.