diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c112c71..a703569 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,29 +21,30 @@ jobs: actions: write # Required by julia-actions/cache. contents: read - runs-on: ${{ matrix.os }} + runs-on: ${{ matrix.config.os }} timeout-minutes: 30 strategy: fail-fast: false matrix: - version: - - '1.6' - - '1' - os: - - ubuntu-latest + config: + - {os: ubuntu-latest, version: '1.6'} + # R crashes on ubuntu-latest with newer versions of Julia. + - {os: macos-latest, version: '1'} steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: ${{ matrix.version }} + version: ${{ matrix.config.version }} - uses: julia-actions/cache@v2.0.0 + with: + cache-name: 'test-${{ matrix.config.os }}-${{ matrix.config.version }}' - uses: r-lib/actions/setup-r@v2 with: use-public-rspm: true r-version: '4' - run: echo "LD_LIBRARY_PATH=$(R RHOME)/lib:$LD_LIBRARY_PATH" >> $GITHUB_ENV - if: matrix.os == 'ubuntu-latest' + if: matrix.config.os == 'ubuntu-latest' - run: Rscript -e 'install.packages("sirus")' - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 diff --git a/docs/Project.toml b/docs/Project.toml index fbc33b9..9885105 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -19,7 +19,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] CSV = "0.10" -CairoMakie = "0.11" +CairoMakie = "0.12" CategoricalArrays = "0.10" DataDeps = "0.7" DataFrames = "1" diff --git a/docs/src/basic-example.jl b/docs/src/basic-example.jl index bc4d31f..71767cd 100644 --- a/docs/src/basic-example.jl +++ b/docs/src/basic-example.jl @@ -1,5 +1,5 @@ ### A Pluto.jl notebook ### -# v0.19.32 +# v0.19.42 using Markdown using InteractiveUtils diff --git a/docs/src/binary-classification.jl b/docs/src/binary-classification.jl index 2e6ffe4..e136a42 100644 --- a/docs/src/binary-classification.jl +++ b/docs/src/binary-classification.jl @@ -1,5 +1,5 @@ ### A Pluto.jl notebook ### -# v0.19.32 +# v0.19.42 using Markdown using InteractiveUtils @@ -171,6 +171,7 @@ Say, we take the following subset of length `0.7 * length(nodes)`: # ╔═╡ ee12350a-627b-4a11-99cb-38c496977d18 md""" Now, the algorithm would choose a different location and, hence, introduce instability. + To solve this, Bénard et al. decided to limit the splitpoints that the algorithm can use to split to data to a pre-defined set of points. For each feature, they find `q` empirical quantiles where `q` is typically 10. Let's overlay these quantiles on top of the `age` feature: @@ -181,6 +182,20 @@ md""" Next, let's see where the cutpoints are when we take the same random subset as above: """ +# ╔═╡ ede038b3-d92e-4208-b8ab-984f3ca1810e +# hideall +function _plot_cutpoints(data::DataFrame; q=10) + fig = Figure(; size=(800, 100)) + ax = Axis(fig[1, 1]) + cps = Float64.(unique(cutpoints(data.nodes, q))) + scatter!(ax, data.nodes, fill(1, nrow(data)); color=data.color, marker=data.marker) + vlines!(ax, cps; color=:black, linestyle=:dash) + xlims!(ax, -1, 48) + ylims!(ax, 0.82, 1.2) + hideydecorations!(ax) + return fig +end; + # ╔═╡ 01b08d44-4b9b-42e2-bb20-f34cb9b407f3 md""" As can be seen, many cutpoints are at the same location as before. @@ -331,23 +346,6 @@ md""" ## Appendix """ -# ╔═╡ ede038b3-d92e-4208-b8ab-984f3ca1810e -function _plot_cutpoints(data::AbstractVector) - fig = Figure(; size=(800, 100)) - ax = Axis(fig[1, 1]) - cps = Float64.(unique(cutpoints(data, 10))) - scatter!(ax, data, fill(1, length(data))) - vlines!(ax, cps; color=:black, linestyle=:dash) - textlocs = [(c, 1.1) for c in cps] - for cutpoint in cps - annotation = string(round(cutpoint; digits=2))::String - text!(ax, cutpoint + 0.2, 1.08; text=annotation, fontsize=13) - end - ylims!(ax, 0.9, 1.2) - hideydecorations!(ax) - return fig -end; - # ╔═╡ 93a7dd3b-7810-4021-bf6e-ae9c04acea46 _rng(seed::Int=1) = StableRNG(seed); @@ -421,47 +419,66 @@ The `then` or `else` outcome is chosen for all the rules and, finally, the outco """ # ╔═╡ 172d3263-2e39-483c-9d82-8c22059e63c3 -nodes = sort(data.age); +# hideall +processed = let + df = sort(data, :nodes) + df[!, :marker] = [survival == 1 ? :circle : :cross for survival in df.survival] + df[!, :color] = [survival == 1 ? :black : :gray for survival in df.survival] + df +end; # ╔═╡ cf1816e5-4e8d-4e60-812f-bd6ae7011d6c # hideall -ln = length(nodes); +l = length(processed.nodes); + +# ╔═╡ 8d1b30bd-0ad2-416e-a36a-f263ef781289 +# hideall +index = l - 21; + +# ╔═╡ bfcb5e17-8937-4448-b090-2782818c6b6c +# hideall +subset_indexes = collect(S._rand_subset(_rng(3), 1:l, round(Int, 0.6 * l))); # ╔═╡ de90efc9-2171-4406-93a1-9a213ab32259 # hideall let fig = Figure(; size=(800, 100)) ax = Axis(fig[1, 1]) - scatter!(ax, nodes, fill(1, ln)) + color = (; colormap=:tab10, colorrange=(1, 10)) + scatter!(ax, processed.nodes, fill(1, l); color=processed.color, marker=processed.marker) hideydecorations!(ax) + xlims!(ax, -1, 48) + ylims!(ax, 0.82, 1.2) fig end -# ╔═╡ 8d1b30bd-0ad2-416e-a36a-f263ef781289 -# hideall -index = length(nodes) - 3; - # ╔═╡ 2c1adef4-822e-4dc0-946b-dc574e50b305 # hideall let + V = processed.nodes fig = Figure(; size=(800, 100)) ax = Axis(fig[1, 1]) - scatter!(ax, nodes, fill(1, ln)) - vlines!(ax, [nodes[index]]; color=:red) - annotation = string(round(nodes[index]; digits=2)) - text!(ax, nodes[index] + 0.003, 1.08; text=annotation, fontsize=11) + scatter!(ax, V, fill(1, l); color=processed.color, marker=processed.marker) + vlines!(ax, [V[index]]; color=:red) + annotation = string(round(V[index]; digits=2)) + text!(ax, V[index] + 0.003, 1.08; text=annotation, fontsize=11) hideydecorations!(ax) - ylims!(ax, 0.9, 1.2) + xlims!(ax, -1, 48) + ylims!(ax, 0.82, 1.2) fig end -# ╔═╡ bfcb5e17-8937-4448-b090-2782818c6b6c +# ╔═╡ 6fb30208-cf39-42cd-bdda-a7941173822e # hideall -subset = collect(S._rand_subset(_rng(3), nodes, round(Int, 0.7 * ln))); +subset = let + df = processed[subset_indexes, :] + sort!(df, :nodes) + df +end # ╔═╡ dff9eb71-a853-4186-8245-a64206379b6f # hideall -ls = length(subset); +ls = nrow(subset); # ╔═╡ 8fdc24d9-1f6b-4094-9722-6b5b6c713f12 # hideall @@ -470,20 +487,23 @@ _plot_cutpoints(subset) # ╔═╡ 25ad7a18-f989-40f7-8ef1-4ca506446478 # hideall let + V = subset.nodes fig = Figure(; size=(800, 100)) ax = Axis(fig[1, 1]) - scatter!(ax, subset, fill(1, ls)) - vlines!(ax, [nodes[index]]; color=:red, linestyle=:dash) - annotation = string(round(nodes[index]; digits=2)) - text!(ax, nodes[index] + 0.003, 1.08; text=annotation, fontsize=11) + scatter!(ax, V, fill(1, ls); color=subset.color, marker=subset.marker) + loc = processed.nodes[index] + vlines!(ax, [loc]; color=:red, linestyle=:dash) + annotation = string(round(loc; digits=2)) + text!(ax, loc + 0.003, 1.08; text=annotation, fontsize=11) hideydecorations!(ax) - ylims!(ax, 0.9, 1.2) + xlims!(ax, -1, 48) + ylims!(ax, 0.82, 1.2) fig end # ╔═╡ 4935d8f5-32e1-429c-a8c1-84c242eff4bf # hideall -_plot_cutpoints(nodes) +_plot_cutpoints(processed; q=10) # ╔═╡ a64dae3c-3b97-4076-98f4-3c9a0e5c0621 function _odds_plot(models::Vector{<:StableRules}, feat_names::Vector{String}) @@ -701,12 +721,14 @@ end # ╠═2c1adef4-822e-4dc0-946b-dc574e50b305 # ╠═896e00dc-2ce9-4a9f-acc1-519aec21dd83 # ╠═bfcb5e17-8937-4448-b090-2782818c6b6c +# ╠═6fb30208-cf39-42cd-bdda-a7941173822e # ╠═dff9eb71-a853-4186-8245-a64206379b6f # ╠═25ad7a18-f989-40f7-8ef1-4ca506446478 # ╠═ee12350a-627b-4a11-99cb-38c496977d18 # ╠═4935d8f5-32e1-429c-a8c1-84c242eff4bf # ╠═0cc970cd-b7ed-4782-a520-ff0a76fe0453 # ╠═8fdc24d9-1f6b-4094-9722-6b5b6c713f12 +# ╠═ede038b3-d92e-4208-b8ab-984f3ca1810e # ╠═01b08d44-4b9b-42e2-bb20-f34cb9b407f3 # ╠═7e1d46b4-5f93-478d-9105-a5b0db1eaf08 # ╠═ab103b4e-24eb-4575-8c04-ae3fd9ec1673 @@ -733,7 +755,6 @@ end # ╠═e7f396dc-38a7-40f7-9e5b-6fbea9d61789 # ╠═7c688412-d1b4-492d-bda2-0b9181057d4d # ╠═e1890517-7a44-4814-999d-6af27e2a136a -# ╠═ede038b3-d92e-4208-b8ab-984f3ca1810e # ╠═93a7dd3b-7810-4021-bf6e-ae9c04acea46 # ╠═be324728-1b60-4584-b8ea-c4fe9e3466af # ╠═7ad3cf67-2acd-44c6-aa91-7d5ae809dfbc