Skip to content

Commit

Permalink
Reuse mask in each thread
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Jun 8, 2023
1 parent 5619bbd commit 82c9015
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,9 @@ function _view_X_y!(mask, X, y, splitpoint::SplitPoint, comparison)
result = comparison(value, splitpoint.value)
@inbounds mask[i] = result
end
X_view = @inbounds view(X, mask, :)
y_view = @inbounds view(y, mask)
mask_subset = view(mask, 1:length(y))
X_view = @inbounds view(X, mask_subset, :)
y_view = @inbounds view(y, mask_subset)
return (X_view, y_view)
end

Expand Down Expand Up @@ -262,8 +263,9 @@ Arguments:
During random forest creation, the number of split candidates is limited to make the trees less correlated.
See Section 8.2.2 of https://doi.org/10.1007/978-1-0716-1418-1 for details.
"""
function _tree(
function _tree!(
rng::AbstractRNG,
mask::Vector{Bool},
X,
y::AbstractVector,
classes::AbstractVector,
Expand All @@ -288,14 +290,13 @@ function _tree(
end
depth += 1

mask = Vector{Bool}(undef, length(y))
left = let
_X, yl = _view_X_y!(mask, X, y, sp, <)
_tree(rng, _X, yl, classes, colnames; cutpoints, depth, max_depth)
_tree!(rng, mask, _X, yl, classes, colnames; cutpoints, depth, max_depth)
end
right = let
_X, yr = _view_X_y!(mask, X, y, sp, )
_tree(rng, _X, yr, classes, colnames; cutpoints, depth, max_depth)
_tree!(rng, mask, _X, yr, classes, colnames; cutpoints, depth, max_depth)
end
node = Node(sp, left, right)
return node
Expand Down Expand Up @@ -391,8 +392,10 @@ function _forest(
rows = rand(_rng, 1:length(y), n_samples)
_X = view(X, rows, :)
_y = view(y, rows)
tree = _tree(
mask = Vector{Bool}(undef, length(y))
tree = _tree!(
_rng,
mask,
_X,
_y,
classes,
Expand Down

0 comments on commit 82c9015

Please sign in to comment.