Skip to content

Commit

Permalink
Do some unsuccesful performance optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Jun 7, 2023
1 parent 7dba7fd commit ed24751
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 22 deletions.
61 changes: 61 additions & 0 deletions src/cutpoints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Return a rough estimate for the index of the cutpoint.
Choose the highest suitable index if there is more than one suitable index.
The reason is that this will split the data nicely in combination with the `<` used later on.
For example, for [1, 2, 3, 4], both 2 and 3 satisfy the 0.5 quantile.
In this case, we pick the ceil, so 3.
Next, the tree will split on 3, causing left (<) to contain 1 and 2 and right (≥) to contain 3 and 4.
"""
_rough_cutpoint_index_estimate(n::Int, quantile::Real) = Int(ceil(quantile * n))

"Return the empirical `quantile` for data `V`."
function _empirical_quantile(V::AbstractVector, quantile::Real)
@assert 0.0 quantile 1.0
n = length(V)
index = _rough_cutpoint_index_estimate(n, quantile)
if index == 0
index = 1
end
if index == n + 1
index = n
end
sorted = sort(V)
return Float(sorted[index])
end

"Return a vector of `q` cutpoints taken from the empirical distribution from data `V`."
function _cutpoints(V::AbstractVector, q::Int)
@assert 2 q
# Taking 2 extra to avoid getting minimum(V) and maximum(V) becoming cutpoints.
# Tree on left and right have always respectively length 0 and 1 then anyway.
length = q + 2
quantiles = range(0.0; stop=1.0, length)[2:end-1]
return Float[_empirical_quantile(V, quantile) for quantile in quantiles]
end

"Return the number of features `p` in a dataset `X`."
_p(X::AbstractMatrix) = size(X, 2)
_p(X) = length(Tables.columnnames(X))

"Set of possible cutpoints, that is, numbers from the empirical quantiles."
const Cutpoints = Vector{Float}

_view_feature(X::AbstractMatrix, feature::Int) = view(X, :, feature)
_view_feature(X, feature::Int) = X[feature]

"""
Return a vector of vectors containing
- one inner vector for each feature in the dataset and
- inner vectors containing the unique cutpoints, that is, `length(V[i])` ≤ `q` for all i in V.
Using unique here to avoid checking splits twice.
"""
function _cutpoints(X, q::Int)::Vector{Cutpoints}
p = _p(X)
cutpoints = Vector{Cutpoints}(undef, p)
for feature in 1:p
V = _view_feature(X, feature)
cutpoints[feature] = unique(_cutpoints(V, q))
end
return cutpoints
end
44 changes: 25 additions & 19 deletions src/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,6 @@ function _cutpoints(X, q::Int)
return cutpoints
end

"""
Return a view on all `y` for which the `comparison` holds in `data`.
`indexes_in_region` is re-used between calls.
"""
function _view_y!(mask, data, y, comparison, cutpoint)
for i in eachindex(data)
mask[i] = comparison(data[i], cutpoint)
end
return view(y, mask)
end

"""
SplitPoint(feature::Int, value::Float, feature_name::String)
Expand All @@ -131,9 +120,14 @@ _feature(sp::SplitPoint) = sp.feature
_value(sp::SplitPoint) = sp.value
_feature_name(sp::SplitPoint) = sp.feature_name

function _information_gain(y, y_left, y_right, classes)
function _information_gain(
y,
y_left,
y_right,
classes,
starting_impurity::Real
)
p = length(y_left) / length(y)
starting_impurity = _gini(y, classes)
impurity_change = p * _gini(y_left, classes) + (1 - p) * _gini(y_right, classes)
return starting_impurity - impurity_change
end
Expand All @@ -143,6 +137,18 @@ function _rand_subset(rng::AbstractRNG, V::AbstractVector, n::Int)
return view(shuffle(rng, V), 1:n)
end

"""
Return a view on all `y` for which the `comparison` holds in `data`.
`mask` is re-used between calls.
"""
function _view_y!(mask, X, feature::Int, y, comparison, cutpoint)
for i in eachindex(y)
value = @inbounds X[i, feature]
mask[i] = comparison(value, cutpoint)
end
return @inbounds view(y, mask)
end

"""
Return the split for which the gini index is maximized.
This function receives the cutpoints for the whole dataset `D` because `X` can be a subset of `D`.
Expand All @@ -157,21 +163,21 @@ function _split(
cutpoints::Vector{Cutpoints};
max_split_candidates::Int=_p(X)
)
best_score = Float(0)
best_score = 0.0
best_score_feature = 0
best_score_cutpoint = Float(0)
best_score_cutpoint = 0.0
p = _p(X)
mc = max_split_candidates
possible_features = mc == p ? (1:p) : _rand_subset(rng, 1:p, mc)
mask = Vector{Bool}(undef, length(y))
starting_impurity = _gini(y, classes)
for feature in possible_features
data = view(X, :, feature)
for cutpoint in cutpoints[feature]
y_left = _view_y!(mask, data, y, <, cutpoint)
y_left = _view_y!(mask, X, feature, y, <, cutpoint)
isempty(y_left) && continue
y_right = _view_y!(mask, data, y, , cutpoint)
y_right = _view_y!(mask, X, feature, y, , cutpoint)
isempty(y_right) && continue
gain = _information_gain(y, y_left, y_right, classes)
gain = _information_gain(y, y_left, y_right, classes, starting_impurity)
if best_score gain
best_score = gain
best_score_feature = feature
Expand Down
5 changes: 3 additions & 2 deletions test/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ X = [1 2;
y = [1, 2]

indexes = Vector{Bool}(undef, 2)
@test collect(ST._view_y!(indexes, X[:, 1], [1 2], <, 2)) == [1]
@test collect(ST._view_y!(indexes, X[:, 1], [1 2], >, 2)) == [2]
feature = 1
@test collect(ST._view_y!(indexes, X, feature, [1 2], <, 2)) == [1]
@test collect(ST._view_y!(indexes, X, feature, [1 2], >, 2)) == [2]

@test ST._cutpoints([3, 1, 2], 2) == Float[1, 2]
@test ST._cutpoints(1:9, 3) == Float[3, 5, 7]
Expand Down
3 changes: 2 additions & 1 deletion test/mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ let
e = _evaluate_baseline!(results, "boston")
end

let
function _evaluate_boston()
hyper = (; rng=_rng(), n_trees=1_500)
e = _evaluate!(results, "boston", StableForestClassifier, hyper)

e = _evaluate!(results, "boston", StableRulesClassifier, hyper)
end
_evaluate_boston()

pretty = rename(results, :se => "1.96*SE")
rename!(pretty, :nfolds => "`nfolds`")
Expand Down

0 comments on commit ed24751

Please sign in to comment.