From 691e5beae1f324918db0af2af51f614014929a9b Mon Sep 17 00:00:00 2001 From: Rik Huijzer Date: Tue, 21 May 2024 16:14:42 +0200 Subject: [PATCH] Fix example --- docs/src/binary-classification.jl | 90 ++++++++++++++++++------------- 1 file changed, 53 insertions(+), 37 deletions(-) diff --git a/docs/src/binary-classification.jl b/docs/src/binary-classification.jl index d44f9b8..c4d17eb 100644 --- a/docs/src/binary-classification.jl +++ b/docs/src/binary-classification.jl @@ -171,6 +171,7 @@ Say, we take the following subset of length `0.7 * length(nodes)`: # ╔═╡ ee12350a-627b-4a11-99cb-38c496977d18 md""" Now, the algorithm would choose a different location and, hence, introduce instability. + To solve this, Bénard et al. decided to limit the splitpoints that the algorithm can use to split to data to a pre-defined set of points. For each feature, they find `q` empirical quantiles where `q` is typically 10. Let's overlay these quantiles on top of the `age` feature: @@ -181,6 +182,19 @@ md""" Next, let's see where the cutpoints are when we take the same random subset as above: """ +# ╔═╡ ede038b3-d92e-4208-b8ab-984f3ca1810e +# hideall +function _plot_cutpoints(data::DataFrame; q=10) + fig = Figure(; size=(800, 100)) + ax = Axis(fig[1, 1]) + cps = Float64.(unique(cutpoints(data.nodes, q))) + scatter!(ax, data.nodes, fill(1, nrow(data)); color=data.color, marker=data.marker) + vlines!(ax, cps; color=:black, linestyle=:dash) + ylims!(ax, 0.9, 1.2) + hideydecorations!(ax) + return fig +end; + # ╔═╡ 01b08d44-4b9b-42e2-bb20-f34cb9b407f3 md""" As can be seen, many cutpoints are at the same location as before. @@ -331,23 +345,6 @@ md""" ## Appendix """ -# ╔═╡ ede038b3-d92e-4208-b8ab-984f3ca1810e -function _plot_cutpoints(data::AbstractVector) - fig = Figure(; size=(800, 100)) - ax = Axis(fig[1, 1]) - cps = Float64.(unique(cutpoints(data, 10))) - scatter!(ax, data, fill(1, length(data))) - vlines!(ax, cps; color=:black, linestyle=:dash) - textlocs = [(c, 1.1) for c in cps] - for cutpoint in cps - annotation = string(round(cutpoint; digits=2))::String - text!(ax, cutpoint + 0.2, 1.08; text=annotation, fontsize=13) - end - ylims!(ax, 0.9, 1.2) - hideydecorations!(ax) - return fig -end; - # ╔═╡ 93a7dd3b-7810-4021-bf6e-ae9c04acea46 _rng(seed::Int=1) = StableRNG(seed); @@ -421,47 +418,63 @@ The `then` or `else` outcome is chosen for all the rules and, finally, the outco """ # ╔═╡ 172d3263-2e39-483c-9d82-8c22059e63c3 -nodes = sort(data.age); +# hideall +processed = let + df = sort(data, :nodes) + df[!, :marker] = [survival == 1 ? :circle : :cross for survival in df.survival] + df[!, :color] = [survival == 1 ? :black : :gray for survival in df.survival] + df +end; # ╔═╡ cf1816e5-4e8d-4e60-812f-bd6ae7011d6c # hideall -ln = length(nodes); +l = length(processed.nodes); + +# ╔═╡ 8d1b30bd-0ad2-416e-a36a-f263ef781289 +# hideall +index = l - 21; + +# ╔═╡ bfcb5e17-8937-4448-b090-2782818c6b6c +# hideall +subset_indexes = collect(S._rand_subset(_rng(3), 1:l, round(Int, 0.6 * l))); # ╔═╡ de90efc9-2171-4406-93a1-9a213ab32259 # hideall let fig = Figure(; size=(800, 100)) ax = Axis(fig[1, 1]) - scatter!(ax, nodes, fill(1, ln)) + color = (; colormap=:tab10, colorrange=(1, 10)) + scatter!(ax, processed.nodes, fill(1, l); color=processed.color, marker=processed.marker) hideydecorations!(ax) fig end -# ╔═╡ 8d1b30bd-0ad2-416e-a36a-f263ef781289 -# hideall -index = length(nodes) - 3; - # ╔═╡ 2c1adef4-822e-4dc0-946b-dc574e50b305 # hideall let + V = processed.nodes fig = Figure(; size=(800, 100)) ax = Axis(fig[1, 1]) - scatter!(ax, nodes, fill(1, ln)) - vlines!(ax, [nodes[index]]; color=:red) - annotation = string(round(nodes[index]; digits=2)) - text!(ax, nodes[index] + 0.003, 1.08; text=annotation, fontsize=11) + scatter!(ax, V, fill(1, l); color=processed.color, marker=processed.marker) + vlines!(ax, [V[index]]; color=:red) + annotation = string(round(V[index]; digits=2)) + text!(ax, V[index] + 0.003, 1.08; text=annotation, fontsize=11) hideydecorations!(ax) ylims!(ax, 0.9, 1.2) fig end -# ╔═╡ bfcb5e17-8937-4448-b090-2782818c6b6c +# ╔═╡ 6fb30208-cf39-42cd-bdda-a7941173822e # hideall -subset = collect(S._rand_subset(_rng(3), nodes, round(Int, 0.7 * ln))); +subset = let + df = processed[subset_indexes, :] + sort!(df, :nodes) + df +end # ╔═╡ dff9eb71-a853-4186-8245-a64206379b6f # hideall -ls = length(subset); +ls = nrow(subset); # ╔═╡ 8fdc24d9-1f6b-4094-9722-6b5b6c713f12 # hideall @@ -470,12 +483,14 @@ _plot_cutpoints(subset) # ╔═╡ 25ad7a18-f989-40f7-8ef1-4ca506446478 # hideall let + V = subset.nodes fig = Figure(; size=(800, 100)) ax = Axis(fig[1, 1]) - scatter!(ax, subset, fill(1, ls)) - vlines!(ax, [nodes[index]]; color=:red, linestyle=:dash) - annotation = string(round(nodes[index]; digits=2)) - text!(ax, nodes[index] + 0.003, 1.08; text=annotation, fontsize=11) + scatter!(ax, V, fill(1, ls); color=subset.color, marker=subset.marker) + loc = processed.nodes[index] + vlines!(ax, [loc]; color=:red, linestyle=:dash) + annotation = string(round(loc; digits=2)) + text!(ax, loc + 0.003, 1.08; text=annotation, fontsize=11) hideydecorations!(ax) ylims!(ax, 0.9, 1.2) fig @@ -483,7 +498,7 @@ end # ╔═╡ 4935d8f5-32e1-429c-a8c1-84c242eff4bf # hideall -_plot_cutpoints(nodes) +_plot_cutpoints(processed; q=10) # ╔═╡ a64dae3c-3b97-4076-98f4-3c9a0e5c0621 function _odds_plot(models::Vector{<:StableRules}, feat_names::Vector{String}) @@ -701,12 +716,14 @@ end # ╠═2c1adef4-822e-4dc0-946b-dc574e50b305 # ╠═896e00dc-2ce9-4a9f-acc1-519aec21dd83 # ╠═bfcb5e17-8937-4448-b090-2782818c6b6c +# ╠═6fb30208-cf39-42cd-bdda-a7941173822e # ╠═dff9eb71-a853-4186-8245-a64206379b6f # ╠═25ad7a18-f989-40f7-8ef1-4ca506446478 # ╠═ee12350a-627b-4a11-99cb-38c496977d18 # ╠═4935d8f5-32e1-429c-a8c1-84c242eff4bf # ╠═0cc970cd-b7ed-4782-a520-ff0a76fe0453 # ╠═8fdc24d9-1f6b-4094-9722-6b5b6c713f12 +# ╠═ede038b3-d92e-4208-b8ab-984f3ca1810e # ╠═01b08d44-4b9b-42e2-bb20-f34cb9b407f3 # ╠═7e1d46b4-5f93-478d-9105-a5b0db1eaf08 # ╠═ab103b4e-24eb-4575-8c04-ae3fd9ec1673 @@ -733,7 +750,6 @@ end # ╠═e7f396dc-38a7-40f7-9e5b-6fbea9d61789 # ╠═7c688412-d1b4-492d-bda2-0b9181057d4d # ╠═e1890517-7a44-4814-999d-6af27e2a136a -# ╠═ede038b3-d92e-4208-b8ab-984f3ca1810e # ╠═93a7dd3b-7810-4021-bf6e-ae9c04acea46 # ╠═be324728-1b60-4584-b8ea-c4fe9e3466af # ╠═7ad3cf67-2acd-44c6-aa91-7d5ae809dfbc