Skip to content

Commit

Permalink
Improve stability example in docs (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed May 22, 2024
1 parent 12ae28b commit eebb38f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 50 deletions.
17 changes: 9 additions & 8 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,30 @@ jobs:
actions: write
# Required by julia-actions/cache.
contents: read
runs-on: ${{ matrix.os }}
runs-on: ${{ matrix.config.os }}
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
version:
- '1.6'
- '1'
os:
- ubuntu-latest
config:
- {os: ubuntu-latest, version: '1.6'}
# R crashes on ubuntu-latest with newer versions of Julia.
- {os: macos-latest, version: '1'}

steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
version: ${{ matrix.config.version }}
- uses: julia-actions/cache@v2.0.0
with:
cache-name: 'test-${{ matrix.config.os }}-${{ matrix.config.version }}'
- uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true
r-version: '4'
- run: echo "LD_LIBRARY_PATH=$(R RHOME)/lib:$LD_LIBRARY_PATH" >> $GITHUB_ENV
if: matrix.os == 'ubuntu-latest'
if: matrix.config.os == 'ubuntu-latest'
- run: Rscript -e 'install.packages("sirus")'
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
CSV = "0.10"
CairoMakie = "0.11"
CairoMakie = "0.12"
CategoricalArrays = "0.10"
DataDeps = "0.7"
DataFrames = "1"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/basic-example.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
### A Pluto.jl notebook ###
# v0.19.32
# v0.19.42

using Markdown
using InteractiveUtils
Expand Down
101 changes: 61 additions & 40 deletions docs/src/binary-classification.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
### A Pluto.jl notebook ###
# v0.19.32
# v0.19.42

using Markdown
using InteractiveUtils
Expand Down Expand Up @@ -171,6 +171,7 @@ Say, we take the following subset of length `0.7 * length(nodes)`:
# ╔═╡ ee12350a-627b-4a11-99cb-38c496977d18
md"""
Now, the algorithm would choose a different location and, hence, introduce instability.
To solve this, Bénard et al. decided to limit the splitpoints that the algorithm can use to split to data to a pre-defined set of points.
For each feature, they find `q` empirical quantiles where `q` is typically 10.
Let's overlay these quantiles on top of the `age` feature:
Expand All @@ -181,6 +182,20 @@ md"""
Next, let's see where the cutpoints are when we take the same random subset as above:
"""

# ╔═╡ ede038b3-d92e-4208-b8ab-984f3ca1810e
# hideall
function _plot_cutpoints(data::DataFrame; q=10)
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
cps = Float64.(unique(cutpoints(data.nodes, q)))
scatter!(ax, data.nodes, fill(1, nrow(data)); color=data.color, marker=data.marker)
vlines!(ax, cps; color=:black, linestyle=:dash)
xlims!(ax, -1, 48)
ylims!(ax, 0.82, 1.2)
hideydecorations!(ax)
return fig
end;

# ╔═╡ 01b08d44-4b9b-42e2-bb20-f34cb9b407f3
md"""
As can be seen, many cutpoints are at the same location as before.
Expand Down Expand Up @@ -331,23 +346,6 @@ md"""
## Appendix
"""

# ╔═╡ ede038b3-d92e-4208-b8ab-984f3ca1810e
function _plot_cutpoints(data::AbstractVector)
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
cps = Float64.(unique(cutpoints(data, 10)))
scatter!(ax, data, fill(1, length(data)))
vlines!(ax, cps; color=:black, linestyle=:dash)
textlocs = [(c, 1.1) for c in cps]
for cutpoint in cps
annotation = string(round(cutpoint; digits=2))::String
text!(ax, cutpoint + 0.2, 1.08; text=annotation, fontsize=13)
end
ylims!(ax, 0.9, 1.2)
hideydecorations!(ax)
return fig
end;

# ╔═╡ 93a7dd3b-7810-4021-bf6e-ae9c04acea46
_rng(seed::Int=1) = StableRNG(seed);

Expand Down Expand Up @@ -421,47 +419,66 @@ The `then` or `else` outcome is chosen for all the rules and, finally, the outco
"""

# ╔═╡ 172d3263-2e39-483c-9d82-8c22059e63c3
nodes = sort(data.age);
# hideall
processed = let
df = sort(data, :nodes)
df[!, :marker] = [survival == 1 ? :circle : :cross for survival in df.survival]
df[!, :color] = [survival == 1 ? :black : :gray for survival in df.survival]
df
end;

# ╔═╡ cf1816e5-4e8d-4e60-812f-bd6ae7011d6c
# hideall
ln = length(nodes);
l = length(processed.nodes);

# ╔═╡ 8d1b30bd-0ad2-416e-a36a-f263ef781289
# hideall
index = l - 21;

# ╔═╡ bfcb5e17-8937-4448-b090-2782818c6b6c
# hideall
subset_indexes = collect(S._rand_subset(_rng(3), 1:l, round(Int, 0.6 * l)));

# ╔═╡ de90efc9-2171-4406-93a1-9a213ab32259
# hideall
let
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
scatter!(ax, nodes, fill(1, ln))
color = (; colormap=:tab10, colorrange=(1, 10))
scatter!(ax, processed.nodes, fill(1, l); color=processed.color, marker=processed.marker)
hideydecorations!(ax)
xlims!(ax, -1, 48)
ylims!(ax, 0.82, 1.2)
fig
end

# ╔═╡ 8d1b30bd-0ad2-416e-a36a-f263ef781289
# hideall
index = length(nodes) - 3;

# ╔═╡ 2c1adef4-822e-4dc0-946b-dc574e50b305
# hideall
let
V = processed.nodes
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
scatter!(ax, nodes, fill(1, ln))
vlines!(ax, [nodes[index]]; color=:red)
annotation = string(round(nodes[index]; digits=2))
text!(ax, nodes[index] + 0.003, 1.08; text=annotation, fontsize=11)
scatter!(ax, V, fill(1, l); color=processed.color, marker=processed.marker)
vlines!(ax, [V[index]]; color=:red)
annotation = string(round(V[index]; digits=2))
text!(ax, V[index] + 0.003, 1.08; text=annotation, fontsize=11)
hideydecorations!(ax)
ylims!(ax, 0.9, 1.2)
xlims!(ax, -1, 48)
ylims!(ax, 0.82, 1.2)
fig
end

# ╔═╡ bfcb5e17-8937-4448-b090-2782818c6b6c
# ╔═╡ 6fb30208-cf39-42cd-bdda-a7941173822e
# hideall
subset = collect(S._rand_subset(_rng(3), nodes, round(Int, 0.7 * ln)));
subset = let
df = processed[subset_indexes, :]
sort!(df, :nodes)
df
end

# ╔═╡ dff9eb71-a853-4186-8245-a64206379b6f
# hideall
ls = length(subset);
ls = nrow(subset);

# ╔═╡ 8fdc24d9-1f6b-4094-9722-6b5b6c713f12
# hideall
Expand All @@ -470,20 +487,23 @@ _plot_cutpoints(subset)
# ╔═╡ 25ad7a18-f989-40f7-8ef1-4ca506446478
# hideall
let
V = subset.nodes
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
scatter!(ax, subset, fill(1, ls))
vlines!(ax, [nodes[index]]; color=:red, linestyle=:dash)
annotation = string(round(nodes[index]; digits=2))
text!(ax, nodes[index] + 0.003, 1.08; text=annotation, fontsize=11)
scatter!(ax, V, fill(1, ls); color=subset.color, marker=subset.marker)
loc = processed.nodes[index]
vlines!(ax, [loc]; color=:red, linestyle=:dash)
annotation = string(round(loc; digits=2))
text!(ax, loc + 0.003, 1.08; text=annotation, fontsize=11)
hideydecorations!(ax)
ylims!(ax, 0.9, 1.2)
xlims!(ax, -1, 48)
ylims!(ax, 0.82, 1.2)
fig
end

# ╔═╡ 4935d8f5-32e1-429c-a8c1-84c242eff4bf
# hideall
_plot_cutpoints(nodes)
_plot_cutpoints(processed; q=10)

# ╔═╡ a64dae3c-3b97-4076-98f4-3c9a0e5c0621
function _odds_plot(models::Vector{<:StableRules}, feat_names::Vector{String})
Expand Down Expand Up @@ -701,12 +721,14 @@ end
# ╠═2c1adef4-822e-4dc0-946b-dc574e50b305
# ╠═896e00dc-2ce9-4a9f-acc1-519aec21dd83
# ╠═bfcb5e17-8937-4448-b090-2782818c6b6c
# ╠═6fb30208-cf39-42cd-bdda-a7941173822e
# ╠═dff9eb71-a853-4186-8245-a64206379b6f
# ╠═25ad7a18-f989-40f7-8ef1-4ca506446478
# ╠═ee12350a-627b-4a11-99cb-38c496977d18
# ╠═4935d8f5-32e1-429c-a8c1-84c242eff4bf
# ╠═0cc970cd-b7ed-4782-a520-ff0a76fe0453
# ╠═8fdc24d9-1f6b-4094-9722-6b5b6c713f12
# ╠═ede038b3-d92e-4208-b8ab-984f3ca1810e
# ╠═01b08d44-4b9b-42e2-bb20-f34cb9b407f3
# ╠═7e1d46b4-5f93-478d-9105-a5b0db1eaf08
# ╠═ab103b4e-24eb-4575-8c04-ae3fd9ec1673
Expand All @@ -733,7 +755,6 @@ end
# ╠═e7f396dc-38a7-40f7-9e5b-6fbea9d61789
# ╠═7c688412-d1b4-492d-bda2-0b9181057d4d
# ╠═e1890517-7a44-4814-999d-6af27e2a136a
# ╠═ede038b3-d92e-4208-b8ab-984f3ca1810e
# ╠═93a7dd3b-7810-4021-bf6e-ae9c04acea46
# ╠═be324728-1b60-4584-b8ea-c4fe9e3466af
# ╠═7ad3cf67-2acd-44c6-aa91-7d5ae809dfbc
Expand Down

0 comments on commit eebb38f

Please sign in to comment.