Skip to content

Commit

Permalink
Speed up loop by reusing vector
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Jun 7, 2023
1 parent 7bcf8e4 commit 141d61b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
34 changes: 20 additions & 14 deletions src/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ _feature_name(sp::SplitPoint) = sp.feature_name

function _information_gain(
y,
y_left,
y_right,
yl,
yr,
classes,
starting_impurity::Real
)
p = length(y_left) / length(y)
impurity_change = p * _gini(y_left, classes) + (1 - p) * _gini(y_right, classes)
p = length(yl) / length(y)
impurity_change = p * _gini(yl, classes) + (1 - p) * _gini(yr, classes)
return starting_impurity - impurity_change
end

Expand All @@ -139,14 +139,18 @@ end

"""
Return a view on all `y` for which the `comparison` holds in `data`.
`mask` is re-used between calls.
"""
function _view_y!(mask, X, feature::Int, y, comparison, cutpoint)
function _view_y!(y_view, X, feature::Int, y, comparison, cutpoint)
len = 0
for i in eachindex(y)
value = @inbounds X[i, feature]
mask[i] = comparison(value, cutpoint)
result = comparison(value, cutpoint)
if result
len += 1
@inbounds y_view[len] = y[i]
end
end
return @inbounds view(y, mask)
return @inbounds view(y_view, 1:len)
end

"""
Expand All @@ -169,15 +173,17 @@ function _split(
p = _p(X)
mc = max_split_candidates
possible_features = mc == p ? (1:p) : _rand_subset(rng, 1:p, mc)
mask = Vector{Bool}(undef, length(y))
starting_impurity = _gini(y, classes)

yl = Vector{eltype(y)}(undef, length(y))
yr = Vector{eltype(y)}(undef, length(y))
for feature in possible_features
for cutpoint in cutpoints[feature]
y_left = _view_y!(mask, X, feature, y, <, cutpoint)
isempty(y_left) && continue
y_right = _view_y!(mask, X, feature, y, , cutpoint)
isempty(y_right) && continue
gain = _information_gain(y, y_left, y_right, classes, starting_impurity)
vl = _view_y!(yl, X, feature, y, <, cutpoint)
isempty(vl) && continue
vr = _view_y!(yr, X, feature, y, , cutpoint)
isempty(vr) && continue
gain = _information_gain(y, vl, vr, classes, starting_impurity)
if best_score gain
best_score = gain
best_score_feature = feature
Expand Down
6 changes: 3 additions & 3 deletions test/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ X = [1 2;
3 4]
y = [1, 2]

indexes = Vector{Bool}(undef, 2)
y_view = Vector{eltype(y)}(undef, length(y))
feature = 1
@test collect(ST._view_y!(indexes, X, feature, [1 2], <, 2)) == [1]
@test collect(ST._view_y!(indexes, X, feature, [1 2], >, 2)) == [2]
@test collect(ST._view_y!(y_view, X, feature, [1 2], <, 2)) == [1]
@test collect(ST._view_y!(y_view, X, feature, [1 2], >, 2)) == [2]

@test ST._cutpoints([3, 1, 2], 2) == Float[1, 2]
@test ST._cutpoints(1:9, 3) == Float[3, 5, 7]
Expand Down

0 comments on commit 141d61b

Please sign in to comment.