diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml index e23ed37..69aa3c3 100644 --- a/.github/workflows/Docs.yml +++ b/.github/workflows/Docs.yml @@ -24,6 +24,7 @@ jobs: - uses: julia-actions/cache@v1.3.0 with: cache-name: 'docs' + - run: julia -e 'using Pkg; Pkg.add("Revise");' - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-docdeploy@v1 env: diff --git a/.github/workflows/Typos.yml b/.github/workflows/Typos.yml index 99bbe62..1f9531e 100644 --- a/.github/workflows/Typos.yml +++ b/.github/workflows/Typos.yml @@ -14,3 +14,5 @@ jobs: - uses: actions/checkout@v4 - uses: crate-ci/typos@master + with: + config: './test/typos.toml' diff --git a/docs/Project.toml b/docs/Project.toml index 1505ad1..fbc33b9 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -19,7 +19,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] CSV = "0.10" -CairoMakie = "0.10" +CairoMakie = "0.11" CategoricalArrays = "0.10" DataDeps = "0.7" DataFrames = "1" diff --git a/docs/src/api.md b/docs/src/api.md index c321985..8967c25 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,6 +1,6 @@ # API -## Types +## MLJ Interface Types ```@docs StableRulesClassifier @@ -9,13 +9,35 @@ StableForestClassifier StableForestRegressor ``` -## Methods +## SIRUS Types ```@docs +SubClause +Clause +Rule +``` + +## SIRUS Methods + +```@docs +feature +features +feature_name feature_names +splitval +splitvals +clause +subclauses +direction directions -values(::SIRUS.Rule) -satisfies -Cutpoints +feature_importance +feature_importances +then +otherwise +gap_size cutpoints +satisfies +unpack_rule +unpack_model +unpack_models ``` diff --git a/docs/src/binary-classification.jl b/docs/src/binary-classification.jl index 1867f39..4d9a2b9 100644 --- a/docs/src/binary-classification.jl +++ b/docs/src/binary-classification.jl @@ -17,6 +17,10 @@ begin Pkg.develop(; path=PKGDIR) end +# ╔═╡ 27bd0e48-9870-472f-8d78-a9b460c9e858 +# hideall +using Revise + # ╔═╡ f833dab6-31d4-4353-a68b-ef0501d606d4 begin using CairoMakie @@ -230,9 +234,9 @@ Since we know that the model performs well on the cross-validations, we can fit # ╔═╡ 3c415a26-803e-4f35-866f-2e582c6c1c45 md""" -## Visualization +## Plot -Since our rules are relatively simple with only a binary outcome and only one subclause in each rule, the following figure is a way to visualize the obtained rules per fold. +Since our rules are relatively simple with only a binary outcome and only one subclause in each rule (because of `max_depth=1`), the following figure is a way to visualize the obtained rules per fold. For multiple subclauses, I would not know how to visualize the rules. Also, this plot is probably not perfect; let me know if you have suggestions. @@ -330,7 +334,7 @@ md""" # ╔═╡ ede038b3-d92e-4208-b8ab-984f3ca1810e function _plot_cutpoints(data::AbstractVector) - fig = Figure(; resolution=(800, 100)) + fig = Figure(; size=(800, 100)) ax = Axis(fig[1, 1]) cps = Float64.(unique(cutpoints(data, 10))) scatter!(ax, data, fill(1, length(data))) @@ -427,7 +431,7 @@ ln = length(nodes); # ╔═╡ de90efc9-2171-4406-93a1-9a213ab32259 # hideall let - fig = Figure(; resolution=(800, 100)) + fig = Figure(; size=(800, 100)) ax = Axis(fig[1, 1]) scatter!(ax, nodes, fill(1, ln)) hideydecorations!(ax) @@ -441,7 +445,7 @@ index = length(nodes) - 3; # ╔═╡ 2c1adef4-822e-4dc0-946b-dc574e50b305 # hideall let - fig = Figure(; resolution=(800, 100)) + fig = Figure(; size=(800, 100)) ax = Axis(fig[1, 1]) scatter!(ax, nodes, fill(1, ln)) vlines!(ax, [nodes[index]]; color=:red) @@ -467,7 +471,7 @@ _plot_cutpoints(subset) # ╔═╡ 25ad7a18-f989-40f7-8ef1-4ca506446478 # hideall let - fig = Figure(; resolution=(800, 100)) + fig = Figure(; size=(800, 100)) ax = Axis(fig[1, 1]) scatter!(ax, subset, fill(1, ls)) vlines!(ax, [nodes[index]]; color=:red, linestyle=:dash) @@ -483,98 +487,49 @@ end _plot_cutpoints(nodes) # ╔═╡ a64dae3c-3b97-4076-98f4-3c9a0e5c0621 -# hideall -function _odds_plot(e::PerformanceEvaluation) +function _odds_plot(models::Vector{<:StableRules}, feat_names::Vector{String}) w, h = (1000, 300) - fig = Figure(; resolution=(w, h)) + fig = Figure(; size=(w, h)) grid = fig[1, 1:2] = GridLayout() - fitresults = getproperty.(e.fitted_params_per_fold, :fitresult) - feature_names = String[] - for fitresult in fitresults - for rule in fitresult.rules - name = only(SIRUS._subclauses(rule)).feature_name - push!(feature_names, name) - end - end - - names = sort(unique(feature_names)) - subtitle = "Ratio" - - max_height = maximum(maximum.(getproperty.(fitresults, :weights))) - - importances = _sum_weights.(Ref(fitresults), names) - - matching_rules = DataFrame(; names, importance=importances) - sort!(matching_rules, :importance; rev=true) - names = matching_rules.names - l = length(names) - - for (i, feature_name) in enumerate(names) - yticks = (1:1, [feature_name]) - ax = i == l ? - Axis(grid[i, 1:3]; yticks, xlabel="Ratio") : - Axis(grid[i, 1:3]; yticks) - vlines!(ax, [0]; color=:gray, linestyle=:dash) - xlims!(ax, -1, 1) - ylabel = feature_name - - name = feature_name - - nested_rules_weights = map(fitresults) do fitresult - subresult = Tuple{SIRUS.Rule,Float64}[] - zipped = zip(fitresult.rules, fitresult.weights) - for (rule, weight) in zipped - feat_name = only(SIRUS._subclauses(rule)).feature_name - if feat_name == feature_name - push!(subresult, (rule, weight)) - end - end - subresult - end - rules_weights = Tuple{SIRUS.Rule,Float64}[] - for nested in nested_rules_weights - isnothing(nested) && continue - for rule_weight in nested - push!(rules_weights, rule_weight) - end - end - rw::Vector{Tuple{SIRUS.Rule,Float64}} = - filter(!isnothing, rules_weights) - thresholds = _threshold.(first.(rw)) - t_mean = round(mean(thresholds); digits=1) - t_std = round(std(thresholds); digits=1) - - for (rule, weight) in rw - left = last(rule.then)::Float64 - right = last(rule.otherwise)::Float64 - t::Float64 = _threshold(rule) - ratio = log((right) / (left)) + @assert feat_names == sort(unique(feat_names)) + + probability_for_class_1(probs::Vector) = last(probs)::Float64 + # Gets the feature importances in order of importance. + importances = feature_importances(models, feat_names) + + # Create a row in the plot for each feature. + for (i, importance) in enumerate(importances) + feat_name, _ = importance + yticks = (1:1, [feat_name]) + axl = Axis(grid[i, 1:3]; yticks) + axr = Axis(grid[i, 4:5]) + vlines!(axl, [0]; color=:gray, linestyle=:dash) + xlims!(axl, -1, 1) + + unpacked_rules = unpack_models(models, feat_name)::Vector{NamedTuple} + # Create a dot and line for each rule that mentions the current feature. + for unpacked_rule::NamedTuple in unpacked_rules + left = probability_for_class_1(unpacked_rule.then) + right = probability_for_class_1(unpacked_rule.otherwise) + value = unpacked_rule.splitval + ratio = log(right / left) # area = πr² - markersize = 50 * sqrt(weight / π) - scatter!(ax, [ratio], [1]; color=:black, markersize) - end - hideydecorations!(ax; ticklabels=false) - - axr = i == l ? - Axis(grid[i, 4:5]; xlabel="Location") : - Axis(grid[i, 4:5]) - D = data[:, feature_name] - hist!(axr, D; scale_to=1) - vlines!(axr, thresholds; color=:black, linestyle=:dash) - - if i < l - hidexdecorations!(ax) - else - hidexdecorations!(ax; ticks=false, ticklabels=false) + markersize = 50 * sqrt(unpacked_rule.weight / π) + scatter!(axl, [ratio], [1]; color=:black, markersize) + + vlines!(axr, unpacked_rule.splitval; color=:black, linestyle=:dash) end - + + # Show a histogram in the background. + hist!(axr, data[:, feat_name]; scale_to=1) + + hidexdecorations!(axl) hideydecorations!(axr) hidexdecorations!(axr; ticks=false, ticklabels=false) end - rowgap!(grid, 5) - colgap!(grid, 50) + rowgap!(grid, 5) # hide return fig end; @@ -659,8 +614,10 @@ e4 = let end; # ╔═╡ 923affb5-b4ca-4b50-baa5-af29204d2081 -# hideall -_odds_plot(e4.e) +let + models = getproperty.(e4.e.fitted_params_per_fold, :fitresult) + _odds_plot(models, sort(names(X))) +end # ╔═╡ 7fad8dd5-c0a9-4c45-9663-d40a464bca77 # hideall @@ -712,10 +669,12 @@ results = let end # ╔═╡ Cell order: +# ╠═27bd0e48-9870-472f-8d78-a9b460c9e858 # ╠═7c10c275-54d8-4f1a-947f-7861199cdf21 # ╠═e9028115-d098-4c61-a82f-d4553fe654f8 # ╠═b1c17349-fd80-43f1-bbc2-53fdb539d1c0 # ╠═348d1235-87f2-4e8f-8f42-be89fef5bf87 +# ╠═f833dab6-31d4-4353-a68b-ef0501d606d4 # ╠═961aa273-d97b-497f-a79a-06bf89dc34b0 # ╠═6e16f844-9365-43af-9ea7-2984808f1fd5 # ╠═b6957225-1889-49fb-93e2-f022ca7c3b23 @@ -769,7 +728,6 @@ end # ╠═e7f396dc-38a7-40f7-9e5b-6fbea9d61789 # ╠═7c688412-d1b4-492d-bda2-0b9181057d4d # ╠═e1890517-7a44-4814-999d-6af27e2a136a -# ╠═f833dab6-31d4-4353-a68b-ef0501d606d4 # ╠═ede038b3-d92e-4208-b8ab-984f3ca1810e # ╠═93a7dd3b-7810-4021-bf6e-ae9c04acea46 # ╠═be324728-1b60-4584-b8ea-c4fe9e3466af diff --git a/src/SIRUS.jl b/src/SIRUS.jl index c9f66b4..a388d18 100644 --- a/src/SIRUS.jl +++ b/src/SIRUS.jl @@ -21,17 +21,24 @@ using .Helpers: colnames, nfeatures, view_feature include("empiricalquantiles.jl") using .EmpiricalQuantiles: Cutpoints, cutpoints -export Cutpoints, cutpoints +export cutpoints include("forest.jl") export StableForest include("classification.jl") include("regression.jl") include("rules.jl") -export StableRules, feature_names, directions, satisfies +export SubClause, feature, feature_name, splitval, direction +export Rule, Clause, clause, then, otherwise, subclauses +export features, feature_names, splitvals, directions +export StableRules, satisfies +export unpack_rule, unpack_model, unpack_models include("ruleshow.jl") -include("weights.jl") include("dependent.jl") +export gap_size +include("weights.jl") +include("importance.jl") +export feature_importance, feature_importances include("mlj.jl") const StableForestClassifier = MLJImplementation.StableForestClassifier diff --git a/src/dependent.jl b/src/dependent.jl index ebfa379..d470275 100644 --- a/src/dependent.jl +++ b/src/dependent.jl @@ -1,15 +1,15 @@ "Return whether `a` implies `b`." function _implies(a::SubClause, b::SubClause)::Bool - if _feature(a) == _feature(b) - if _direction(a) == :L - if _direction(b) == :L - return _splitval(a) ≤ _splitval(b) + if feature(a) == feature(b) + if direction(a) == :L + if direction(b) == :L + return splitval(a) ≤ splitval(b) else return false end else - if _direction(b) == :R - return _splitval(a) ≥ _splitval(b) + if direction(b) == :R + return splitval(a) ≥ splitval(b) else return false end @@ -24,8 +24,7 @@ Return whether `condition` implies `rule`, that is, whether `A & B => rule`. """ function _implies(condition::Tuple{SubClause, SubClause}, rule::Rule) A, B = condition - subclauses = _subclauses(rule) - implied = map(subclauses) do subclause + implied = map(subclauses(rule)) do subclause _implies(A, subclause) || _implies(B, subclause) end return all(implied) @@ -77,17 +76,17 @@ the rank increases when adding rules. # Example ```jldoctest -julia> A = SIRUS.SubClause(1, "1", 32000.0f0, :L); +julia> A = SubClause(1, "1", 32000.0f0, :L); -julia> B = SIRUS.SubClause(3, "3", 64.0f0, :L); +julia> B = SubClause(3, "3", 64.0f0, :L); -julia> r1 = SIRUS.Rule(SIRUS.Clause(" X[i, 1] < 32000.0 "), [0.061], [0.408]); +julia> r1 = Rule(Clause(" X[i, 1] < 32000.0 "), [0.061], [0.408]); -julia> r5 = SIRUS.Rule(SIRUS.Clause(" X[i, 3] < 64.0 "), [0.056], [0.334]); +julia> r5 = Rule(Clause(" X[i, 3] < 64.0 "), [0.056], [0.334]); -julia> r7 = SIRUS.Rule(SIRUS.Clause(" X[i, 1] ≥ 32000.0 & X[i, 3] ≥ 64.0 "), [0.517], [0.067]); +julia> r7 = Rule(Clause(" X[i, 1] ≥ 32000.0 & X[i, 3] ≥ 64.0 "), [0.517], [0.067]); -julia> r12 = SIRUS.Rule(SIRUS.Clause(" X[i, 1] ≥ 32000.0 & X[i, 3] < 64.0 "), [0.192], [0.102]); +julia> r12 = Rule(Clause(" X[i, 1] ≥ 32000.0 & X[i, 3] < 64.0 "), [0.192], [0.102]); julia> SIRUS.rank(SIRUS._feature_space([r1, r5], A, B)) 3 @@ -119,7 +118,7 @@ function _feature_space(rules::AbstractVector{Rule}, A::SubClause, B::SubClause) end "Canonicalize a SubClause by ensuring that the direction is left." -_canonicalize(s::SubClause) = _direction(s) == :L ? s : _reverse(s) +_canonicalize(s::SubClause) = direction(s) == :L ? s : _reverse(s) """ Return a vector of unique left splits for `rules`. @@ -128,16 +127,16 @@ For example, the pair `x[i, 1] < 32000` (A) and `x[i, 3] < 64` (B) will be used the feature space `A & B`, `A & !B`, `!A & B`, `!A & !B`. """ function _unique_left_subclauses(rules::Vector{Rule})::Vector{SubClause} - subclauses = SubClause[] + S = SubClause[] for rule in rules - for subclause in _subclauses(rule) - canonicalized = _canonicalize(subclause) - if !(canonicalized in subclauses) - push!(subclauses, canonicalized) + for s::SubClause in subclauses(rule) + canonicalized = _canonicalize(s) + if !(canonicalized in S) + push!(S, canonicalized) end end end - return subclauses + return S end """ @@ -165,16 +164,16 @@ Here, it is very important to get rid of rules which are about the same feature Otherwise, rules will be wrongly classified as linearly dependent in the next step. """ function _related_rule(rule::Rule, A::SubClause, B::SubClause)::Bool - @assert _direction(A) == :L - @assert _direction(B) == :L - subclauses = _subclauses(rule) - if length(subclauses) == 1 - subclause = only(subclauses) + @assert direction(A) == :L + @assert direction(B) == :L + S = subclauses(rule) + if length(S) == 1 + subclause = only(S) left_subclause = _canonicalize(subclause) return left_subclause == A || left_subclause == B - elseif length(subclauses) == 2 - l1 = _canonicalize(subclauses[1]) - l2 = _canonicalize(subclauses[2]) + elseif length(S) == 2 + l1 = _canonicalize(S[1]) + l2 = _canonicalize(S[2]) return (l1 == A && l2 == B) || (l1 == B && l2 == A) else @error "Rule $rule has more than two splits; this is not supported." @@ -208,7 +207,19 @@ function _linearly_dependent( return dependent end -function _gap_size(rule::Rule) +""" + gap_size(rule::Rule) + +Return the gap size for a rule. +The gap size is used by Bénard et al. in the appendix of their PMLR paper +(). +Via an example, they specify that the gap size is the difference between the +then and otherwise (else) probabilities. + +A smaller gap size implies a smaller CART-splitting criterion, which implies a +smaller occurrence frequency. +""" +function gap_size(rule::Rule) @assert length(rule.then) == length(rule.otherwise) gap_size_per_class = abs.(rule.then .- rule.otherwise) sum(gap_size_per_class) @@ -221,7 +232,7 @@ they have a smaller gap. """ function _sort_by_gap_size(rules::Vector{Rule})::Vector{Rule} alg = Helpers.STABLE_SORT_ALG - return sort(rules; alg, by=_gap_size, rev=true) + return sort(rules; alg, by=gap_size, rev=true) end """ diff --git a/src/empiricalquantiles.jl b/src/empiricalquantiles.jl index f09c516..2577ef5 100644 --- a/src/empiricalquantiles.jl +++ b/src/empiricalquantiles.jl @@ -5,7 +5,10 @@ using ..Helpers: nfeatures, view_feature -"Set of possible cutpoints, that is, empirical quantiles." +""" +A type that represents a vector of possible cutpoints, that is, empirical +quantiles. +""" const Cutpoints = Vector{Float32} """ @@ -36,7 +39,11 @@ function _empirical_quantile(V::AbstractVector, quantile::Real) return Float32(sorted[index]) end -"Return a vector of `q` cutpoints taken from the empirical distribution from data `V`." +""" + cutpoints(V::AbstractVector, q::Int) + +Return a vector of `q` cutpoints taken from the empirical distribution from data `V`." +""" function cutpoints(V::AbstractVector, q::Int) @assert 2 ≤ q # Taking 2 extra to avoid getting minimum(V) and maximum(V) becoming cutpoints. @@ -47,9 +54,11 @@ function cutpoints(V::AbstractVector, q::Int) end """ + cutpoints(X, q::Int) + Return a vector of vectors containing -- one inner vector for each feature in the dataset and -- inner vectors containing the unique cutpoints, that is, `length(V[i])` ≤ `q` for all i in V. +- one inner vector for each feature in the dataset `X` and +- inner vectors containing `q` unique cutpoints, that is, `length(V[i])` ≤ `q` for all i in V. Using unique here to avoid checking splits twice. """ diff --git a/src/importance.jl b/src/importance.jl new file mode 100644 index 0000000..530e4d3 --- /dev/null +++ b/src/importance.jl @@ -0,0 +1,99 @@ + +"Estimate the importance of a rule." +function _rule_importance(weight::Number, rule::Rule) + importance = 0.0 + gap = gap_size(rule) + n_classes = length(rule.then) + return (weight * gap) / n_classes +end + +""" + feature_importance( + models::Union{StableRules, Vector{StableRules}}, + feature_name::AbstractString + ) + +Estimate the importance of the given `feature_name`. +The aim is to satisfy the following property, so that the features can be +ordered by importance: + +> Given two features A and B, if A has more effect on the outcome, then +> feature_importance(model, A) > feature_importance(model, B). + +This is based on the [`gap_size`](@ref) function. The gap size is the +difference between the then and otherwise (else) probabilities. A smaller gap +size implies a smaller CART-splitting criterion, which implies a smaller +occurrence frequency (see the appendix at + for an example). + +!!! note + This function provides only an importance _estimate_ because the effect on + the outcome depends on the data. +""" +function feature_importance( + model::StableRules, + feat_name::String + ) + importance = 0.0 + found_feature = false + for (i, rule) in enumerate(model.rules) + for subclause::SubClause in subclauses(rule) + if feature_name(subclause)::String == feat_name + found_feature = true + weight = model.weights[i] + importance += _rule_importance(weight, rule) + end + end + end + if !found_feature + throw(ArgumentError("Feature `$feature_name` not found in the model.")) + end + return importance +end + +function feature_importance(model::StableRules, feature_name::AbstractString) + return feature_importance(model, string(feature_name)::String) +end + +function feature_importance( + models::Vector{<:StableRules}, + feature_name::String + ) + importance = 0.0 + for model in models + importance += feature_importance(model, feature_name) + end + return importance / length(models) +end + +function feature_importance(models::Vector{<:StableRules}, feature_name::AbstractString) + return feature_importance(models, string(feature_name)::String) +end + +""" + feature_importances( + models::Union{StableRules, Vector{StableRules}} + feat_names::Vector{String} + )::Vector{NamedTuple{(:feature_name, :importance), Tuple{String, Float64}}} + +Return the feature names and importances, sorted by feature importance in descending order. +""" +function feature_importances( + models::Union{StableRules, Vector{<:StableRules}}, + feat_names::Vector{String} + )::Vector{NamedTuple{(:feature_name, :importance), Tuple{String, Float64}}} + @assert length(unique(feat_names)) == length(feat_names) + importances = map(feat_names) do feat_name + importance = feature_importance(models, feat_name) + (; feature_name=feat_name, importance) + end + alg = Helpers.STABLE_SORT_ALG + return sort(importances; alg, by=last, rev=true) +end + +function feature_importances( + models::Union{StableRules, Vector{<:StableRules}}, + feature_names + )::Vector{NamedTuple{(:feature_name, :importance), Tuple{String, Float64}}} + return feature_importances(models, string.(feature_names)) +end diff --git a/src/rules.jl b/src/rules.jl index 6f662ee..02ff13f 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -1,5 +1,10 @@ """ - SubClause + SubClause( + feature::Int, + feature_name::AbstractString, + splitval::Number, + direction::Symbol + ) A subclause denotes a conditional on one feature. Each rule contains a clause with one or more subclauses. @@ -13,18 +18,16 @@ in the final model, as is discussed in the original SIRUS paper. The data inside a `SubClause` can be accessed via -- `_feature`, -- `_feature_name`, -- `_splitval`, and -- `_direction`. - -To obtain the reverse, use `_reverse`. +- [`feature(::SubClause)`](@ref), +- [`feature_name(::SubClause)`](@ref), +- [`splitval(::SubClause)`](@ref), and +- [`direction(::SubClause)`](@ref). Note: this name is not perfect. A formally better name would be "predicate atom", but that takes more characters and is also not very intuitive. -Instead, the word `Clause` and `SubClause` seem pretty short and clear. +Instead, the word `Clause` and `SubClause` seemed pretty short and clear. """ struct SubClause feature::Int @@ -50,10 +53,35 @@ function SubClause( return SubClause(sp.feature, sp.feature_name, sp.value, direction) end -_feature(s::SubClause) = s.feature -_feature_name(s::SubClause) = s.feature_name -_splitval(s::SubClause) = s.splitval -_direction(s::SubClause) = s.direction +""" + feature(s::SubClause) -> Int + +Return the feature number for a subclause. +""" +feature(s::SubClause)::Int = s.feature + +""" + feature_name(s::SubClause) -> String + +Return the feature name for a subclause. +""" +feature_name(s::SubClause)::String = s.feature_name + +""" + splitval(s::SubClause) + +Return the split value for a subclause. +The function currently returns a `Float32` but this might change in the future. +""" +splitval(s::SubClause) = s.splitval + +""" + direction(s::SubClause) -> Symbol + +Return the direction of the comparison for a subclause. +Can be either `:L` or `:R`, which is equivalent to `<` or `≥` respectively. +""" +direction(s::SubClause)::Symbol = s.direction function _reverse(s::SubClause) direction = s.direction == :L ? :R : :L @@ -66,7 +94,7 @@ function Base.:(==)(a::SubClause, b::SubClause) end """ - Clause + Clause(subclauses::Vector{SubClause}) A clause denotes a conditional on one or more features. Each rule contains a clause with one or more subclauses. @@ -80,13 +108,27 @@ As discussed above, in practice the number of subclauses or subclauses `d ≤ 2` Note that a path can also be a path to a node; not necessarily a leaf. -Data can be accessed via `_subclauses`. +Data can be accessed via [`subclauses`](@ref). + +Clauses can be constructed from a textual representation: + +### Example + +```jldoctest +julia> Clause(" X[i, 1] < 32000 ") +Clause(" X[i, 1] < 32000.0 ") +``` """ struct Clause subclauses::Vector{SubClause} end -_subclauses(c::Clause) = c.subclauses +""" + subclauses(c::Clause) -> Vector{SubClause} + +Return the subclauses for a clause. +""" +subclauses(c::Clause) = c.subclauses function Clause(text::String) try @@ -139,6 +181,38 @@ function Base.:(==)(a::Clause, b::Clause) return all(a.subclauses .== b.subclauses) end +""" + Rule(clause::Clause, then::LeafContent, otherwise::LeafContent) + +A rule is a clause with a then and otherwise probability. For example, the rule +`if X[i, 1] > 3 & X[i, 2] < 4, then 0.1 else 0.2` is a rule with two +subclauses. The name `otherwise` is used internally instead of `else` since +`else` is a reserved keyword. + +Data can be accessed via + +- [`clause(::Rule)`](@ref), +- [`subclauses(::Rule)`](@ref), +- [`then(::Rule)`](@ref), +- [`otherwise(::Rule)`](@ref), +- [`feature(::Rule)`](@ref), +- [`features(::Rule)`](@ref), +- [`feature_name(::Rule)`](@ref), +- [`feature_names(::Rule)`](@ref), +- [`splitval(::Rule)`](@ref), +- [`splitvals(::Rule)`](@ref), +- [`direction(::Rule)`](@ref), and +- [`directions(::Rule)`](@ref). + +Rules can be constructed from a textual representation: + +### Example + +```jldoctest +julia> Rule(Clause(" X[i, 1] < 32000 "), [0.1], [0.4]) +Rule(Clause(" X[i, 1] < 32000.0 "), [0.1], [0.4]) +``` +""" struct Rule clause::Clause then::LeafContent @@ -146,8 +220,64 @@ struct Rule otherwise::LeafContent end -_clause(rule::Rule) = rule.clause -_subclauses(rule::Rule) = rule.clause.subclauses +""" + clause(rule::Rule) -> Clause + +Return the clause for a rule. +The clause is a path in a decision tree after the conversion to rules. +A clause consists of one or more subclauses. +""" +clause(rule::Rule) = rule.clause + +""" + then(rule::Rule) + +Return the then probabilities for a rule. The return type is a vector of +probabilities; the exact element type may change over time. +""" +then(rule::Rule) = rule.then + +""" + otherwise(rule::Rule) + +Return the otherwise probabilities for a rule. The return type is a vector of +probabilities; the exact element type may change over time. +""" +otherwise(rule::Rule) = rule.otherwise + +""" + subclauses(rule::Rule) -> Vector{SubClause} + +Return the subclauses for a rule. +""" +subclauses(rule::Rule) = rule.clause.subclauses + +""" + feature(rule::Rule) -> Int + +Return the feature number for a rule with one subclause. +Throws an error if the rule has multiple subclauses. +Use [`features`](@ref) for rules with multiple subclauses. +""" +feature(rule::Rule)::Int = feature(only(subclauses(rule))) + +""" + features(rule::Rule) -> Vector{Int} + +Return a vector of feature numbers; one for each clause in `rule`. +""" +function features(rule::Rule)::Vector{Int} + return Int[feature(s) for s in subclauses(rule)] +end + +""" + feature_name(rule::Rule) -> String + +Return the feature name for a rule with one subclause. +Throws an error if the rule has multiple subclauses. +Use [`feature_names`](@ref) for rules with multiple subclauses. +""" +feature_name(rule::Rule)::String = feature_name(only(subclauses(rule))) """ feature_names(rule::Rule) -> Vector{String} @@ -155,25 +285,41 @@ _subclauses(rule::Rule) = rule.clause.subclauses Return a vector of feature names; one for each clause in `rule`. """ function feature_names(rule::Rule)::Vector{String} - return String[String(_feature_name(s))::String for s in _subclauses(rule)] + return String[feature_name(s)::String for s in subclauses(rule)] end """ - directions(rule::Rule) -> Vector{Symbol} + splitval(rule::Rule) -Return a vector of split directions; one for each clause in `rule`. +Return the splitvalue for a rule with one subclause. +Throws an error if the rule has multiple subclauses. +Use [`splitvals`](@ref) for rules with multiple subclauses. """ -function directions(rule::Rule)::Vector{Symbol} - return Symbol[_direction(s) for s in _subclauses(rule)] -end +splitval(rule::Rule) = splitval(only(subclauses(rule))) """ - values(rule::Rule) -> Vector{Float64} + splitvals(rule::Rule) -Return a vector split values; one for each subclause in `rule`. +Return the splitvalues for a rule; one for each subclause. """ -function Base.values(rule::Rule)::Vector{Float64} - return Float64[Float64(_splitval(s)) for s in _subclauses(rule)] +splitvals(rule::Rule) = splitval.(subclauses(rule)) + +""" + direction(rule::Rule) -> Symbol + +Return the direction for a rule with one subclause. +Throws an error if the rule has multiple subclauses. +Use [`directions`](@ref) for rules with multiple subclauses. +""" +direction(rule::Rule)::Symbol = direction(only(subclauses(rule))) + +""" + directions(rule::Rule) -> Vector{Symbol} + +Return a vector of split directions; one for each clause in `rule`. +""" +function directions(rule::Rule)::Vector{Symbol} + return Symbol[direction(s) for s in subclauses(rule)] end """ @@ -184,18 +330,18 @@ Assumes that the rule has only one split (clause) since two subclauses cannot be reversed. """ function _reverse(rule::Rule)::Rule - subclauses = _subclauses(rule) - @assert length(subclauses) == 1 - subclause = subclauses[1] + S = subclauses(rule) + @assert length(S) == 1 + subclause = S[1] clause = Clause([_reverse(subclause)]) return Rule(clause, rule.otherwise, rule.then) end function _left_rule(rule::Rule)::Rule - subclauses = _subclauses(rule) - @assert length(subclauses) == 1 - split = subclauses[1] - return _direction(split) == :L ? rule : _reverse(rule) + S = subclauses(rule) + @assert length(S) == 1 + s::SubClause = only(S) + return direction(s) == :L ? rule : _reverse(rule) end function Base.:(==)(a::Rule, b::Rule) @@ -203,7 +349,7 @@ function Base.:(==)(a::Rule, b::Rule) end function Base.hash(rule::Rule) - hash([_subclauses(rule), rule.then, rule.otherwise]) + hash([subclauses(rule), rule.then, rule.otherwise]) end function _then_output!( @@ -343,8 +489,7 @@ removing duplicates. function _simplify_single_rules(rules::Vector{Rule})::Vector{Rule} out = OrderedSet{Rule}() for rule in rules - subclauses = _subclauses(rule) - if length(subclauses) == 1 + if length(subclauses(rule)) == 1 left_rule = _left_rule(rule) push!(out, left_rule) else @@ -419,11 +564,11 @@ end Return whether data `row` satisfies `rule`. """ function satisfies(row::AbstractVector, rule::Rule)::Bool - constraints = map(_subclauses(rule)) do subclause - comparison = _direction(subclause) == :L ? (<) : (≥) - feature = _feature(subclause) - value = _splitval(subclause) - satisfies_constraint = comparison(row[feature], value) + constraints = map(subclauses(rule)) do subclause + comparison = direction(subclause) == :L ? (<) : (≥) + feat = feature(subclause) + value = splitval(subclause) + satisfies_constraint = comparison(row[feat], value) end return all(constraints) end @@ -446,3 +591,94 @@ function _predict(model::StableRules, row::AbstractVector) end return _sum(rule_predictions) end + +""" + unpack_rule(rule::Rule) -> NamedTuple + +Unpack a rule into it's components. This is useful for plotting. It returns a +named tuple with the following fields: + +- `feature` +- `feature_name` +- `splitval` +- `direction` +- `then` +- `otherwise` +""" +function unpack_rule(rule::Rule)::NamedTuple + return (; + feature=feature(rule), + feature_name=feature_name(rule), + splitval=splitval(rule), + direction=direction(rule), + then=then(rule), + otherwise=otherwise(rule) + ) +end + +""" + unpack_model(model::StableRules) -> Vector{NamedTuple} + +Unpack a model containing only single subclauses (`max_depth=1`) into it's +components. This is useful for plotting. It returns a vector of named tuples +with the following fields: + +- `weight` +- `feature` +- `feature_name` +- `splitval` +- `direction` +- `then` +- `otherwise` + +One row for each rule in the `model`. +""" +function unpack_model(model::StableRules)::Vector{NamedTuple} + @assert length(model.weights) == length(model.rules) + return map(zip(model.weights, model.rules)) do (weight, rule) + (; + weight=weight, + feature=feature(rule), + feature_name=feature_name(rule), + splitval=splitval(rule), + direction=direction(rule), + then=then(rule), + otherwise=otherwise(rule) + ) + end +end + +""" + unpack_models( + models::Vector{StableRules}, + feature_name::String + ) -> Vector{NamedTuple} + +Unpack a vector of models containing only single subclauses (`max_depth=1`) +into it's components. This is useful when plotting the rules that the model has +learned for each feature. It returns a vector of named tuples with the +following fields for each rule in the `models` that contains `feature_name`: + +- `weight` +- `feature` +- `feature_name` +- `splitval` +- `direction` +- `then` +- `otherwise` +""" +function unpack_models( + models::Vector{<:StableRules}, + feature_name::String +)::Vector{NamedTuple} + out = NamedTuple[] + for model in models + unpacked = unpack_model(model) + for nt in unpacked + if nt.feature_name == feature_name + push!(out, nt) + end + end + end + return out +end diff --git a/src/ruleshow.jl b/src/ruleshow.jl index d75ada8..1982cae 100644 --- a/src/ruleshow.jl +++ b/src/ruleshow.jl @@ -2,19 +2,19 @@ Return a feature name that can be shown as `[:, 1]` or `[:, :some_var]`. """ function _pretty_feature_name(subclause::SubClause) - feature = string(_feature(subclause)::Int)::String - feature_name = _feature_name(subclause)::String - if feature == feature_name - return feature + feat = string(feature(subclause)::Int)::String + feat_name = feature_name(subclause)::String + if feat == feat_name + return feat else - return string(':', feature_name)::String + return string(':', feat_name)::String end end function _pretty_clause(clause::Clause) - texts = map(_subclauses(clause)) do subclause - comparison = _direction(subclause) == :L ? '<' : '≥' - value = _splitval(subclause) + texts = map(subclauses(clause)) do subclause + comparison = direction(subclause) == :L ? '<' : '≥' + value = splitval(subclause) feature_descr = _pretty_feature_name(subclause) text = "X[i, $feature_descr] $comparison $value" end diff --git a/test/dependent.jl b/test/dependent.jl index 2c97cbe..69a3879 100644 --- a/test/dependent.jl +++ b/test/dependent.jl @@ -26,10 +26,10 @@ r15 = S.Rule(S.Clause(" X[i, 1] ≥ 32000 & X[i, 4] < 12 "), [0.192], [0.096]) r16 = S.Rule(S.Clause(" X[i, 2] ≥ 8000 & X[i, 4] ≥ 12 "), [0.586], [0.076]) r17 = S.Rule(S.Clause(" X[i, 2] ≥ 8000 & X[i, 4] < 12 "), [0.236], [0.094]) -@test S._gap_size(r1) == 0.408 - 0.061 -@test S._gap_size(r3) == 0.386 - 0.062 -@test S._gap_size(r3) < S._gap_size(r1) -@test S._gap_size(r12) < S._gap_size(r1) +@test gap_size(r1) == 0.408 - 0.061 +@test gap_size(r3) == 0.386 - 0.062 +@test gap_size(r3) < gap_size(r1) +@test gap_size(r12) < gap_size(r1) @test S._sort_by_gap_size([r3, r12, r1]) == [r1, r3, r12] @test S._filter_linearly_dependent([r1, r2, r3, r5]) == [r1, r3, r5] diff --git a/test/docs.jl b/test/docs.jl index 07f8905..8366866 100644 --- a/test/docs.jl +++ b/test/docs.jl @@ -1,7 +1,10 @@ api_docs = read(joinpath(pkgdir(SIRUS), "docs", "src", "api.md"), String) # Testing manually because setting doctest too restrictive doesn't work with PlutoStaticHTML. -for name in names(SIRUS) +for name::Symbol in names(SIRUS) + if name == :SIRUS + continue + end @test contains(api_docs, string(name)) end diff --git a/test/importance.jl b/test/importance.jl new file mode 100644 index 0000000..80cd64c --- /dev/null +++ b/test/importance.jl @@ -0,0 +1,55 @@ +r1 = S.Rule(S.Clause(" X[i, 1] < 32000 "), [0.1], [0.4]) +r2 = S.Rule(S.Clause(" X[i, 1] ≥ 32000 "), [0.3], [0.2]) +r3 = S.Rule(S.Clause(" X[i, 2] < 8000 "), [0.1], [0.5]) +w1 = 0.4 +w2 = 0.3 +w3 = 0.3 + +model = let + rules = [r1, r2, r3] + algo = SIRUS.Classification() + classes = [0, 1] + weights = Float16[w1, w2, w3] + SIRUS.StableRules(rules, algo, classes, weights) +end +# StableRules model with 3 rules: +# if X[i, 1] < 32000.0 then [0.04] else [0.16] + +# if X[i, 1] ≥ 32000.0 then [0.09] else [0.06] + +# if X[i, 2] < 8000.0 then [0.03] else [0.15] +# and 2 classes: [0, 1]. +# Note: showing only the probability for class 1 since class 0 has probability 1 - p. + +@test_throws ArgumentError feature_importance(model, "x1") + +importance = feature_importance(model, "1") +# Based on the numbers above. +expected = w1 * (0.4 - 0.1) + w2 * (0.3 - 0.2) +@test importance ≈ expected atol=0.01 + +@test feature_importance([model, model], "1") ≈ expected atol=0.01 +@test only(feature_importances(model, ["1"])).importance ≈ expected atol=0.01 + +importances = feature_importances([model], ["1", "2"])::Vector{<:NamedTuple} +@test length(importances) == 2 +@test importances[1].feature_name == "1" +@test importances[1].importance ≈ expected atol=0.01 +@test importances[2].feature_name == "2" + +@test unpack_rule(r1) == (; + feature=1, + feature_name="1", + splitval=32000.0, + direction=:L, + then=[0.1], + otherwise=[0.4] + ) + +@test unpack_model(model)[1] == (; + weight=Float16(w1), + feature=1, + feature_name="1", + splitval=32000.0, + direction=:L, + then=[0.1], + otherwise=[0.4] + ) diff --git a/test/rules.jl b/test/rules.jl index 18416e3..0f724b9 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -23,7 +23,7 @@ left_rule = S.Rule(S.Clause(" X[i, 1] < 32000 "), [0.61], [0.408]) @testset "exported functions" begin @test feature_names(left_rule) == ["1"] @test directions(left_rule) == [:L] - @test values(left_rule) == [32000] + @test splitvals(left_rule) == [32000] end r1 = S.Rule(S.Clause(" X[i, 1] < 32000 "), [0.61], [0.408]) diff --git a/test/runtests.jl b/test/runtests.jl index c125711..1a3f60c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,6 +36,10 @@ end include("weights.jl") end +@testset "importance" begin + include("importance.jl") +end + if get(ENV, "CAN_RUN_R_SIRUS", "false")::String == "true" @testset "rcall" begin include("rcall.jl") diff --git a/_typos.toml b/test/typos.toml similarity index 100% rename from _typos.toml rename to test/typos.toml