diff --git a/src/forest.jl b/src/forest.jl index 2ecd39d..5c46090 100644 --- a/src/forest.jl +++ b/src/forest.jl @@ -122,13 +122,13 @@ _feature_name(sp::SplitPoint) = sp.feature_name function _information_gain( y, - y_left, - y_right, + yl, + yr, classes, starting_impurity::Real ) - p = length(y_left) / length(y) - impurity_change = p * _gini(y_left, classes) + (1 - p) * _gini(y_right, classes) + p = length(yl) / length(y) + impurity_change = p * _gini(yl, classes) + (1 - p) * _gini(yr, classes) return starting_impurity - impurity_change end @@ -139,14 +139,18 @@ end """ Return a view on all `y` for which the `comparison` holds in `data`. -`mask` is re-used between calls. """ -function _view_y!(mask, X, feature::Int, y, comparison, cutpoint) +function _view_y!(y_view, X, feature::Int, y, comparison, cutpoint) + len = 0 for i in eachindex(y) value = @inbounds X[i, feature] - mask[i] = comparison(value, cutpoint) + result = comparison(value, cutpoint) + if result + len += 1 + @inbounds y_view[len] = y[i] + end end - return @inbounds view(y, mask) + return @inbounds view(y_view, 1:len) end """ @@ -169,15 +173,17 @@ function _split( p = _p(X) mc = max_split_candidates possible_features = mc == p ? (1:p) : _rand_subset(rng, 1:p, mc) - mask = Vector{Bool}(undef, length(y)) starting_impurity = _gini(y, classes) + + yl = Vector{eltype(y)}(undef, length(y)) + yr = Vector{eltype(y)}(undef, length(y)) for feature in possible_features for cutpoint in cutpoints[feature] - y_left = _view_y!(mask, X, feature, y, <, cutpoint) - isempty(y_left) && continue - y_right = _view_y!(mask, X, feature, y, ≥, cutpoint) - isempty(y_right) && continue - gain = _information_gain(y, y_left, y_right, classes, starting_impurity) + vl = _view_y!(yl, X, feature, y, <, cutpoint) + isempty(vl) && continue + vr = _view_y!(yr, X, feature, y, ≥, cutpoint) + isempty(vr) && continue + gain = _information_gain(y, vl, vr, classes, starting_impurity) if best_score ≤ gain best_score = gain best_score_feature = feature diff --git a/test/forest.jl b/test/forest.jl index 8659120..11b1591 100644 --- a/test/forest.jl +++ b/test/forest.jl @@ -7,10 +7,10 @@ X = [1 2; 3 4] y = [1, 2] -indexes = Vector{Bool}(undef, 2) +y_view = Vector{eltype(y)}(undef, length(y)) feature = 1 -@test collect(ST._view_y!(indexes, X, feature, [1 2], <, 2)) == [1] -@test collect(ST._view_y!(indexes, X, feature, [1 2], >, 2)) == [2] +@test collect(ST._view_y!(y_view, X, feature, [1 2], <, 2)) == [1] +@test collect(ST._view_y!(y_view, X, feature, [1 2], >, 2)) == [2] @test ST._cutpoints([3, 1, 2], 2) == Float[1, 2] @test ST._cutpoints(1:9, 3) == Float[3, 5, 7]