Skip to content

Commit

Permalink
Extend API
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Nov 14, 2023
1 parent 73322b3 commit 71abde0
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SIRUS"
uuid = "cdeec39e-fb35-4959-aadb-a1dd5dede958"
authors = ["Rik Huijzer <github@huijzer.xyz>"]
version = "1.3.4"
version = "2.0.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
2 changes: 2 additions & 0 deletions src/SIRUS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ include("ruleshow.jl")
include("weights.jl")
export TreePath
include("dependent.jl")
include("extract.jl")
export sum_weights

include("mlj.jl")
const StableForestClassifier = MLJImplementation.StableForestClassifier
Expand Down
41 changes: 41 additions & 0 deletions src/extract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"Estimate the importance of a rule."
function _rule_importance(weight::Number, rule::Rule)
importance = 0.0
thens = rule.then::Vector{Float64}
otherwises = rule.otherwise::Vector{Float64}
for (then, otherwise) in zip(thens, otherwises)
importance += weight * abs(then - otherwise)
end
return importance
end

"""
feature_importance(
model::StableRules,
feature_name::AbstractString
)
Estimate the importance of the given `feature_name`.
The aim of this function is to satisfy the following property:
> Given two features X and Y, if X has more effect on the outcome, then
> feature_importance(model, X) > feature_importance(model, Y).
This function provides only an estimation of the importance because
the effect on the outcome depends on the data.
"""
function feature_importance(
model::StableRules,
feature_name::AbstractString
)
importance = 0.0
for (i, rule) in enumerate(model.rules)
for clause::Split in rule.path.splits
if _feature_name(clause) == feature_name
weight = model.weights[i]
importance += _rule_importance(weight, rule)
end
end
end
return importance
end
8 changes: 6 additions & 2 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ Each rule is based on one or more splits.
Data can be accessed via `_feature`, `_value`, `_feature_name`, `_direction`, and `_reverse`.
"""
struct Split
splitpoint::SplitPoint
struct SubClause
# Removed splitpoint
# splitpoint::SplitPoint
feature::Int,
feature_name::String,
splitval::Float32,
direction::Symbol # :L or :R
end

Expand Down
16 changes: 16 additions & 0 deletions test/extract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
function _haberman_data()
df = haberman()
X = MLJBase.table(MLJBase.matrix(df[:, Not(:survival)]))
y = categorical(df.survival)
(X, y)
end

X, y = _haberman_data()

classifier = StableRulesClassifier(; max_depth=2, max_rules=8, n_trees=1000, rng=_rng())
mach = machine(classifier, X, y)
fit!(mach)

model = mach.fitresult::StableRules

importance = SIRUS.feature_importance(model, "x1")
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ if CAN_RUN_R_SIRUS
end
end

@testset "extract" begin
include("extract.jl")
end

@testset "mlj" begin
include("mlj.jl")
end
Expand Down

0 comments on commit 71abde0

Please sign in to comment.