Skip to content

Commit

Permalink
Extend API (#74)
Browse files Browse the repository at this point in the history
Fixes #44 and fixes
#66.
  • Loading branch information
rikhuijzer committed Dec 1, 2023
1 parent de9fb40 commit c3bae52
Show file tree
Hide file tree
Showing 17 changed files with 599 additions and 192 deletions.
1 change: 1 addition & 0 deletions .github/workflows/Docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- uses: julia-actions/cache@v1.3.0
with:
cache-name: 'docs'
- run: julia -e 'using Pkg; Pkg.add("Revise");'
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-docdeploy@v1
env:
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/Typos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ jobs:
- uses: actions/checkout@v4

- uses: crate-ci/typos@master
with:
config: './test/typos.toml'
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
CSV = "0.10"
CairoMakie = "0.10"
CairoMakie = "0.11"
CategoricalArrays = "0.10"
DataDeps = "0.7"
DataFrames = "1"
Expand Down
32 changes: 27 additions & 5 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# API

## Types
## MLJ Interface Types

```@docs
StableRulesClassifier
Expand All @@ -9,13 +9,35 @@ StableForestClassifier
StableForestRegressor
```

## Methods
## SIRUS Types

```@docs
SubClause
Clause
Rule
```

## SIRUS Methods

```@docs
feature
features
feature_name
feature_names
splitval
splitvals
clause
subclauses
direction
directions
values(::SIRUS.Rule)
satisfies
Cutpoints
feature_importance
feature_importances
then
otherwise
gap_size
cutpoints
satisfies
unpack_rule
unpack_model
unpack_models
```
142 changes: 50 additions & 92 deletions docs/src/binary-classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ begin
Pkg.develop(; path=PKGDIR)
end

# ╔═╡ 27bd0e48-9870-472f-8d78-a9b460c9e858
# hideall
using Revise

# ╔═╡ f833dab6-31d4-4353-a68b-ef0501d606d4
begin
using CairoMakie
Expand Down Expand Up @@ -230,9 +234,9 @@ Since we know that the model performs well on the cross-validations, we can fit

# ╔═╡ 3c415a26-803e-4f35-866f-2e582c6c1c45
md"""
## Visualization
## Plot
Since our rules are relatively simple with only a binary outcome and only one subclause in each rule, the following figure is a way to visualize the obtained rules per fold.
Since our rules are relatively simple with only a binary outcome and only one subclause in each rule (because of `max_depth=1`), the following figure is a way to visualize the obtained rules per fold.
For multiple subclauses, I would not know how to visualize the rules.
Also, this plot is probably not perfect; let me know if you have suggestions.
Expand Down Expand Up @@ -330,7 +334,7 @@ md"""

# ╔═╡ ede038b3-d92e-4208-b8ab-984f3ca1810e
function _plot_cutpoints(data::AbstractVector)
fig = Figure(; resolution=(800, 100))
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
cps = Float64.(unique(cutpoints(data, 10)))
scatter!(ax, data, fill(1, length(data)))
Expand Down Expand Up @@ -427,7 +431,7 @@ ln = length(nodes);
# ╔═╡ de90efc9-2171-4406-93a1-9a213ab32259
# hideall
let
fig = Figure(; resolution=(800, 100))
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
scatter!(ax, nodes, fill(1, ln))
hideydecorations!(ax)
Expand All @@ -441,7 +445,7 @@ index = length(nodes) - 3;
# ╔═╡ 2c1adef4-822e-4dc0-946b-dc574e50b305
# hideall
let
fig = Figure(; resolution=(800, 100))
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
scatter!(ax, nodes, fill(1, ln))
vlines!(ax, [nodes[index]]; color=:red)
Expand All @@ -467,7 +471,7 @@ _plot_cutpoints(subset)
# ╔═╡ 25ad7a18-f989-40f7-8ef1-4ca506446478
# hideall
let
fig = Figure(; resolution=(800, 100))
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
scatter!(ax, subset, fill(1, ls))
vlines!(ax, [nodes[index]]; color=:red, linestyle=:dash)
Expand All @@ -483,98 +487,49 @@ end
_plot_cutpoints(nodes)

# ╔═╡ a64dae3c-3b97-4076-98f4-3c9a0e5c0621
# hideall
function _odds_plot(e::PerformanceEvaluation)
function _odds_plot(models::Vector{<:StableRules}, feat_names::Vector{String})
w, h = (1000, 300)
fig = Figure(; resolution=(w, h))
fig = Figure(; size=(w, h))
grid = fig[1, 1:2] = GridLayout()

fitresults = getproperty.(e.fitted_params_per_fold, :fitresult)
feature_names = String[]
for fitresult in fitresults
for rule in fitresult.rules
name = only(SIRUS._subclauses(rule)).feature_name
push!(feature_names, name)
end
end

names = sort(unique(feature_names))
subtitle = "Ratio"

max_height = maximum(maximum.(getproperty.(fitresults, :weights)))

importances = _sum_weights.(Ref(fitresults), names)

matching_rules = DataFrame(; names, importance=importances)
sort!(matching_rules, :importance; rev=true)
names = matching_rules.names
l = length(names)

for (i, feature_name) in enumerate(names)
yticks = (1:1, [feature_name])
ax = i == l ?
Axis(grid[i, 1:3]; yticks, xlabel="Ratio") :
Axis(grid[i, 1:3]; yticks)
vlines!(ax, [0]; color=:gray, linestyle=:dash)
xlims!(ax, -1, 1)
ylabel = feature_name

name = feature_name

nested_rules_weights = map(fitresults) do fitresult
subresult = Tuple{SIRUS.Rule,Float64}[]
zipped = zip(fitresult.rules, fitresult.weights)
for (rule, weight) in zipped
feat_name = only(SIRUS._subclauses(rule)).feature_name
if feat_name == feature_name
push!(subresult, (rule, weight))
end
end
subresult
end
rules_weights = Tuple{SIRUS.Rule,Float64}[]
for nested in nested_rules_weights
isnothing(nested) && continue
for rule_weight in nested
push!(rules_weights, rule_weight)
end
end
rw::Vector{Tuple{SIRUS.Rule,Float64}} =
filter(!isnothing, rules_weights)
thresholds = _threshold.(first.(rw))
t_mean = round(mean(thresholds); digits=1)
t_std = round(std(thresholds); digits=1)

for (rule, weight) in rw
left = last(rule.then)::Float64
right = last(rule.otherwise)::Float64
t::Float64 = _threshold(rule)
ratio = log((right) / (left))
@assert feat_names == sort(unique(feat_names))

probability_for_class_1(probs::Vector) = last(probs)::Float64
# Gets the feature importances in order of importance.
importances = feature_importances(models, feat_names)

# Create a row in the plot for each feature.
for (i, importance) in enumerate(importances)
feat_name, _ = importance
yticks = (1:1, [feat_name])
axl = Axis(grid[i, 1:3]; yticks)
axr = Axis(grid[i, 4:5])
vlines!(axl, [0]; color=:gray, linestyle=:dash)
xlims!(axl, -1, 1)

unpacked_rules = unpack_models(models, feat_name)::Vector{NamedTuple}
# Create a dot and line for each rule that mentions the current feature.
for unpacked_rule::NamedTuple in unpacked_rules
left = probability_for_class_1(unpacked_rule.then)
right = probability_for_class_1(unpacked_rule.otherwise)
value = unpacked_rule.splitval
ratio = log(right / left)
# area = πr²
markersize = 50 * sqrt(weight / π)
scatter!(ax, [ratio], [1]; color=:black, markersize)
end
hideydecorations!(ax; ticklabels=false)

axr = i == l ?
Axis(grid[i, 4:5]; xlabel="Location") :
Axis(grid[i, 4:5])
D = data[:, feature_name]
hist!(axr, D; scale_to=1)
vlines!(axr, thresholds; color=:black, linestyle=:dash)

if i < l
hidexdecorations!(ax)
else
hidexdecorations!(ax; ticks=false, ticklabels=false)
markersize = 50 * sqrt(unpacked_rule.weight / π)
scatter!(axl, [ratio], [1]; color=:black, markersize)

vlines!(axr, unpacked_rule.splitval; color=:black, linestyle=:dash)
end


# Show a histogram in the background.
hist!(axr, data[:, feat_name]; scale_to=1)

hidexdecorations!(axl)
hideydecorations!(axr)
hidexdecorations!(axr; ticks=false, ticklabels=false)
end

rowgap!(grid, 5)
colgap!(grid, 50)
rowgap!(grid, 5) # hide
return fig
end;

Expand Down Expand Up @@ -659,8 +614,10 @@ e4 = let
end;

# ╔═╡ 923affb5-b4ca-4b50-baa5-af29204d2081
# hideall
_odds_plot(e4.e)
let
models = getproperty.(e4.e.fitted_params_per_fold, :fitresult)
_odds_plot(models, sort(names(X)))
end

# ╔═╡ 7fad8dd5-c0a9-4c45-9663-d40a464bca77
# hideall
Expand Down Expand Up @@ -712,10 +669,12 @@ results = let
end

# ╔═╡ Cell order:
# ╠═27bd0e48-9870-472f-8d78-a9b460c9e858
# ╠═7c10c275-54d8-4f1a-947f-7861199cdf21
# ╠═e9028115-d098-4c61-a82f-d4553fe654f8
# ╠═b1c17349-fd80-43f1-bbc2-53fdb539d1c0
# ╠═348d1235-87f2-4e8f-8f42-be89fef5bf87
# ╠═f833dab6-31d4-4353-a68b-ef0501d606d4
# ╠═961aa273-d97b-497f-a79a-06bf89dc34b0
# ╠═6e16f844-9365-43af-9ea7-2984808f1fd5
# ╠═b6957225-1889-49fb-93e2-f022ca7c3b23
Expand Down Expand Up @@ -769,7 +728,6 @@ end
# ╠═e7f396dc-38a7-40f7-9e5b-6fbea9d61789
# ╠═7c688412-d1b4-492d-bda2-0b9181057d4d
# ╠═e1890517-7a44-4814-999d-6af27e2a136a
# ╠═f833dab6-31d4-4353-a68b-ef0501d606d4
# ╠═ede038b3-d92e-4208-b8ab-984f3ca1810e
# ╠═93a7dd3b-7810-4021-bf6e-ae9c04acea46
# ╠═be324728-1b60-4584-b8ea-c4fe9e3466af
Expand Down
13 changes: 10 additions & 3 deletions src/SIRUS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,24 @@ using .Helpers: colnames, nfeatures, view_feature

include("empiricalquantiles.jl")
using .EmpiricalQuantiles: Cutpoints, cutpoints
export Cutpoints, cutpoints
export cutpoints

include("forest.jl")
export StableForest
include("classification.jl")
include("regression.jl")
include("rules.jl")
export StableRules, feature_names, directions, satisfies
export SubClause, feature, feature_name, splitval, direction
export Rule, Clause, clause, then, otherwise, subclauses
export features, feature_names, splitvals, directions
export StableRules, satisfies
export unpack_rule, unpack_model, unpack_models
include("ruleshow.jl")
include("weights.jl")
include("dependent.jl")
export gap_size
include("weights.jl")
include("importance.jl")
export feature_importance, feature_importances

include("mlj.jl")
const StableForestClassifier = MLJImplementation.StableForestClassifier
Expand Down
Loading

0 comments on commit c3bae52

Please sign in to comment.