Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to obtain rules for visualizations #66

Closed
rikhuijzer opened this issue Nov 13, 2023 · 0 comments · Fixed by #74
Closed

Add API to obtain rules for visualizations #66

rikhuijzer opened this issue Nov 13, 2023 · 0 comments · Fixed by #74

Comments

@rikhuijzer
Copy link
Owner

rikhuijzer commented Nov 13, 2023

Add an API to obtain rules for plotting. Currently, the following code

function _rule_index(model::StableRules, feature_name::String)
    for (i, rule) in enumerate(model.rules)
        if only(rule.path.splits).splitpoint.feature_name == feature_name
            return i
        end
    end
    return nothing
end

# Renamed to `SIRUS.sum_weights(fitresults::Vector{StableRules}, name::AbstractString)`.
function _sum_weights(fitresults::Vector{<:StableRules}, name::AbstractString)
    indexes = _rule_index.(fitresults, Ref(name))
    return sum([isnothing(index) ? 0 : fitresults[i].weights[index] for (i, index) in enumerate(indexes)])
end

function _remove_nato_name(name::String)
    if contains(name, '(')
        parts = split(name, ' ')
        return join(parts[1:end-1], ' ')
    else
        return name
    end
end

function _threshold(rule)
    sp = only(rule.path.splits).splitpoint
    return sp.value
end

function odds_plot(
        e::PerformanceEvaluation,
        data::DataFrame,
        pretty_name::Function
    )
    w, h = (800, 1000)
    fig = Figure(; resolution=(w, h))
    grid = fig[1, 1:3] = GridLayout()

    fitresults = getproperty.(e.fitted_params_per_fold, :fitresult)
    feature_names = String[]
    for fitresult in fitresults
        for rule in fitresult.rules
            name = only(rule.path.splits).splitpoint.feature_name
            push!(feature_names, name)
        end
    end

    names = sort(unique(feature_names))
    subtitle = "Ratio"

    max_height = maximum(maximum.(getproperty.(fitresults, :weights)))

    importances = _sum_weights.(Ref(fitresults), names)

    matching_rules = DataFrame(; names, importance=importances)
    sort!(matching_rules, :importance; rev=true)
    names = matching_rules.names[1:15]
    l = length(names)

    pretty_names = [pretty_name(n) for n in names]

    for (i, feature_name) in enumerate(names)
        yticks = (1:1, [pretty_names[i]])
        ax = i == l ?
            Axis(grid[i, 1:2]; yticks, xlabel="Ratio") :
            Axis(grid[i, 1:2]; yticks)
        vlines!(ax, [0]; color=:gray, linestyle=:dash)
        xlims!(ax, -1.01, 1.01)
        ylabel = feature_name

        name = _remove_nato_name(pretty_name(feature_name))

        nested_rules_weights = map(fitresults) do fitresult
            subresult = Tuple{SIRUS.Rule,Float64}[]
            zipped = zip(fitresult.rules, fitresult.weights)
            for (rule, weight) in zipped
                feat_name = only(rule.path.splits).splitpoint.feature_name
                if feat_name == feature_name
                    push!(subresult, (rule, weight))
                end
            end
            subresult
        end
        rules_weights = Tuple{SIRUS.Rule,Float64}[]
        for nested in nested_rules_weights
            isnothing(nested) && continue
            for rule_weight in nested
                push!(rules_weights, rule_weight)
            end
        end
        rw::Vector{Tuple{SIRUS.Rule,Float64}} =
            filter(!isnothing, rules_weights)
        thresholds = _threshold.(first.(rw))
        t_mean = round(mean(thresholds); digits=1)
        t_std = round(std(thresholds); digits=1)

        for (rule, weight) in rw
            left = last(rule.then_probs)::Float64
            right = last(rule.else_probs)::Float64
            t::Float64 = _threshold(rule)
            ratio = log((right) / (left))
            # area = πr²
            markersize = 50 * sqrt(weight / π)
            scatter!(ax, [ratio], [1]; color=:black, markersize)
        end
        hideydecorations!(ax; ticklabels=false)

        axr = i == l ?
            Axis(grid[i, 3:5]; xlabel="Location") :
            Axis(grid[i, 3:5])
        D = data[:, feature_name]
        hist!(axr, D; scale_to=1, color=:white, strokewidth=1, strokecolor=:black)
        vlines!(axr, thresholds; color=:black, linestyle=:dash)

        if i < l
            hidexdecorations!(ax)
        else
            hidexdecorations!(ax; ticks=false, ticklabels=false)
        end

        hideydecorations!(axr)
        hidexdecorations!(axr; ticks=false, ticklabels=false)
    end

    rowgap!(grid, 5)
    return fig
end;

Produces the following plot

image

Apart from the bug that causes all points to be on the left, this should provide a good basis for the API together with #44.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant