Skip to content

Commit

Permalink
Reduce allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Jun 8, 2023
1 parent ee4f104 commit 5619bbd
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
1 change: 1 addition & 0 deletions .github/workflows/Docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-docdeploy@v1
env:
Expand Down
28 changes: 18 additions & 10 deletions src/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ function _split(
end
end
end
if best_score == Float(0)
if best_score == 0.0
return nothing
end
feature_name = colnames[best_score_feature]
Expand Down Expand Up @@ -217,11 +217,17 @@ nodevalue(node::Node) = node.splitpoint
"""
Return a view on all rows in `X` and `y` for which the `comparison` holds in `X[:, feature]`.
"""
function _view_X_y(X, y, splitpoint::SplitPoint, comparison)
data = view(X, :, splitpoint.feature)
mask = comparison.(data, splitpoint.value)
X_view = view(X, mask, :)
y_view = view(y, mask)
function _view_X_y!(mask, X, y, splitpoint::SplitPoint, comparison)
data = @inbounds view(X, :, splitpoint.feature)
@assert length(data) == length(y)
len = 0
for i in eachindex(y)
value = @inbounds data[i]
result = comparison(value, splitpoint.value)
@inbounds mask[i] = result
end
X_view = @inbounds view(X, mask, :)
y_view = @inbounds view(y, mask)
return (X_view, y_view)
end

Expand Down Expand Up @@ -281,13 +287,15 @@ function _tree(
return Leaf(classes, y)
end
depth += 1

mask = Vector{Bool}(undef, length(y))
left = let
_X, _y = _view_X_y(X, y, sp, <)
_tree(rng, _X, _y, classes, colnames; cutpoints, depth, max_depth)
_X, yl = _view_X_y!(mask, X, y, sp, <)
_tree(rng, _X, yl, classes, colnames; cutpoints, depth, max_depth)
end
right = let
_X, _y = _view_X_y(X, y, sp, )
_tree(rng, _X, _y, classes, colnames; cutpoints, depth, max_depth)
_X, yr = _view_X_y!(mask, X, y, sp, )
_tree(rng, _X, yr, classes, colnames; cutpoints, depth, max_depth)
end
node = Node(sp, left, right)
return node
Expand Down

0 comments on commit 5619bbd

Please sign in to comment.