diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml
index ea3306e..72f0b5f 100644
--- a/.github/workflows/Docs.yml
+++ b/.github/workflows/Docs.yml
@@ -20,3 +20,4 @@ jobs:
- uses: julia-actions/julia-docdeploy@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ - run: echo "sirus.huijzer.xyz" > CNAME
diff --git a/README.md b/README.md
index 79229ed..65c1a39 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@
-
+
diff --git a/docs/make.jl b/docs/make.jl
index 92ca3cc..45e4879 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -50,7 +50,7 @@ checkdocs = :none
makedocs(; sitename, pages, format, modules, strict, checkdocs)
deploydocs(;
- branch="docs",
+ branch="docs-output",
devbranch="main",
repo="github.com/rikhuijzer/SIRUS.jl.git",
push_preview=false
diff --git a/docs/src/api.md b/docs/src/api.md
index 2253597..ed235e7 100644
--- a/docs/src/api.md
+++ b/docs/src/api.md
@@ -14,4 +14,6 @@ feature_names
directions
values(::SIRUS.Rule)
satisfies
+Cutpoints
+cutpoints
```
diff --git a/docs/src/sirus.jl b/docs/src/sirus.jl
index a0f4886..dd5e0f3 100644
--- a/docs/src/sirus.jl
+++ b/docs/src/sirus.jl
@@ -363,11 +363,11 @@ ST = SIRUS;
function _plot_cutpoints(data::AbstractVector)
fig = Figure(; resolution=(800, 100))
ax = Axis(fig[1, 1])
- cutpoints = Float64.(unique(ST._cutpoints(data, 10)))
+ cps = Float64.(unique(cutpoints(data, 10)))
scatter!(ax, data, fill(1, length(data)))
- vlines!(ax, cutpoints; color=:black, linestyle=:dash)
- textlocs = [(c, 1.1) for c in cutpoints]
- for cutpoint in cutpoints
+ vlines!(ax, cps; color=:black, linestyle=:dash)
+ textlocs = [(c, 1.1) for c in cps]
+ for cutpoint in cps
annotation = string(round(cutpoint; digits=2))::String
text!(ax, cutpoint + 0.2, 1.08; text=annotation, textsize=13)
end
diff --git a/src/SIRUS.jl b/src/SIRUS.jl
index bf9b9e7..24f77aa 100644
--- a/src/SIRUS.jl
+++ b/src/SIRUS.jl
@@ -14,11 +14,16 @@ using Random: AbstractRNG, default_rng, seed!, shuffle
using Statistics: mean, median
using Tables: Tables, matrix
-const Float = Float32
-
export StableForestClassifier, StableRulesClassifier
export feature_names, directions, satisfies
+include("helpers.jl")
+using .Helpers: nfeatures, view_feature
+
+include("empiricalquantiles.jl")
+using .EmpiricalQuantiles: Cutpoints, cutpoints
+export Cutpoints, cutpoints
+
include("forest.jl")
include("rules.jl")
include("weights.jl")
diff --git a/src/empiricalquantiles.jl b/src/empiricalquantiles.jl
new file mode 100644
index 0000000..82f3ca9
--- /dev/null
+++ b/src/empiricalquantiles.jl
@@ -0,0 +1,62 @@
+module EmpiricalQuantiles
+
+using ..Helpers: nfeatures, view_feature
+
+"Set of possible cutpoints, that is, empirical quantiles."
+const Cutpoints = Vector{Float32}
+
+"""
+Return a rough estimate for the index of the cutpoint.
+Choose the highest suitable index if there is more than one suitable index.
+The reason is that this will split the data nicely in combination with the `<` used later on.
+For example, for [1, 2, 3, 4], both 2 and 3 satisfy the 0.5 quantile.
+In this case, we pick the ceil, so 3.
+Next, the tree will split on 3, causing left (<) to contain 1 and 2 and right (≥) to contain 3 and 4.
+"""
+function _rough_cutpoint_index_estimate(n::Int, quantile::Real)
+ Int(ceil(quantile * n))
+end
+
+"Return the empirical `quantile` for data `V`."
+function _empirical_quantile(V::AbstractVector, quantile::Real)
+ @assert 0.0 ≤ quantile ≤ 1.0
+ n = length(V)
+ index = _rough_cutpoint_index_estimate(n, quantile)
+ if index == 0
+ index = 1
+ end
+ if index == n + 1
+ index = n
+ end
+ sorted = sort(V)
+ return Float32(sorted[index])
+end
+
+"Return a vector of `q` cutpoints taken from the empirical distribution from data `V`."
+function cutpoints(V::AbstractVector, q::Int)
+ @assert 2 ≤ q
+ # Taking 2 extra to avoid getting minimum(V) and maximum(V) becoming cutpoints.
+ # Tree on left and right have always respectively length 0 and 1 then anyway.
+ length = q + 2
+ quantiles = range(0.0; stop=1.0, length)[2:end-1]
+ return Float32[_empirical_quantile(V, quantile) for quantile in quantiles]
+end
+
+"""
+Return a vector of vectors containing
+- one inner vector for each feature in the dataset and
+- inner vectors containing the unique cutpoints, that is, `length(V[i])` ≤ `q` for all i in V.
+
+Using unique here to avoid checking splits twice.
+"""
+function cutpoints(X, q::Int)
+ p = nfeatures(X)
+ cps = Vector{Cutpoints}(undef, p)
+ for feature in 1:p
+ V = view_feature(X, feature)
+ cps[feature] = unique(cutpoints(V, q))
+ end
+ return cps
+end
+
+end # module
diff --git a/src/forest.jl b/src/forest.jl
index c680ee8..944f2da 100644
--- a/src/forest.jl
+++ b/src/forest.jl
@@ -45,69 +45,7 @@ function _information_gain(
end
"""
-Return a rough estimate for the index of the cutpoint.
-Choose the highest suitable index if there is more than one suitable index.
-The reason is that this will split the data nicely in combination with the `<` used later on.
-For example, for [1, 2, 3, 4], both 2 and 3 satisfy the 0.5 quantile.
-In this case, we pick the ceil, so 3.
-Next, the tree will split on 3, causing left (<) to contain 1 and 2 and right (≥) to contain 3 and 4.
-"""
-_rough_cutpoint_index_estimate(n::Int, quantile::Real) = Int(ceil(quantile * n))
-
-"Return the empirical `quantile` for data `V`."
-function _empirical_quantile(V::AbstractVector, quantile::Real)
- @assert 0.0 ≤ quantile ≤ 1.0
- n = length(V)
- index = _rough_cutpoint_index_estimate(n, quantile)
- if index == 0
- index = 1
- end
- if index == n + 1
- index = n
- end
- sorted = sort(V)
- return Float(sorted[index])
-end
-
-"Return a vector of `q` cutpoints taken from the empirical distribution from data `V`."
-function _cutpoints(V::AbstractVector, q::Int)
- @assert 2 ≤ q
- # Taking 2 extra to avoid getting minimum(V) and maximum(V) becoming cutpoints.
- # Tree on left and right have always respectively length 0 and 1 then anyway.
- length = q + 2
- quantiles = range(0.0; stop=1.0, length)[2:end-1]
- return Float[_empirical_quantile(V, quantile) for quantile in quantiles]
-end
-
-"Return the number of features `p` in a dataset `X`."
-_p(X::AbstractMatrix) = size(X, 2)
-_p(X) = length(Tables.columnnames(X))
-
-"Set of possible cutpoints, that is, numbers from the empirical quantiles."
-const Cutpoints = Vector{Float}
-
-_view_feature(X::AbstractMatrix, feature::Int) = view(X, :, feature)
-_view_feature(X, feature::Int) = X[feature]
-
-"""
-Return a vector of vectors containing
-- one inner vector for each feature in the dataset and
-- inner vectors containing the unique cutpoints, that is, `length(V[i])` ≤ `q` for all i in V.
-
-Using unique here to avoid checking splits twice.
-"""
-function _cutpoints(X, q::Int)
- p = _p(X)
- cutpoints = Vector{Cutpoints}(undef, p)
- for feature in 1:p
- V = _view_feature(X, feature)
- cutpoints[feature] = unique(_cutpoints(V, q))
- end
- return cutpoints
-end
-
-"""
- SplitPoint(feature::Int, value::Float, feature_name::String)
+ SplitPoint(feature::Int, value::Float32, feature_name::String)
A location where the tree splits.
@@ -118,10 +56,10 @@ Arguments:
"""
struct SplitPoint
feature::Int
- value::Float
+ value::Float32
feature_name::String255
- function SplitPoint(feature::Int, value::Float, feature_name::String)
+ function SplitPoint(feature::Int, value::Float32, feature_name::String)
return new(feature, value, String255(feature_name))
end
end
@@ -164,13 +102,13 @@ function _split(
y::AbstractVector,
classes::AbstractVector,
colnames::Vector{String},
- cutpoints::Vector{Cutpoints};
- max_split_candidates::Int=_p(X)
+ cps::Vector{Cutpoints};
+ max_split_candidates::Int=nfeatures(X)
)
best_score = 0.0
best_score_feature = 0
best_score_cutpoint = 0.0
- p = _p(X)
+ p = nfeatures(X)
mc = max_split_candidates
possible_features = mc == p ? (1:p) : _rand_subset(rng, 1:p, mc)
starting_impurity = _gini(y, classes)
@@ -182,7 +120,7 @@ function _split(
@inbounds for i in eachindex(feat_data)
feat_data[i] = X[i, feature]
end
- for cutpoint in cutpoints[feature]
+ for cutpoint in cps[feature]
vl = _view_y!(yl, feat_data, y, <, cutpoint)
isempty(vl) && continue
vr = _view_y!(yr, feat_data, y, ≥, cutpoint)
@@ -258,7 +196,7 @@ end
"Return `names(X)` if defined for `X` and string numbers otherwise."
function _colnames(X)::Vector{String}
- fallback() = string.(1:_p(X))
+ fallback() = string.(1:nfeatures(X))
try
names = collect(string.(Tables.columnnames(X)))
if isempty(names)
@@ -286,11 +224,11 @@ function _tree!(
y::AbstractVector,
classes::AbstractVector,
colnames::Vector{String}=_colnames(X);
- max_split_candidates=_p(X),
+ max_split_candidates=nfeatures(X),
depth=0,
max_depth=2,
q=10,
- cutpoints::Vector{Cutpoints}=_cutpoints(X, q),
+ cps::Vector{Cutpoints}=cutpoints(X, q),
min_data_in_leaf=5
)
if X isa Tables.MatrixTable
@@ -300,7 +238,7 @@ function _tree!(
if depth == max_depth
return Leaf(classes, y)
end
- sp = _split(rng, X, y, classes, colnames, cutpoints; max_split_candidates)
+ sp = _split(rng, X, y, classes, colnames, cps; max_split_candidates)
if isnothing(sp) || length(y) ≤ min_data_in_leaf
return Leaf(classes, y)
end
@@ -308,11 +246,11 @@ function _tree!(
left = let
_X, yl = _view_X_y!(mask, X, y, sp, <)
- _tree!(rng, mask, _X, yl, classes, colnames; cutpoints, depth, max_depth)
+ _tree!(rng, mask, _X, yl, classes, colnames; cps, depth, max_depth)
end
right = let
_X, yr = _view_X_y!(mask, X, y, sp, ≥)
- _tree!(rng, mask, _X, yr, classes, colnames; cutpoints, depth, max_depth)
+ _tree!(rng, mask, _X, yr, classes, colnames; cps, depth, max_depth)
end
node = Node(sp, left, right)
return node
@@ -393,10 +331,10 @@ function _forest(
error("Minimum tree depth is 1; got $max_depth")
end
# It is essential for the stability to determine the cutpoints over the whole dataset.
- cutpoints = _cutpoints(X, q)
+ cps = cutpoints(X, q)
classes = _classes(y)
- max_split_candidates = round(Int, sqrt(_p(X)))
+ max_split_candidates = round(Int, sqrt(nfeatures(X)))
n_samples = floor(Int, partial_sampling * length(y))
trees = Vector{Union{Node,Leaf}}(undef, n_trees)
@@ -419,7 +357,7 @@ function _forest(
max_split_candidates,
max_depth,
q,
- cutpoints,
+ cps,
min_data_in_leaf
)
trees[i] = tree
diff --git a/src/helpers.jl b/src/helpers.jl
new file mode 100644
index 0000000..2b87a6b
--- /dev/null
+++ b/src/helpers.jl
@@ -0,0 +1,10 @@
+module Helpers
+
+"Return the number of features `p` in a dataset `X`."
+nfeatures(X::AbstractMatrix) = size(X, 2)
+nfeatures(X) = length(Tables.columnnames(X))
+
+view_feature(X::AbstractMatrix, feature::Int) = view(X, :, feature)
+view_feature(X, feature::Int) = view(X, feature)
+
+end
diff --git a/src/rules.jl b/src/rules.jl
index 404789e..e6c4b6a 100644
--- a/src/rules.jl
+++ b/src/rules.jl
@@ -9,7 +9,7 @@ struct Split
direction::Symbol # :L or :R
end
-function Split(feature::Int, name::String, splitval::Float, direction::Symbol)
+function Split(feature::Int, name::String, splitval::Float32, direction::Symbol)
return Split(SplitPoint(feature, splitval, name), direction)
end
@@ -45,7 +45,7 @@ function TreePath(text::String)
feature = parse(Int, feature_text)
splitval = let
start = direction == :L ? findfirst('<', c) + 1 : findfirst('≥', c) + 3
- parse(Float, c[start:end])
+ parse(Float32, c[start:end])
end
feature_name = string(feature)::String
Split(feature, feature_name, splitval, direction)
@@ -467,7 +467,7 @@ function _probability(row::AbstractVector, rule::Rule)
return satisfies(row, rule) ? rule.then_probs : rule.else_probs
end
-function _predict(pair::Tuple{Rule,Float16}, row::AbstractVector)
+function _predict(pair::Tuple{Rule, Float16}, row::AbstractVector)
rule, weight = pair
probs = _probability(row, rule)
return weight .* probs
diff --git a/test/empiricalquantiles.jl b/test/empiricalquantiles.jl
new file mode 100644
index 0000000..eb4c8e9
--- /dev/null
+++ b/test/empiricalquantiles.jl
@@ -0,0 +1,11 @@
+X = [1 2;
+ 3 4]
+y = [1, 2]
+
+@test cutpoints([3, 1, 2], 2) == Float32[1, 2]
+@test cutpoints(1:9, 3) == Float32[3, 5, 7]
+@test cutpoints(1:4, 3) == Float32[1, 2, 3]
+@test cutpoints([1, 3, 5, 7], 2) == Float32[3, 5]
+
+@test cutpoints(X, 2) == [Float32[1, 3], Float32[2, 4]]
+@test cutpoints([3 4; 1 5; 2 6], 2) == [Float32[1, 2], Float32[4, 5]]
diff --git a/test/forest.jl b/test/forest.jl
index b0183e2..8980bd5 100644
--- a/test/forest.jl
+++ b/test/forest.jl
@@ -12,26 +12,18 @@ feature = 1
@test collect(ST._view_y!(y_view, X[:, feature], [1 2], <, 2)) == [1]
@test collect(ST._view_y!(y_view, X[:, feature], [1 2], >, 2)) == [2]
-@test ST._cutpoints([3, 1, 2], 2) == Float[1, 2]
-@test ST._cutpoints(1:9, 3) == Float[3, 5, 7]
-@test ST._cutpoints(1:4, 3) == Float[1, 2, 3]
-@test ST._cutpoints([1, 3, 5, 7], 2) == Float[3, 5]
-
-@test ST._cutpoints(X, 2) == [Float[1, 3], Float[2, 4]]
-@test ST._cutpoints([3 4; 1 5; 2 6], 2) == [Float[1, 2], Float[4, 5]]
-
let
X = [1 1;
1 3]
classes = unique(y)
colnames = ["A", "B"]
- cutpoints = ST._cutpoints(X, 2)
- splitpoint = ST._split(StableRNG(1), X, y, classes, colnames, cutpoints)
+ cp = cutpoints(X, 2)
+ splitpoint = ST._split(StableRNG(1), X, y, classes, colnames, cp)
# Obviously, feature (column) 2 is more informative to split on than feature 1.
@test splitpoint.feature == 2
@test splitpoint.feature_name == "B"
# Given that the split does < and ≥, then 3 is the best place since it separates 1 (left) and 3 (right).
- @test splitpoint.value == Float(3)
+ @test splitpoint.value == Float32(3)
end
let
@@ -41,7 +33,7 @@ let
classes = y
mask = Vector{Bool}(undef, length(y))
node = ST._tree!(_rng(), mask, X, y, classes; min_data_in_leaf=1, q=2)
- # @test node.splitpoint == ST.SplitPoint(1, Float(3))
+ # @test node.splitpoint == ST.SplitPoint(1, Float32(3))
# @test node.left.probabilities == [1.0, 0.0]
# @test node.right.probabilities == [0.0, 1.0]
end
@@ -72,7 +64,7 @@ stree = ST._tree!(_rng(), mask, data, y, classes, min_data_in_leaf=1, q=10)
@testset "data_subset" begin
n_features = round(Int, sqrt(p))
n_samples = round(Int, n/2)
- cols = rand(_rng(), 1:ST._p(data), n_features)
+ cols = rand(_rng(), 1:ST.nfeatures(data), n_features)
rows = rand(_rng(), 1:length(y), n_samples)
_data = view(data, rows, cols)
_y = view(y, rows)
diff --git a/test/preliminaries.jl b/test/preliminaries.jl
index 4dd0696..de41ba8 100644
--- a/test/preliminaries.jl
+++ b/test/preliminaries.jl
@@ -37,7 +37,6 @@ using Tables: Tables
using Test
ST = SIRUS
-Float = ST.Float
_rng(seed::Int=1) = StableRNG(seed)
if !haskey(ENV, "REGISTERED_HABERMAN")
@@ -85,7 +84,7 @@ function boston()
return (X, y)
end
-function _Split(feature::Int, splitval::Float, direction::Symbol)
+function _Split(feature::Int, splitval::Float32, direction::Symbol)
return ST.Split(feature, string(feature), splitval, direction)
end
diff --git a/test/rules.jl b/test/rules.jl
index 8840452..78ae629 100644
--- a/test/rules.jl
+++ b/test/rules.jl
@@ -8,11 +8,10 @@ let
@test_throws ArgumentError repr(TreePath(text))
end
-Float = ST.Float
classes = [:a, :b, :c]
left = ST.Leaf([1.0, 0.0, 0.0])
feature_name = "1"
-splitpoint = ST.SplitPoint(1, Float(1), feature_name)
+splitpoint = ST.SplitPoint(1, Float32(1), feature_name)
right = ST.Node(splitpoint, ST.Leaf([0.0, 1.0, 0.0]), ST.Leaf([0.0, 0.0, 1.0]))
left_rule = ST.Rule(ST.TreePath(" X[i, 1] < 32000 "), [0.61], [0.408])
@@ -38,7 +37,7 @@ end
# @test ST._mode([[1, 2], [1, 6], [4, 6]]) == [1, 6]
-splitpoint = ST.SplitPoint(1, ST.Float(4), feature_name)
+splitpoint = ST.SplitPoint(1, Float32(4), feature_name)
node = ST.Node(splitpoint, left, right)
rules = ST._rules!(node)
diff --git a/test/runtests.jl b/test/runtests.jl
index b643ccb..9bd6f15 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,5 +1,9 @@
include("preliminaries.jl")
+@testset "empiricalquantiles" begin
+ include("empiricalquantiles.jl")
+end
+
@testset "forest" begin
include("forest.jl")
end