diff --git a/src/dependent.jl b/src/dependent.jl index 3973798..236cc70 100644 --- a/src/dependent.jl +++ b/src/dependent.jl @@ -193,9 +193,10 @@ function _linearly_dependent( data = _feature_space(rules, A, B) l = length(rules) dependent = BitArray(undef, l) - current_rank = rank(data[:, 1:1]) + atol = 1e-6 + current_rank = rank(data[:, 1:1]; atol) for i in 1:l - new_rank = rank(view(data, :, 1:i+1)) + new_rank = rank(view(data, :, 1:i+1); atol) if current_rank < new_rank dependent[i] = false current_rank = new_rank @@ -266,33 +267,6 @@ function _filter_linearly_dependent(rules::Vector{Rule})::Vector{Rule} @assert length(indexes) == length(subset) @assert length(dependent_subset) == length(subset) dependent_indexes = indexes[dependent_subset] - r3 = Rule(TreePath(" X[i, 2] < 8000.0 "), [0.062], [0.386]) - if r3 in subset - for i in 1:length(dependent_subset) - println(dependent_subset[i], ": ", out[indexes[i]]) - end - # @show dependent_subset - # @show A - # @show B - # @show subset - retainables = out[indexes[.!dependent_subset]] - # @show r3 in retainables - # @show retainables - after_remove = deleteat!(copy(out), sort(dependent_indexes)) - # This is wrong. It should be in there still. - # @show r3 in after_remove - end - r4 = Rule(TreePath(" X[i, 2] ≥ 8000 "), [0.386], [0.062]) - if r4 in subset - retainables = out[indexes[.!dependent_subset]] - if !(r4 in retainables) - @show A - @show B - for r in subset - @show r - end - end - end deleteat!(out, sort(dependent_indexes)) end return out diff --git a/test/dependent.jl b/test/dependent.jl index deabc02..8606622 100644 --- a/test/dependent.jl +++ b/test/dependent.jl @@ -103,6 +103,22 @@ end @test S._process_rules(repeat([r1], 10), 10) == [r1] +@testset "rank calculation is precise enough" begin + A = S.Split(S.SplitPoint(2, 8000.0f0, "2"), :L) + B = S.Split(S.SplitPoint(1, 32000.0f0, "1"), :L) + n = 34 + dependent = S._linearly_dependent([repeat([r2, r1], 34); r4], A, B) + expected = Bool[0; repeat([true], 2n-1); 0] + @test length(dependent) == length(expected) + @test dependent == expected + + n = 1_000 + dependent = S._linearly_dependent([repeat([r2, r1], n); r4], A, B) + expected = Bool[0; repeat([true], 2n-1); 0] + @test length(dependent) == length(expected) + @test dependent == expected +end + function _canonicalize(rules::Vector{SIRUS.Rule}) [length(r.path.splits) == 1 ? SIRUS._left_rule(r) : r for r in rules] end @@ -121,47 +137,6 @@ expected = _canonicalize(expected) @test length(S._filter_linearly_dependent(allrules)) == 9 @test length(S._filter_linearly_dependent(allrules)) == 9 @test length(S._filter_linearly_dependent([r1])) == 1 - -actual = S._filter_linearly_dependent(repeat(allrules, 34)) -for r in actual - @show r in expected, r -end -for r in expected - @show r in actual, r -end -@show r3 in actual, r3 -# @test r3 in actual -# @test length(actual) == 9 - -# @test length(S._filter_linearly_dependent(repeat(allrules, 200))) == 9 - -A = S.Split(S.SplitPoint(2, 8000.0f0, "2"), :L) -B = S.Split(S.SplitPoint(1, 32000.0f0, "1"), :L) -dependent = S._linearly_dependent([repeat([r2, r1], 34); r4], A, B) -# THIS IS A CLEAR BUG! -@test dependent == Bool[0; repeat([true], 67); 0] - -A = S.Split(S.SplitPoint(2, 8000.0f0, "2"), :L) -B = S.Split(S.SplitPoint(4, 12.0f0, "4"), :L) -rules = [ - repeat([r16], 34); - repeat([r4, r3], 34); - repeat([r17], 34); -] -dependent = S._linearly_dependent(rules, A, B) -for i in 1:length(dependent) - println(dependent[i], ": ", rules[i]) -end -filtered = rules[findall(.!dependent)] -@test r3 in filtered -@test !(r4 in filtered) -@test r16 in filtered -@test !(r17 in filtered) - -filtered = S._filter_linearly_dependent(rules) -@test r3 in filtered -@test !(r4 in filtered) -@test r16 in filtered -@test !(r17 in filtered) +@test length(S._filter_linearly_dependent(repeat(allrules, 200))) == 9 nothing