diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml index ea3306e..72f0b5f 100644 --- a/.github/workflows/Docs.yml +++ b/.github/workflows/Docs.yml @@ -20,3 +20,4 @@ jobs: - uses: julia-actions/julia-docdeploy@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - run: echo "sirus.huijzer.xyz" > CNAME diff --git a/README.md b/README.md index 79229ed..65c1a39 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ CI - Documentation + Documentation Code Style Blue diff --git a/docs/make.jl b/docs/make.jl index 92ca3cc..45e4879 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -50,7 +50,7 @@ checkdocs = :none makedocs(; sitename, pages, format, modules, strict, checkdocs) deploydocs(; - branch="docs", + branch="docs-output", devbranch="main", repo="github.com/rikhuijzer/SIRUS.jl.git", push_preview=false diff --git a/docs/src/api.md b/docs/src/api.md index 2253597..ed235e7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -14,4 +14,6 @@ feature_names directions values(::SIRUS.Rule) satisfies +Cutpoints +cutpoints ``` diff --git a/docs/src/sirus.jl b/docs/src/sirus.jl index a0f4886..dd5e0f3 100644 --- a/docs/src/sirus.jl +++ b/docs/src/sirus.jl @@ -363,11 +363,11 @@ ST = SIRUS; function _plot_cutpoints(data::AbstractVector) fig = Figure(; resolution=(800, 100)) ax = Axis(fig[1, 1]) - cutpoints = Float64.(unique(ST._cutpoints(data, 10))) + cps = Float64.(unique(cutpoints(data, 10))) scatter!(ax, data, fill(1, length(data))) - vlines!(ax, cutpoints; color=:black, linestyle=:dash) - textlocs = [(c, 1.1) for c in cutpoints] - for cutpoint in cutpoints + vlines!(ax, cps; color=:black, linestyle=:dash) + textlocs = [(c, 1.1) for c in cps] + for cutpoint in cps annotation = string(round(cutpoint; digits=2))::String text!(ax, cutpoint + 0.2, 1.08; text=annotation, textsize=13) end diff --git a/src/SIRUS.jl b/src/SIRUS.jl index bf9b9e7..24f77aa 100644 --- a/src/SIRUS.jl +++ b/src/SIRUS.jl @@ -14,11 +14,16 @@ using Random: AbstractRNG, default_rng, seed!, shuffle using Statistics: mean, median using Tables: Tables, matrix -const Float = Float32 - export StableForestClassifier, StableRulesClassifier export feature_names, directions, satisfies +include("helpers.jl") +using .Helpers: nfeatures, view_feature + +include("empiricalquantiles.jl") +using .EmpiricalQuantiles: Cutpoints, cutpoints +export Cutpoints, cutpoints + include("forest.jl") include("rules.jl") include("weights.jl") diff --git a/src/empiricalquantiles.jl b/src/empiricalquantiles.jl new file mode 100644 index 0000000..82f3ca9 --- /dev/null +++ b/src/empiricalquantiles.jl @@ -0,0 +1,62 @@ +module EmpiricalQuantiles + +using ..Helpers: nfeatures, view_feature + +"Set of possible cutpoints, that is, empirical quantiles." +const Cutpoints = Vector{Float32} + +""" +Return a rough estimate for the index of the cutpoint. +Choose the highest suitable index if there is more than one suitable index. +The reason is that this will split the data nicely in combination with the `<` used later on. +For example, for [1, 2, 3, 4], both 2 and 3 satisfy the 0.5 quantile. +In this case, we pick the ceil, so 3. +Next, the tree will split on 3, causing left (<) to contain 1 and 2 and right (≥) to contain 3 and 4. +""" +function _rough_cutpoint_index_estimate(n::Int, quantile::Real) + Int(ceil(quantile * n)) +end + +"Return the empirical `quantile` for data `V`." +function _empirical_quantile(V::AbstractVector, quantile::Real) + @assert 0.0 ≤ quantile ≤ 1.0 + n = length(V) + index = _rough_cutpoint_index_estimate(n, quantile) + if index == 0 + index = 1 + end + if index == n + 1 + index = n + end + sorted = sort(V) + return Float32(sorted[index]) +end + +"Return a vector of `q` cutpoints taken from the empirical distribution from data `V`." +function cutpoints(V::AbstractVector, q::Int) + @assert 2 ≤ q + # Taking 2 extra to avoid getting minimum(V) and maximum(V) becoming cutpoints. + # Tree on left and right have always respectively length 0 and 1 then anyway. + length = q + 2 + quantiles = range(0.0; stop=1.0, length)[2:end-1] + return Float32[_empirical_quantile(V, quantile) for quantile in quantiles] +end + +""" +Return a vector of vectors containing +- one inner vector for each feature in the dataset and +- inner vectors containing the unique cutpoints, that is, `length(V[i])` ≤ `q` for all i in V. + +Using unique here to avoid checking splits twice. +""" +function cutpoints(X, q::Int) + p = nfeatures(X) + cps = Vector{Cutpoints}(undef, p) + for feature in 1:p + V = view_feature(X, feature) + cps[feature] = unique(cutpoints(V, q)) + end + return cps +end + +end # module diff --git a/src/forest.jl b/src/forest.jl index c680ee8..944f2da 100644 --- a/src/forest.jl +++ b/src/forest.jl @@ -45,69 +45,7 @@ function _information_gain( end """ -Return a rough estimate for the index of the cutpoint. -Choose the highest suitable index if there is more than one suitable index. -The reason is that this will split the data nicely in combination with the `<` used later on. -For example, for [1, 2, 3, 4], both 2 and 3 satisfy the 0.5 quantile. -In this case, we pick the ceil, so 3. -Next, the tree will split on 3, causing left (<) to contain 1 and 2 and right (≥) to contain 3 and 4. -""" -_rough_cutpoint_index_estimate(n::Int, quantile::Real) = Int(ceil(quantile * n)) - -"Return the empirical `quantile` for data `V`." -function _empirical_quantile(V::AbstractVector, quantile::Real) - @assert 0.0 ≤ quantile ≤ 1.0 - n = length(V) - index = _rough_cutpoint_index_estimate(n, quantile) - if index == 0 - index = 1 - end - if index == n + 1 - index = n - end - sorted = sort(V) - return Float(sorted[index]) -end - -"Return a vector of `q` cutpoints taken from the empirical distribution from data `V`." -function _cutpoints(V::AbstractVector, q::Int) - @assert 2 ≤ q - # Taking 2 extra to avoid getting minimum(V) and maximum(V) becoming cutpoints. - # Tree on left and right have always respectively length 0 and 1 then anyway. - length = q + 2 - quantiles = range(0.0; stop=1.0, length)[2:end-1] - return Float[_empirical_quantile(V, quantile) for quantile in quantiles] -end - -"Return the number of features `p` in a dataset `X`." -_p(X::AbstractMatrix) = size(X, 2) -_p(X) = length(Tables.columnnames(X)) - -"Set of possible cutpoints, that is, numbers from the empirical quantiles." -const Cutpoints = Vector{Float} - -_view_feature(X::AbstractMatrix, feature::Int) = view(X, :, feature) -_view_feature(X, feature::Int) = X[feature] - -""" -Return a vector of vectors containing -- one inner vector for each feature in the dataset and -- inner vectors containing the unique cutpoints, that is, `length(V[i])` ≤ `q` for all i in V. - -Using unique here to avoid checking splits twice. -""" -function _cutpoints(X, q::Int) - p = _p(X) - cutpoints = Vector{Cutpoints}(undef, p) - for feature in 1:p - V = _view_feature(X, feature) - cutpoints[feature] = unique(_cutpoints(V, q)) - end - return cutpoints -end - -""" - SplitPoint(feature::Int, value::Float, feature_name::String) + SplitPoint(feature::Int, value::Float32, feature_name::String) A location where the tree splits. @@ -118,10 +56,10 @@ Arguments: """ struct SplitPoint feature::Int - value::Float + value::Float32 feature_name::String255 - function SplitPoint(feature::Int, value::Float, feature_name::String) + function SplitPoint(feature::Int, value::Float32, feature_name::String) return new(feature, value, String255(feature_name)) end end @@ -164,13 +102,13 @@ function _split( y::AbstractVector, classes::AbstractVector, colnames::Vector{String}, - cutpoints::Vector{Cutpoints}; - max_split_candidates::Int=_p(X) + cps::Vector{Cutpoints}; + max_split_candidates::Int=nfeatures(X) ) best_score = 0.0 best_score_feature = 0 best_score_cutpoint = 0.0 - p = _p(X) + p = nfeatures(X) mc = max_split_candidates possible_features = mc == p ? (1:p) : _rand_subset(rng, 1:p, mc) starting_impurity = _gini(y, classes) @@ -182,7 +120,7 @@ function _split( @inbounds for i in eachindex(feat_data) feat_data[i] = X[i, feature] end - for cutpoint in cutpoints[feature] + for cutpoint in cps[feature] vl = _view_y!(yl, feat_data, y, <, cutpoint) isempty(vl) && continue vr = _view_y!(yr, feat_data, y, ≥, cutpoint) @@ -258,7 +196,7 @@ end "Return `names(X)` if defined for `X` and string numbers otherwise." function _colnames(X)::Vector{String} - fallback() = string.(1:_p(X)) + fallback() = string.(1:nfeatures(X)) try names = collect(string.(Tables.columnnames(X))) if isempty(names) @@ -286,11 +224,11 @@ function _tree!( y::AbstractVector, classes::AbstractVector, colnames::Vector{String}=_colnames(X); - max_split_candidates=_p(X), + max_split_candidates=nfeatures(X), depth=0, max_depth=2, q=10, - cutpoints::Vector{Cutpoints}=_cutpoints(X, q), + cps::Vector{Cutpoints}=cutpoints(X, q), min_data_in_leaf=5 ) if X isa Tables.MatrixTable @@ -300,7 +238,7 @@ function _tree!( if depth == max_depth return Leaf(classes, y) end - sp = _split(rng, X, y, classes, colnames, cutpoints; max_split_candidates) + sp = _split(rng, X, y, classes, colnames, cps; max_split_candidates) if isnothing(sp) || length(y) ≤ min_data_in_leaf return Leaf(classes, y) end @@ -308,11 +246,11 @@ function _tree!( left = let _X, yl = _view_X_y!(mask, X, y, sp, <) - _tree!(rng, mask, _X, yl, classes, colnames; cutpoints, depth, max_depth) + _tree!(rng, mask, _X, yl, classes, colnames; cps, depth, max_depth) end right = let _X, yr = _view_X_y!(mask, X, y, sp, ≥) - _tree!(rng, mask, _X, yr, classes, colnames; cutpoints, depth, max_depth) + _tree!(rng, mask, _X, yr, classes, colnames; cps, depth, max_depth) end node = Node(sp, left, right) return node @@ -393,10 +331,10 @@ function _forest( error("Minimum tree depth is 1; got $max_depth") end # It is essential for the stability to determine the cutpoints over the whole dataset. - cutpoints = _cutpoints(X, q) + cps = cutpoints(X, q) classes = _classes(y) - max_split_candidates = round(Int, sqrt(_p(X))) + max_split_candidates = round(Int, sqrt(nfeatures(X))) n_samples = floor(Int, partial_sampling * length(y)) trees = Vector{Union{Node,Leaf}}(undef, n_trees) @@ -419,7 +357,7 @@ function _forest( max_split_candidates, max_depth, q, - cutpoints, + cps, min_data_in_leaf ) trees[i] = tree diff --git a/src/helpers.jl b/src/helpers.jl new file mode 100644 index 0000000..2b87a6b --- /dev/null +++ b/src/helpers.jl @@ -0,0 +1,10 @@ +module Helpers + +"Return the number of features `p` in a dataset `X`." +nfeatures(X::AbstractMatrix) = size(X, 2) +nfeatures(X) = length(Tables.columnnames(X)) + +view_feature(X::AbstractMatrix, feature::Int) = view(X, :, feature) +view_feature(X, feature::Int) = view(X, feature) + +end diff --git a/src/rules.jl b/src/rules.jl index 404789e..e6c4b6a 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -9,7 +9,7 @@ struct Split direction::Symbol # :L or :R end -function Split(feature::Int, name::String, splitval::Float, direction::Symbol) +function Split(feature::Int, name::String, splitval::Float32, direction::Symbol) return Split(SplitPoint(feature, splitval, name), direction) end @@ -45,7 +45,7 @@ function TreePath(text::String) feature = parse(Int, feature_text) splitval = let start = direction == :L ? findfirst('<', c) + 1 : findfirst('≥', c) + 3 - parse(Float, c[start:end]) + parse(Float32, c[start:end]) end feature_name = string(feature)::String Split(feature, feature_name, splitval, direction) @@ -467,7 +467,7 @@ function _probability(row::AbstractVector, rule::Rule) return satisfies(row, rule) ? rule.then_probs : rule.else_probs end -function _predict(pair::Tuple{Rule,Float16}, row::AbstractVector) +function _predict(pair::Tuple{Rule, Float16}, row::AbstractVector) rule, weight = pair probs = _probability(row, rule) return weight .* probs diff --git a/test/empiricalquantiles.jl b/test/empiricalquantiles.jl new file mode 100644 index 0000000..eb4c8e9 --- /dev/null +++ b/test/empiricalquantiles.jl @@ -0,0 +1,11 @@ +X = [1 2; + 3 4] +y = [1, 2] + +@test cutpoints([3, 1, 2], 2) == Float32[1, 2] +@test cutpoints(1:9, 3) == Float32[3, 5, 7] +@test cutpoints(1:4, 3) == Float32[1, 2, 3] +@test cutpoints([1, 3, 5, 7], 2) == Float32[3, 5] + +@test cutpoints(X, 2) == [Float32[1, 3], Float32[2, 4]] +@test cutpoints([3 4; 1 5; 2 6], 2) == [Float32[1, 2], Float32[4, 5]] diff --git a/test/forest.jl b/test/forest.jl index b0183e2..8980bd5 100644 --- a/test/forest.jl +++ b/test/forest.jl @@ -12,26 +12,18 @@ feature = 1 @test collect(ST._view_y!(y_view, X[:, feature], [1 2], <, 2)) == [1] @test collect(ST._view_y!(y_view, X[:, feature], [1 2], >, 2)) == [2] -@test ST._cutpoints([3, 1, 2], 2) == Float[1, 2] -@test ST._cutpoints(1:9, 3) == Float[3, 5, 7] -@test ST._cutpoints(1:4, 3) == Float[1, 2, 3] -@test ST._cutpoints([1, 3, 5, 7], 2) == Float[3, 5] - -@test ST._cutpoints(X, 2) == [Float[1, 3], Float[2, 4]] -@test ST._cutpoints([3 4; 1 5; 2 6], 2) == [Float[1, 2], Float[4, 5]] - let X = [1 1; 1 3] classes = unique(y) colnames = ["A", "B"] - cutpoints = ST._cutpoints(X, 2) - splitpoint = ST._split(StableRNG(1), X, y, classes, colnames, cutpoints) + cp = cutpoints(X, 2) + splitpoint = ST._split(StableRNG(1), X, y, classes, colnames, cp) # Obviously, feature (column) 2 is more informative to split on than feature 1. @test splitpoint.feature == 2 @test splitpoint.feature_name == "B" # Given that the split does < and ≥, then 3 is the best place since it separates 1 (left) and 3 (right). - @test splitpoint.value == Float(3) + @test splitpoint.value == Float32(3) end let @@ -41,7 +33,7 @@ let classes = y mask = Vector{Bool}(undef, length(y)) node = ST._tree!(_rng(), mask, X, y, classes; min_data_in_leaf=1, q=2) - # @test node.splitpoint == ST.SplitPoint(1, Float(3)) + # @test node.splitpoint == ST.SplitPoint(1, Float32(3)) # @test node.left.probabilities == [1.0, 0.0] # @test node.right.probabilities == [0.0, 1.0] end @@ -72,7 +64,7 @@ stree = ST._tree!(_rng(), mask, data, y, classes, min_data_in_leaf=1, q=10) @testset "data_subset" begin n_features = round(Int, sqrt(p)) n_samples = round(Int, n/2) - cols = rand(_rng(), 1:ST._p(data), n_features) + cols = rand(_rng(), 1:ST.nfeatures(data), n_features) rows = rand(_rng(), 1:length(y), n_samples) _data = view(data, rows, cols) _y = view(y, rows) diff --git a/test/preliminaries.jl b/test/preliminaries.jl index 4dd0696..de41ba8 100644 --- a/test/preliminaries.jl +++ b/test/preliminaries.jl @@ -37,7 +37,6 @@ using Tables: Tables using Test ST = SIRUS -Float = ST.Float _rng(seed::Int=1) = StableRNG(seed) if !haskey(ENV, "REGISTERED_HABERMAN") @@ -85,7 +84,7 @@ function boston() return (X, y) end -function _Split(feature::Int, splitval::Float, direction::Symbol) +function _Split(feature::Int, splitval::Float32, direction::Symbol) return ST.Split(feature, string(feature), splitval, direction) end diff --git a/test/rules.jl b/test/rules.jl index 8840452..78ae629 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -8,11 +8,10 @@ let @test_throws ArgumentError repr(TreePath(text)) end -Float = ST.Float classes = [:a, :b, :c] left = ST.Leaf([1.0, 0.0, 0.0]) feature_name = "1" -splitpoint = ST.SplitPoint(1, Float(1), feature_name) +splitpoint = ST.SplitPoint(1, Float32(1), feature_name) right = ST.Node(splitpoint, ST.Leaf([0.0, 1.0, 0.0]), ST.Leaf([0.0, 0.0, 1.0])) left_rule = ST.Rule(ST.TreePath(" X[i, 1] < 32000 "), [0.61], [0.408]) @@ -38,7 +37,7 @@ end # @test ST._mode([[1, 2], [1, 6], [4, 6]]) == [1, 6] -splitpoint = ST.SplitPoint(1, ST.Float(4), feature_name) +splitpoint = ST.SplitPoint(1, Float32(4), feature_name) node = ST.Node(splitpoint, left, right) rules = ST._rules!(node) diff --git a/test/runtests.jl b/test/runtests.jl index b643ccb..9bd6f15 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,9 @@ include("preliminaries.jl") +@testset "empiricalquantiles" begin + include("empiricalquantiles.jl") +end + @testset "forest" begin include("forest.jl") end