We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an API to obtain rules for plotting. Currently, the following code
function _rule_index(model::StableRules, feature_name::String) for (i, rule) in enumerate(model.rules) if only(rule.path.splits).splitpoint.feature_name == feature_name return i end end return nothing end # Renamed to `SIRUS.sum_weights(fitresults::Vector{StableRules}, name::AbstractString)`. function _sum_weights(fitresults::Vector{<:StableRules}, name::AbstractString) indexes = _rule_index.(fitresults, Ref(name)) return sum([isnothing(index) ? 0 : fitresults[i].weights[index] for (i, index) in enumerate(indexes)]) end function _remove_nato_name(name::String) if contains(name, '(') parts = split(name, ' ') return join(parts[1:end-1], ' ') else return name end end function _threshold(rule) sp = only(rule.path.splits).splitpoint return sp.value end function odds_plot( e::PerformanceEvaluation, data::DataFrame, pretty_name::Function ) w, h = (800, 1000) fig = Figure(; resolution=(w, h)) grid = fig[1, 1:3] = GridLayout() fitresults = getproperty.(e.fitted_params_per_fold, :fitresult) feature_names = String[] for fitresult in fitresults for rule in fitresult.rules name = only(rule.path.splits).splitpoint.feature_name push!(feature_names, name) end end names = sort(unique(feature_names)) subtitle = "Ratio" max_height = maximum(maximum.(getproperty.(fitresults, :weights))) importances = _sum_weights.(Ref(fitresults), names) matching_rules = DataFrame(; names, importance=importances) sort!(matching_rules, :importance; rev=true) names = matching_rules.names[1:15] l = length(names) pretty_names = [pretty_name(n) for n in names] for (i, feature_name) in enumerate(names) yticks = (1:1, [pretty_names[i]]) ax = i == l ? Axis(grid[i, 1:2]; yticks, xlabel="Ratio") : Axis(grid[i, 1:2]; yticks) vlines!(ax, [0]; color=:gray, linestyle=:dash) xlims!(ax, -1.01, 1.01) ylabel = feature_name name = _remove_nato_name(pretty_name(feature_name)) nested_rules_weights = map(fitresults) do fitresult subresult = Tuple{SIRUS.Rule,Float64}[] zipped = zip(fitresult.rules, fitresult.weights) for (rule, weight) in zipped feat_name = only(rule.path.splits).splitpoint.feature_name if feat_name == feature_name push!(subresult, (rule, weight)) end end subresult end rules_weights = Tuple{SIRUS.Rule,Float64}[] for nested in nested_rules_weights isnothing(nested) && continue for rule_weight in nested push!(rules_weights, rule_weight) end end rw::Vector{Tuple{SIRUS.Rule,Float64}} = filter(!isnothing, rules_weights) thresholds = _threshold.(first.(rw)) t_mean = round(mean(thresholds); digits=1) t_std = round(std(thresholds); digits=1) for (rule, weight) in rw left = last(rule.then_probs)::Float64 right = last(rule.else_probs)::Float64 t::Float64 = _threshold(rule) ratio = log((right) / (left)) # area = πr² markersize = 50 * sqrt(weight / π) scatter!(ax, [ratio], [1]; color=:black, markersize) end hideydecorations!(ax; ticklabels=false) axr = i == l ? Axis(grid[i, 3:5]; xlabel="Location") : Axis(grid[i, 3:5]) D = data[:, feature_name] hist!(axr, D; scale_to=1, color=:white, strokewidth=1, strokecolor=:black) vlines!(axr, thresholds; color=:black, linestyle=:dash) if i < l hidexdecorations!(ax) else hidexdecorations!(ax; ticks=false, ticklabels=false) end hideydecorations!(axr) hidexdecorations!(axr; ticks=false, ticklabels=false) end rowgap!(grid, 5) return fig end;
Produces the following plot
Apart from the bug that causes all points to be on the left, this should provide a good basis for the API together with #44.
The text was updated successfully, but these errors were encountered:
Split
Extend API (#74)
c3bae52
Fixes #44 and fixes #66.
Successfully merging a pull request may close this issue.
Add an API to obtain rules for plotting. Currently, the following code
Produces the following plot
Apart from the bug that causes all points to be on the left, this should provide a good basis for the API together with #44.
The text was updated successfully, but these errors were encountered: