Skip to content

Commit

Permalink
Tune a few things
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Jun 7, 2023
1 parent 141d61b commit ee4f104
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- ubuntu-latest

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
name: Documentation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-docdeploy@v1
env:
Expand Down
44 changes: 16 additions & 28 deletions src/forest.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
"""
Return the number of elements in `V` being equal to `x`.
This method allocates less memory than `count(V .== y)`.
"""
function _count_equal(V::AbstractVector, x)::Int
c = 0
for v in V
if x == v
c += 1
end
end
return c
end

"""
_gini(y::AbstractVector, classes::AbstractVector)
Expand All @@ -27,13 +13,25 @@ function _gini(y::AbstractVector, classes)
len_y = length(y)
impurity = 1.0
for class in classes
c = _count_equal(y, class)
c = count(==(class), y)
proportion = c / len_y
impurity -= proportion^2
end
return impurity
end

function _information_gain(
y,
yl,
yr,
classes,
starting_impurity::Real
)
p = length(yl) / length(y)
impurity_change = p * _gini(yl, classes) + (1 - p) * _gini(yr, classes)
return starting_impurity - impurity_change
end

"""
Return a rough estimate for the index of the cutpoint.
Choose the highest suitable index if there is more than one suitable index.
Expand Down Expand Up @@ -120,29 +118,19 @@ _feature(sp::SplitPoint) = sp.feature
_value(sp::SplitPoint) = sp.value
_feature_name(sp::SplitPoint) = sp.feature_name

function _information_gain(
y,
yl,
yr,
classes,
starting_impurity::Real
)
p = length(yl) / length(y)
impurity_change = p * _gini(yl, classes) + (1 - p) * _gini(yr, classes)
return starting_impurity - impurity_change
end

"Return a random subset of `V` sampled without replacement."
function _rand_subset(rng::AbstractRNG, V::AbstractVector, n::Int)
return view(shuffle(rng, V), 1:n)
end

"""
Return a view on all `y` for which the `comparison` holds in `data`.
The mutable `y_view` is used to have a view of `y` in continuous memory.
"""
function _view_y!(y_view, X, feature::Int, y, comparison, cutpoint)
len = 0
for i in eachindex(y)
@inbounds for i in eachindex(y)
value = @inbounds X[i, feature]
result = comparison(value, cutpoint)
if result
Expand Down

0 comments on commit ee4f104

Please sign in to comment.