Skip to content

Commit

Permalink
Fix example
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed May 21, 2024
1 parent 471448f commit 691e5be
Showing 1 changed file with 53 additions and 37 deletions.
90 changes: 53 additions & 37 deletions docs/src/binary-classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ Say, we take the following subset of length `0.7 * length(nodes)`:
# ╔═╡ ee12350a-627b-4a11-99cb-38c496977d18
md"""
Now, the algorithm would choose a different location and, hence, introduce instability.
To solve this, Bénard et al. decided to limit the splitpoints that the algorithm can use to split to data to a pre-defined set of points.
For each feature, they find `q` empirical quantiles where `q` is typically 10.
Let's overlay these quantiles on top of the `age` feature:
Expand All @@ -181,6 +182,19 @@ md"""
Next, let's see where the cutpoints are when we take the same random subset as above:
"""

# ╔═╡ ede038b3-d92e-4208-b8ab-984f3ca1810e
# hideall
function _plot_cutpoints(data::DataFrame; q=10)
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
cps = Float64.(unique(cutpoints(data.nodes, q)))
scatter!(ax, data.nodes, fill(1, nrow(data)); color=data.color, marker=data.marker)
vlines!(ax, cps; color=:black, linestyle=:dash)
ylims!(ax, 0.9, 1.2)
hideydecorations!(ax)
return fig
end;

# ╔═╡ 01b08d44-4b9b-42e2-bb20-f34cb9b407f3
md"""
As can be seen, many cutpoints are at the same location as before.
Expand Down Expand Up @@ -331,23 +345,6 @@ md"""
## Appendix
"""

# ╔═╡ ede038b3-d92e-4208-b8ab-984f3ca1810e
function _plot_cutpoints(data::AbstractVector)
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
cps = Float64.(unique(cutpoints(data, 10)))
scatter!(ax, data, fill(1, length(data)))
vlines!(ax, cps; color=:black, linestyle=:dash)
textlocs = [(c, 1.1) for c in cps]
for cutpoint in cps
annotation = string(round(cutpoint; digits=2))::String
text!(ax, cutpoint + 0.2, 1.08; text=annotation, fontsize=13)
end
ylims!(ax, 0.9, 1.2)
hideydecorations!(ax)
return fig
end;

# ╔═╡ 93a7dd3b-7810-4021-bf6e-ae9c04acea46
_rng(seed::Int=1) = StableRNG(seed);

Expand Down Expand Up @@ -421,47 +418,63 @@ The `then` or `else` outcome is chosen for all the rules and, finally, the outco
"""

# ╔═╡ 172d3263-2e39-483c-9d82-8c22059e63c3
nodes = sort(data.age);
# hideall
processed = let
df = sort(data, :nodes)
df[!, :marker] = [survival == 1 ? :circle : :cross for survival in df.survival]
df[!, :color] = [survival == 1 ? :black : :gray for survival in df.survival]
df
end;

# ╔═╡ cf1816e5-4e8d-4e60-812f-bd6ae7011d6c
# hideall
ln = length(nodes);
l = length(processed.nodes);

# ╔═╡ 8d1b30bd-0ad2-416e-a36a-f263ef781289
# hideall
index = l - 21;

# ╔═╡ bfcb5e17-8937-4448-b090-2782818c6b6c
# hideall
subset_indexes = collect(S._rand_subset(_rng(3), 1:l, round(Int, 0.6 * l)));

# ╔═╡ de90efc9-2171-4406-93a1-9a213ab32259
# hideall
let
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
scatter!(ax, nodes, fill(1, ln))
color = (; colormap=:tab10, colorrange=(1, 10))
scatter!(ax, processed.nodes, fill(1, l); color=processed.color, marker=processed.marker)
hideydecorations!(ax)
fig
end

# ╔═╡ 8d1b30bd-0ad2-416e-a36a-f263ef781289
# hideall
index = length(nodes) - 3;

# ╔═╡ 2c1adef4-822e-4dc0-946b-dc574e50b305
# hideall
let
V = processed.nodes
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
scatter!(ax, nodes, fill(1, ln))
vlines!(ax, [nodes[index]]; color=:red)
annotation = string(round(nodes[index]; digits=2))
text!(ax, nodes[index] + 0.003, 1.08; text=annotation, fontsize=11)
scatter!(ax, V, fill(1, l); color=processed.color, marker=processed.marker)
vlines!(ax, [V[index]]; color=:red)
annotation = string(round(V[index]; digits=2))
text!(ax, V[index] + 0.003, 1.08; text=annotation, fontsize=11)
hideydecorations!(ax)
ylims!(ax, 0.9, 1.2)
fig
end

# ╔═╡ bfcb5e17-8937-4448-b090-2782818c6b6c
# ╔═╡ 6fb30208-cf39-42cd-bdda-a7941173822e
# hideall
subset = collect(S._rand_subset(_rng(3), nodes, round(Int, 0.7 * ln)));
subset = let
df = processed[subset_indexes, :]
sort!(df, :nodes)
df
end

# ╔═╡ dff9eb71-a853-4186-8245-a64206379b6f
# hideall
ls = length(subset);
ls = nrow(subset);

# ╔═╡ 8fdc24d9-1f6b-4094-9722-6b5b6c713f12
# hideall
Expand All @@ -470,20 +483,22 @@ _plot_cutpoints(subset)
# ╔═╡ 25ad7a18-f989-40f7-8ef1-4ca506446478
# hideall
let
V = subset.nodes
fig = Figure(; size=(800, 100))
ax = Axis(fig[1, 1])
scatter!(ax, subset, fill(1, ls))
vlines!(ax, [nodes[index]]; color=:red, linestyle=:dash)
annotation = string(round(nodes[index]; digits=2))
text!(ax, nodes[index] + 0.003, 1.08; text=annotation, fontsize=11)
scatter!(ax, V, fill(1, ls); color=subset.color, marker=subset.marker)
loc = processed.nodes[index]
vlines!(ax, [loc]; color=:red, linestyle=:dash)
annotation = string(round(loc; digits=2))
text!(ax, loc + 0.003, 1.08; text=annotation, fontsize=11)
hideydecorations!(ax)
ylims!(ax, 0.9, 1.2)
fig
end

# ╔═╡ 4935d8f5-32e1-429c-a8c1-84c242eff4bf
# hideall
_plot_cutpoints(nodes)
_plot_cutpoints(processed; q=10)

# ╔═╡ a64dae3c-3b97-4076-98f4-3c9a0e5c0621
function _odds_plot(models::Vector{<:StableRules}, feat_names::Vector{String})
Expand Down Expand Up @@ -701,12 +716,14 @@ end
# ╠═2c1adef4-822e-4dc0-946b-dc574e50b305
# ╠═896e00dc-2ce9-4a9f-acc1-519aec21dd83
# ╠═bfcb5e17-8937-4448-b090-2782818c6b6c
# ╠═6fb30208-cf39-42cd-bdda-a7941173822e
# ╠═dff9eb71-a853-4186-8245-a64206379b6f
# ╠═25ad7a18-f989-40f7-8ef1-4ca506446478
# ╠═ee12350a-627b-4a11-99cb-38c496977d18
# ╠═4935d8f5-32e1-429c-a8c1-84c242eff4bf
# ╠═0cc970cd-b7ed-4782-a520-ff0a76fe0453
# ╠═8fdc24d9-1f6b-4094-9722-6b5b6c713f12
# ╠═ede038b3-d92e-4208-b8ab-984f3ca1810e
# ╠═01b08d44-4b9b-42e2-bb20-f34cb9b407f3
# ╠═7e1d46b4-5f93-478d-9105-a5b0db1eaf08
# ╠═ab103b4e-24eb-4575-8c04-ae3fd9ec1673
Expand All @@ -733,7 +750,6 @@ end
# ╠═e7f396dc-38a7-40f7-9e5b-6fbea9d61789
# ╠═7c688412-d1b4-492d-bda2-0b9181057d4d
# ╠═e1890517-7a44-4814-999d-6af27e2a136a
# ╠═ede038b3-d92e-4208-b8ab-984f3ca1810e
# ╠═93a7dd3b-7810-4021-bf6e-ae9c04acea46
# ╠═be324728-1b60-4584-b8ea-c4fe9e3466af
# ╠═7ad3cf67-2acd-44c6-aa91-7d5ae809dfbc
Expand Down

0 comments on commit 691e5be

Please sign in to comment.