Skip to content

Commit

Permalink
Write down problem
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Nov 23, 2023
1 parent 0714a56 commit f2fe3ee
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 9 deletions.
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ StableForestRegressor

```@docs
feature_names
feature_importance
feature_importances
directions
values(::SIRUS.Rule)
satisfies
Cutpoints
cutpoints
feature_importance
```
2 changes: 1 addition & 1 deletion src/SIRUS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ include("ruleshow.jl")
include("weights.jl")
include("dependent.jl")
include("extract.jl")
export feature_importance
export feature_importance, feature_importances

include("mlj.jl")
const StableForestClassifier = MLJImplementation.StableForestClassifier
Expand Down
31 changes: 30 additions & 1 deletion src/extract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ end
)
Estimate the importance of the given `feature_name`.
The aim is to satisfy the following property:
The aim is to satisfy the following property, so that the features can be
ordered by importance:
> Given two features A and B, if A has more effect on the outcome, then
> feature_importance(model, A) > feature_importance(model, B).
Expand Down Expand Up @@ -63,3 +64,31 @@ end
function feature_importance(models::Vector{<:StableRules}, feature_name::AbstractString)
return feature_importance(models, string(feature_name)::String)
end

"""
feature_importances(
models::Union{StableRules, Vector{StableRules}}
feature_names
)::Vector{NamedTuple{(:feature_name, :importance), Tuple{String, Float64}}}
Return the feature names and importances, sorted by feature importance in descending order.
"""
function feature_importances(
models::Union{StableRules, Vector{StableRules}},
feature_names::Vector{String}
)::Vector{NamedTuple{(:feature_name, :importance), Tuple{String, Float64}}}
@assert length(unique(feature_names)) == length(feature_names)
importances = map(feature_names) do feature_name
importance = feature_importance(models, feature_name)
(; feature_name, importance)
end
alg = Helpers.STABLE_SORT_ALG
return sort(importances; alg, by=last, rev=true)
end

function feature_importances(
models::Union{StableRules, Vector{StableRules}},
feature_names
)::Vector{NamedTuple{(:feature_name, :importance), Tuple{String, Float64}}}
return feature_importances(models, string.(feature_names))
end
10 changes: 7 additions & 3 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ function _count_unique(V::AbstractVector{T}) where T
return counts
end

# TODO IS THE PROBLEM HERE?
# MAYBE THE PROBLEM IS THAT ONLY THE UNIQUE RULE CLAUSES SHOULD BE CHECKED
# NOT THE THEN/OTHERWISE.

"""
Return a vector of unique values in `V` sorted by frequency.
"""
Expand All @@ -383,9 +387,9 @@ function _process_rules(
rules::Vector{Rule},
max_rules::Int
)::Vector{Rule}
simplified = _simplify_single_rules(rules)
sorted = _sort_by_frequency(simplified)
filtered = _filter_linearly_dependent(sorted)
simplified = _simplify_single_rules(rules)::Vector{Rule}
sorted = _sort_by_frequency(simplified)::Vector{Rule}
filtered = _filter_linearly_dependent(sorted)::Vector{Rule}
return first(filtered, max_rules)
end

Expand Down
16 changes: 13 additions & 3 deletions test/extract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,22 @@ mach = machine(classifier, X, y)
fit!(mach)

model = mach.fitresult::StableRules
# StableRules model with 8 rules:
# if X[i, :x3] < 8.0 then 0.084 else 0.03 +
# if X[i, :x3] < 14.0 then 0.147 else 0.098 +
# if X[i, :x3] < 2.0 then 0.073 else 0.047 +
# if X[i, :x3] < 4.0 then 0.079 else 0.048 +
# if X[i, :x3] < 1.0 then 0.076 else 0.06 +
# if X[i, :x2] < 1959.0 then 0.006 else 0.008 +
# if X[i, :x1] < 38.0 then 0.029 else 0.024 +
# if X[i, :x1] < 42.0 then 0.052 else 0.043
# and 2 classes: [0, 1].
# Note: showing only the probability for class 1 since class 0 has probability 1 - p.

importance = feature_importance(model, "x1")
# Based on the numbers that are printed in the following lines:
# if X[i, :x1] < 38.0 then 0.029 else 0.024 +
# if X[i, :x1] < 42.0 then 0.052 else 0.043
# Based on the numbers above.
expected = ((0.029 - 0.024) + (0.052 - 0.043))
@test importance expected atol=0.01

@test feature_importance([model, model], "x1") expected atol=0.01
@test only(feature_importances(model, ["x1"])).importance expected atol=0.01

0 comments on commit f2fe3ee

Please sign in to comment.