Skip to content

Commit

Permalink
Move cutpoints in separate module (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Jun 9, 2023
1 parent 7ef3cb5 commit b8c3119
Show file tree
Hide file tree
Showing 15 changed files with 130 additions and 107 deletions.
1 change: 1 addition & 0 deletions .github/workflows/Docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ jobs:
- uses: julia-actions/julia-docdeploy@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- run: echo "sirus.huijzer.xyz" > CNAME
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<img src="https://github.com/rikhuijzer/SIRUS.jl/workflows/CI/badge.svg" alt="CI">
</a>
<a href="https://huijzer.xyz/StableTrees.jl/dev/">
<img src="https://img.shields.io/badge/Documentation-main-blue" alt="Documentation">
<img src="https://img.shields.io/badge/Docs-main-blue" alt="Documentation">
</a>
<a href="https://github.com/invenia/BlueStyle">
<img src="https://img.shields.io/badge/Code%20Style-Blue-4495d1.svg" alt="Code Style Blue">
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ checkdocs = :none
makedocs(; sitename, pages, format, modules, strict, checkdocs)

deploydocs(;
branch="docs",
branch="docs-output",
devbranch="main",
repo="github.com/rikhuijzer/SIRUS.jl.git",
push_preview=false
Expand Down
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ feature_names
directions
values(::SIRUS.Rule)
satisfies
Cutpoints
cutpoints
```
8 changes: 4 additions & 4 deletions docs/src/sirus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,11 @@ ST = SIRUS;
function _plot_cutpoints(data::AbstractVector)
fig = Figure(; resolution=(800, 100))
ax = Axis(fig[1, 1])
cutpoints = Float64.(unique(ST._cutpoints(data, 10)))
cps = Float64.(unique(cutpoints(data, 10)))
scatter!(ax, data, fill(1, length(data)))
vlines!(ax, cutpoints; color=:black, linestyle=:dash)
textlocs = [(c, 1.1) for c in cutpoints]
for cutpoint in cutpoints
vlines!(ax, cps; color=:black, linestyle=:dash)
textlocs = [(c, 1.1) for c in cps]
for cutpoint in cps
annotation = string(round(cutpoint; digits=2))::String
text!(ax, cutpoint + 0.2, 1.08; text=annotation, textsize=13)
end
Expand Down
9 changes: 7 additions & 2 deletions src/SIRUS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@ using Random: AbstractRNG, default_rng, seed!, shuffle
using Statistics: mean, median
using Tables: Tables, matrix

const Float = Float32

export StableForestClassifier, StableRulesClassifier
export feature_names, directions, satisfies

include("helpers.jl")
using .Helpers: nfeatures, view_feature

include("empiricalquantiles.jl")
using .EmpiricalQuantiles: Cutpoints, cutpoints
export Cutpoints, cutpoints

include("forest.jl")
include("rules.jl")
include("weights.jl")
Expand Down
62 changes: 62 additions & 0 deletions src/empiricalquantiles.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module EmpiricalQuantiles

using ..Helpers: nfeatures, view_feature

"Set of possible cutpoints, that is, empirical quantiles."
const Cutpoints = Vector{Float32}

"""
Return a rough estimate for the index of the cutpoint.
Choose the highest suitable index if there is more than one suitable index.
The reason is that this will split the data nicely in combination with the `<` used later on.
For example, for [1, 2, 3, 4], both 2 and 3 satisfy the 0.5 quantile.
In this case, we pick the ceil, so 3.
Next, the tree will split on 3, causing left (<) to contain 1 and 2 and right (≥) to contain 3 and 4.
"""
function _rough_cutpoint_index_estimate(n::Int, quantile::Real)
Int(ceil(quantile * n))
end

"Return the empirical `quantile` for data `V`."
function _empirical_quantile(V::AbstractVector, quantile::Real)
@assert 0.0 quantile 1.0
n = length(V)
index = _rough_cutpoint_index_estimate(n, quantile)
if index == 0
index = 1
end
if index == n + 1
index = n
end
sorted = sort(V)
return Float32(sorted[index])
end

"Return a vector of `q` cutpoints taken from the empirical distribution from data `V`."
function cutpoints(V::AbstractVector, q::Int)
@assert 2 q
# Taking 2 extra to avoid getting minimum(V) and maximum(V) becoming cutpoints.
# Tree on left and right have always respectively length 0 and 1 then anyway.
length = q + 2
quantiles = range(0.0; stop=1.0, length)[2:end-1]
return Float32[_empirical_quantile(V, quantile) for quantile in quantiles]
end

"""
Return a vector of vectors containing
- one inner vector for each feature in the dataset and
- inner vectors containing the unique cutpoints, that is, `length(V[i])` ≤ `q` for all i in V.
Using unique here to avoid checking splits twice.
"""
function cutpoints(X, q::Int)
p = nfeatures(X)
cps = Vector{Cutpoints}(undef, p)
for feature in 1:p
V = view_feature(X, feature)
cps[feature] = unique(cutpoints(V, q))
end
return cps
end

end # module
94 changes: 16 additions & 78 deletions src/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,69 +45,7 @@ function _information_gain(
end

"""
Return a rough estimate for the index of the cutpoint.
Choose the highest suitable index if there is more than one suitable index.
The reason is that this will split the data nicely in combination with the `<` used later on.
For example, for [1, 2, 3, 4], both 2 and 3 satisfy the 0.5 quantile.
In this case, we pick the ceil, so 3.
Next, the tree will split on 3, causing left (<) to contain 1 and 2 and right (≥) to contain 3 and 4.
"""
_rough_cutpoint_index_estimate(n::Int, quantile::Real) = Int(ceil(quantile * n))

"Return the empirical `quantile` for data `V`."
function _empirical_quantile(V::AbstractVector, quantile::Real)
@assert 0.0 quantile 1.0
n = length(V)
index = _rough_cutpoint_index_estimate(n, quantile)
if index == 0
index = 1
end
if index == n + 1
index = n
end
sorted = sort(V)
return Float(sorted[index])
end

"Return a vector of `q` cutpoints taken from the empirical distribution from data `V`."
function _cutpoints(V::AbstractVector, q::Int)
@assert 2 q
# Taking 2 extra to avoid getting minimum(V) and maximum(V) becoming cutpoints.
# Tree on left and right have always respectively length 0 and 1 then anyway.
length = q + 2
quantiles = range(0.0; stop=1.0, length)[2:end-1]
return Float[_empirical_quantile(V, quantile) for quantile in quantiles]
end

"Return the number of features `p` in a dataset `X`."
_p(X::AbstractMatrix) = size(X, 2)
_p(X) = length(Tables.columnnames(X))

"Set of possible cutpoints, that is, numbers from the empirical quantiles."
const Cutpoints = Vector{Float}

_view_feature(X::AbstractMatrix, feature::Int) = view(X, :, feature)
_view_feature(X, feature::Int) = X[feature]

"""
Return a vector of vectors containing
- one inner vector for each feature in the dataset and
- inner vectors containing the unique cutpoints, that is, `length(V[i])` ≤ `q` for all i in V.
Using unique here to avoid checking splits twice.
"""
function _cutpoints(X, q::Int)
p = _p(X)
cutpoints = Vector{Cutpoints}(undef, p)
for feature in 1:p
V = _view_feature(X, feature)
cutpoints[feature] = unique(_cutpoints(V, q))
end
return cutpoints
end

"""
SplitPoint(feature::Int, value::Float, feature_name::String)
SplitPoint(feature::Int, value::Float32, feature_name::String)
A location where the tree splits.
Expand All @@ -118,10 +56,10 @@ Arguments:
"""
struct SplitPoint
feature::Int
value::Float
value::Float32
feature_name::String255

function SplitPoint(feature::Int, value::Float, feature_name::String)
function SplitPoint(feature::Int, value::Float32, feature_name::String)
return new(feature, value, String255(feature_name))
end
end
Expand Down Expand Up @@ -164,13 +102,13 @@ function _split(
y::AbstractVector,
classes::AbstractVector,
colnames::Vector{String},
cutpoints::Vector{Cutpoints};
max_split_candidates::Int=_p(X)
cps::Vector{Cutpoints};
max_split_candidates::Int=nfeatures(X)
)
best_score = 0.0
best_score_feature = 0
best_score_cutpoint = 0.0
p = _p(X)
p = nfeatures(X)
mc = max_split_candidates
possible_features = mc == p ? (1:p) : _rand_subset(rng, 1:p, mc)
starting_impurity = _gini(y, classes)
Expand All @@ -182,7 +120,7 @@ function _split(
@inbounds for i in eachindex(feat_data)
feat_data[i] = X[i, feature]
end
for cutpoint in cutpoints[feature]
for cutpoint in cps[feature]
vl = _view_y!(yl, feat_data, y, <, cutpoint)
isempty(vl) && continue
vr = _view_y!(yr, feat_data, y, , cutpoint)
Expand Down Expand Up @@ -258,7 +196,7 @@ end

"Return `names(X)` if defined for `X` and string numbers otherwise."
function _colnames(X)::Vector{String}
fallback() = string.(1:_p(X))
fallback() = string.(1:nfeatures(X))
try
names = collect(string.(Tables.columnnames(X)))
if isempty(names)
Expand Down Expand Up @@ -286,11 +224,11 @@ function _tree!(
y::AbstractVector,
classes::AbstractVector,
colnames::Vector{String}=_colnames(X);
max_split_candidates=_p(X),
max_split_candidates=nfeatures(X),
depth=0,
max_depth=2,
q=10,
cutpoints::Vector{Cutpoints}=_cutpoints(X, q),
cps::Vector{Cutpoints}=cutpoints(X, q),
min_data_in_leaf=5
)
if X isa Tables.MatrixTable
Expand All @@ -300,19 +238,19 @@ function _tree!(
if depth == max_depth
return Leaf(classes, y)
end
sp = _split(rng, X, y, classes, colnames, cutpoints; max_split_candidates)
sp = _split(rng, X, y, classes, colnames, cps; max_split_candidates)
if isnothing(sp) || length(y) min_data_in_leaf
return Leaf(classes, y)
end
depth += 1

left = let
_X, yl = _view_X_y!(mask, X, y, sp, <)
_tree!(rng, mask, _X, yl, classes, colnames; cutpoints, depth, max_depth)
_tree!(rng, mask, _X, yl, classes, colnames; cps, depth, max_depth)
end
right = let
_X, yr = _view_X_y!(mask, X, y, sp, )
_tree!(rng, mask, _X, yr, classes, colnames; cutpoints, depth, max_depth)
_tree!(rng, mask, _X, yr, classes, colnames; cps, depth, max_depth)
end
node = Node(sp, left, right)
return node
Expand Down Expand Up @@ -393,10 +331,10 @@ function _forest(
error("Minimum tree depth is 1; got $max_depth")
end
# It is essential for the stability to determine the cutpoints over the whole dataset.
cutpoints = _cutpoints(X, q)
cps = cutpoints(X, q)
classes = _classes(y)

max_split_candidates = round(Int, sqrt(_p(X)))
max_split_candidates = round(Int, sqrt(nfeatures(X)))
n_samples = floor(Int, partial_sampling * length(y))

trees = Vector{Union{Node,Leaf}}(undef, n_trees)
Expand All @@ -419,7 +357,7 @@ function _forest(
max_split_candidates,
max_depth,
q,
cutpoints,
cps,
min_data_in_leaf
)
trees[i] = tree
Expand Down
10 changes: 10 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module Helpers

"Return the number of features `p` in a dataset `X`."
nfeatures(X::AbstractMatrix) = size(X, 2)
nfeatures(X) = length(Tables.columnnames(X))

view_feature(X::AbstractMatrix, feature::Int) = view(X, :, feature)
view_feature(X, feature::Int) = view(X, feature)

end
6 changes: 3 additions & 3 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ struct Split
direction::Symbol # :L or :R
end

function Split(feature::Int, name::String, splitval::Float, direction::Symbol)
function Split(feature::Int, name::String, splitval::Float32, direction::Symbol)
return Split(SplitPoint(feature, splitval, name), direction)
end

Expand Down Expand Up @@ -45,7 +45,7 @@ function TreePath(text::String)
feature = parse(Int, feature_text)
splitval = let
start = direction == :L ? findfirst('<', c) + 1 : findfirst('', c) + 3
parse(Float, c[start:end])
parse(Float32, c[start:end])
end
feature_name = string(feature)::String
Split(feature, feature_name, splitval, direction)
Expand Down Expand Up @@ -467,7 +467,7 @@ function _probability(row::AbstractVector, rule::Rule)
return satisfies(row, rule) ? rule.then_probs : rule.else_probs
end

function _predict(pair::Tuple{Rule,Float16}, row::AbstractVector)
function _predict(pair::Tuple{Rule, Float16}, row::AbstractVector)
rule, weight = pair
probs = _probability(row, rule)
return weight .* probs
Expand Down
11 changes: 11 additions & 0 deletions test/empiricalquantiles.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
X = [1 2;
3 4]
y = [1, 2]

@test cutpoints([3, 1, 2], 2) == Float32[1, 2]
@test cutpoints(1:9, 3) == Float32[3, 5, 7]
@test cutpoints(1:4, 3) == Float32[1, 2, 3]
@test cutpoints([1, 3, 5, 7], 2) == Float32[3, 5]

@test cutpoints(X, 2) == [Float32[1, 3], Float32[2, 4]]
@test cutpoints([3 4; 1 5; 2 6], 2) == [Float32[1, 2], Float32[4, 5]]
Loading

0 comments on commit b8c3119

Please sign in to comment.