Skip to content

Commit

Permalink
Fix all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Nov 21, 2023
1 parent 4f9321e commit be9a7d6
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
10 changes: 7 additions & 3 deletions src/extract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,20 @@ the effect on the outcome depends on the data.
"""
function feature_importance(
model::StableRules,
feature_name::AbstractString
feature_name::String
)
importance = 0.0
for (i, rule) in enumerate(model.rules)
for clause::Split in rule.path.splits
if _feature_name(clause) == feature_name
for subclause::SubClause in _subclauses(rule)
if _feature_name(subclause)::String == feature_name
weight = model.weights[i]
importance += _rule_importance(weight, rule)
end
end
end
return importance
end

function feature_importance(model::StableRules, feature_name::AbstractString)
return feature_importance(model, string(feature_name)::String)
end
2 changes: 1 addition & 1 deletion src/ruleshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function _pretty_feature_name(subclause::SubClause)
if feature == feature_name
return feature
else
return string(':', name)::String
return string(':', feature_name)::String
end
end

Expand Down
2 changes: 2 additions & 0 deletions test/rcall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# This file defines the MLJ wrappers around R sirus and tests them.
# Actual comparisons against other models are done in test/mlj.jl.
#
# Assumes sirus is installed in R `install.packages("sirus")`.
#

import MLJModelInterface:
MLJModelInterface,
Expand Down
4 changes: 2 additions & 2 deletions test/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ data = [0.0 2.5
5.0 5.0
0.0 0.0]

r1 = S.Rule(S.TreePath(" X[i, 1] < 1 "), [0.1], [0.0])
r2 = S.Rule(S.TreePath(" X[i, 2] < 2 "), [0.2], [0.0])
r1 = S.Rule(S.Clause(" X[i, 1] < 1 "), [0.1], [0.0])
r2 = S.Rule(S.Clause(" X[i, 2] < 2 "), [0.2], [0.0])

binary_feature_data = Float16[1 0;
0 0;
Expand Down

0 comments on commit be9a7d6

Please sign in to comment.