From ee4f104fa7666c4514d2f286e5b5b3258fe7e5bc Mon Sep 17 00:00:00 2001 From: rikhuijzer Date: Wed, 7 Jun 2023 18:31:46 +0200 Subject: [PATCH] Tune a few things --- .github/workflows/CI.yml | 2 +- .github/workflows/Docs.yml | 2 +- src/forest.jl | 44 ++++++++++++++------------------------ 3 files changed, 18 insertions(+), 30 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6e34e22..99f8d9f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -25,7 +25,7 @@ jobs: - ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml index 24bd96a..8a32290 100644 --- a/.github/workflows/Docs.yml +++ b/.github/workflows/Docs.yml @@ -14,7 +14,7 @@ jobs: name: Documentation runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-docdeploy@v1 env: diff --git a/src/forest.jl b/src/forest.jl index 5c46090..ae1859f 100644 --- a/src/forest.jl +++ b/src/forest.jl @@ -1,17 +1,3 @@ -""" -Return the number of elements in `V` being equal to `x`. -This method allocates less memory than `count(V .== y)`. -""" -function _count_equal(V::AbstractVector, x)::Int - c = 0 - for v in V - if x == v - c += 1 - end - end - return c -end - """ _gini(y::AbstractVector, classes::AbstractVector) @@ -27,13 +13,25 @@ function _gini(y::AbstractVector, classes) len_y = length(y) impurity = 1.0 for class in classes - c = _count_equal(y, class) + c = count(==(class), y) proportion = c / len_y impurity -= proportion^2 end return impurity end +function _information_gain( + y, + yl, + yr, + classes, + starting_impurity::Real + ) + p = length(yl) / length(y) + impurity_change = p * _gini(yl, classes) + (1 - p) * _gini(yr, classes) + return starting_impurity - impurity_change +end + """ Return a rough estimate for the index of the cutpoint. Choose the highest suitable index if there is more than one suitable index. @@ -120,18 +118,6 @@ _feature(sp::SplitPoint) = sp.feature _value(sp::SplitPoint) = sp.value _feature_name(sp::SplitPoint) = sp.feature_name -function _information_gain( - y, - yl, - yr, - classes, - starting_impurity::Real - ) - p = length(yl) / length(y) - impurity_change = p * _gini(yl, classes) + (1 - p) * _gini(yr, classes) - return starting_impurity - impurity_change -end - "Return a random subset of `V` sampled without replacement." function _rand_subset(rng::AbstractRNG, V::AbstractVector, n::Int) return view(shuffle(rng, V), 1:n) @@ -139,10 +125,12 @@ end """ Return a view on all `y` for which the `comparison` holds in `data`. + +The mutable `y_view` is used to have a view of `y` in continuous memory. """ function _view_y!(y_view, X, feature::Int, y, comparison, cutpoint) len = 0 - for i in eachindex(y) + @inbounds for i in eachindex(y) value = @inbounds X[i, feature] result = comparison(value, cutpoint) if result