Skip to content

Commit

Permalink
Use mode since it makes more sense
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Jun 21, 2023
1 parent df46527 commit 58feb1c
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
PrecompileSignatures = "91cefc8d-f054-46dc-8f8c-26e11d7c5411"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Expand All @@ -22,5 +23,6 @@ InlineStrings = "1"
MLJLinearModels = "0.8, 0.9"
MLJModelInterface = "1.4"
PrecompileSignatures = "3"
StatsBase = "0.34"
Tables = "1.7"
julia = "1.6"
1 change: 1 addition & 0 deletions src/SIRUS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using MLJModelInterface: UnivariateFinite, Probabilistic, fit
using PrecompileSignatures: @precompile_signatures
using Random: AbstractRNG, default_rng, seed!, shuffle
using Statistics: mean, median
using StatsBase: mode
using Tables: Tables, matrix

export feature_names, directions, satisfies
Expand Down
3 changes: 2 additions & 1 deletion src/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,15 @@ end

_mean(V::AbstractVector{<:AbstractVector}) = _apply_statistic(V, mean)
_median(V::AbstractVector{<:AbstractVector}) = _apply_statistic(V, median)
_mode(V::AbstractVector{<:AbstractVector}) = _apply_statistic(V, mode)

function _predict(forest::StableForest, row::AbstractVector)
isempty(_elements(forest)) && _isempty_error(forest)
predictions = [_predict(tree, row) for tree in forest.trees]
if forest.algo isa Classification
return _median(predictions)
else
m = mean(predictions)
m = median(predictions)
@assert m isa Number
return m
end
Expand Down
2 changes: 1 addition & 1 deletion src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ function _combine_paths(
rules = duplicate_paths[path]
# Taking the mode because that might make more sense here.
# Doesn't seem to affect accuracy so much.
aggregate = algo isa Classification ? _median : _mean
aggregate = algo isa Classification ? _mode : _mean
then = aggregate(getproperty.(rules, :then))
otherwise = aggregate(getproperty.(rules, :otherwise))
combined_rule = Rule(path, then, otherwise)
Expand Down

0 comments on commit 58feb1c

Please sign in to comment.