# Cubic Tree SHAP

Here we demonstrate an algorithm for exactly computing the SHAP value of a decision tree in $O(LD^2)$ running time, where $L$ is the number of leaves in the tree and $D$ is the maximum depth of the tree.

## Tree SHAP algorithm

In [1]:
# data we keep about our decision path
# note that pweight is included for convenience and is not tied with the other attributes
# the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them
type PathElement
    feature_index
    zero_fraction
    one_fraction
    pweight
end

# extend our decision path with a fraction of one and zero extensions
function extend_path!(unique_path, unique_depth, zero_fraction, one_fraction, feature_index)
    unique_path[unique_depth+1] = PathElement(feature_index, zero_fraction, one_fraction, unique_depth == 0 ? 1 : 0)
    for i in unique_depth:-1:1
        unique_path[i+1].pweight += one_fraction*unique_path[i].pweight*i/(unique_depth+1)
        unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth-i+1)/(unique_depth+1)
    end
end

# undo a previous extension of the decision path
function unwind_path!(unique_path, unique_depth, path_index)
    one_fraction = unique_path[path_index].one_fraction
    zero_fraction = unique_path[path_index].zero_fraction
    next_one_portion = unique_path[unique_depth+1].pweight
    
    for i in unique_depth:-1:1
        if one_fraction != 0
            tmp = unique_path[i].pweight
            unique_path[i].pweight = next_one_portion*(unique_depth+1)/(i*one_fraction)
            next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth-i+1)/(unique_depth+1)
        else
            unique_path[i].pweight = (unique_path[i].pweight*(unique_depth+1))/(zero_fraction*(unique_depth-i+1))
        end
    end
    
    for i in path_index:unique_depth
        unique_path[i].feature_index = unique_path[i+1].feature_index
        unique_path[i].zero_fraction = unique_path[i+1].zero_fraction
        unique_path[i].one_fraction = unique_path[i+1].one_fraction
    end
end

# determine what the total permuation weight would be if we unwound a previous extension in the decision path
function unwound_path_sum(unique_path, unique_depth, path_index)
    one_fraction = unique_path[path_index].one_fraction
    zero_fraction = unique_path[path_index].zero_fraction
    next_one_portion = unique_path[unique_depth+1].pweight
    total = 0
    for i in unique_depth:-1:1  
        if one_fraction != 0
            tmp = next_one_portion*(unique_depth+1)/(i*one_fraction)
            total += tmp
            next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth-i+1)/(unique_depth+1))
        else
            total += (unique_path[i].pweight/zero_fraction)/((unique_depth-i+1)/(unique_depth+1))
        end
    end
    total
end

# recursive computation of SHAP values for a decision tree
function tree_shap!(phi, x, tree_nodes, node_index=1, unique_depth=0, parent_unique_path=PathElement[],
                    parent_zero_fraction=1, parent_one_fraction=1, parent_feature_index=0)
    node = tree_nodes[node_index]
    
    # extend the unique path
    unique_path = Array(PathElement, unique_depth+1)
    unique_path[1:unique_depth] = deepcopy(parent_unique_path[1:unique_depth])
    extend_path!(unique_path, unique_depth, parent_zero_fraction, parent_one_fraction, parent_feature_index)
    
    # leaf node
    if node.feature_index == 0
        for i in 2:unique_depth+1
            w = unwound_path_sum(unique_path, unique_depth, i)
            el = unique_path[i]
            phi[el.feature_index] += w*(el.one_fraction-el.zero_fraction)*node.value
        end
    
    # internal node
    else
        # find which branch is "hot" (meaning x would follow it)
        hot_index = 0
        if x[node.feature_index] == nothing
            hot_index = node.missing_index
        elseif x[node.feature_index] < node.value
            hot_index = node.yes_index
        else
            hot_index = node.no_index
        end
        cold_index = (hot_index == node.yes_index ? node.no_index : node.yes_index)
        hot_zero_fraction = tree_nodes[hot_index].cover/node.cover
        cold_zero_fraction = tree_nodes[cold_index].cover/node.cover
        incoming_zero_fraction = incoming_one_fraction = 1
        
        # see if we have already split on this feature, if so we undo that split so we can redo it for this node
        path_index = findfirst([e.feature_index for e in unique_path], node.feature_index)
        if path_index != 0
            incoming_zero_fraction = unique_path[path_index].zero_fraction
            incoming_one_fraction = unique_path[path_index].one_fraction
            unwind_path!(unique_path, unique_depth, path_index)
            unique_depth -= 1
        end
        
        tree_shap!(
            phi, x, tree_nodes, hot_index, unique_depth+1, unique_path,
            hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, node.feature_index
        )
        tree_shap!(
            phi, x, tree_nodes, cold_index, unique_depth+1, unique_path,
            cold_zero_fraction*incoming_zero_fraction, 0, node.feature_index
        )
    end
end

tree_shap! (generic function with 7 methods)

## Supporting evaluation and comparison code

### Helper functions to run the SHAP algorithm

In [2]:
using XGBoost

function xgboost_shap(bst, x)
    global weight_cache = Dict()
    phi = zeros(length(x))
    data = XGBoost.XGBoosterDumpModel(bst.handle, "", 1)
    phi0 = 0.0
    base_score = 0.0
    for tree_num in 1:length(data)
        nodes = parse_xgboost_tree(unsafe_string(data[tree_num]))
        
        # find the base_score of the model using the first tree
        if tree_num == 1
            tree_out = eval_tree(x, nodes)
            model_out = predict(bst, reshape(x, 1, length(x)), ntree_limit=1)[1]
            base_score = model_out - tree_out
        end
                
        phi0 += cond_expectation([], x, nodes)[1]
        tree_shap!(phi, x, nodes)
    end
    phi0+base_score,phi
end
function xgboost_shap_data(data, x)
    global weight_cache = Dict()
    phi = zeros(length(x))
    #data = XGBoost.XGBoosterDumpModel(bst.handle, "", 1)
    phi0 = 0.0
    base_score = 0.0
    for tree_num in 1:length(data)
        nodes = parse_xgboost_tree(data[tree_num])
                        #println(nodes)
        # find the base_score of the model using the first tree
#         if tree_num == 1
#             tree_out = eval_tree(x, nodes)
#             model_out = predict(bst, reshape(x, 1, length(x)), ntree_limit=1)[1]
#             base_score = model_out - tree_out
#         end
                
        #phi0 += cond_expectation([], x, nodes)[1]
        tree_shap!(phi, x, nodes)
    end
    #phi0+base_score,
    phi
end

xgboost_shap_data (generic function with 1 method)

### Function to build a Julia representation of the XGBoost trees

In [3]:
type TreeNode
    feature_index
    value
    yes_index
    no_index
    missing_index
    cover
end

function parse_xgboost_tree(tree_str)
    lines = strip.(split(tree_str, '\n'))
    nodes = Array(TreeNode, length(lines)-1) # ignore the blank line at the end
    internal_regex = r"([0-9]+):\[f([0-9]+)<([-+e0-9.]+)\] yes=([0-9]+),no=([0-9]+),missing=([0-9]+),gain=[-e+0-9.]+,cover=([0-9]+)"
    leaf_regex = r"([0-9]+):leaf=([-e+0-9.]+),cover=([0-9]+)"
    for i in 1:length(lines)-1
        m = match(internal_regex, lines[i])
        if m != nothing
            index = parse(Int64, m.captures[1])+1
            feature_index = parse(Int64, m.captures[2])+1
            split_val = parse(Float64, m.captures[3])
            yes_index = parse(Int64, m.captures[4])+1
            no_index = parse(Int64, m.captures[5])+1
            missing_index = parse(Int64, m.captures[6])+1
            cover = parse(Int64, m.captures[7])
            nodes[index] = TreeNode(feature_index, split_val, yes_index, no_index, missing_index, cover)
        else
            m = match(leaf_regex, lines[i])
            index = parse(Int64, m.captures[1])+1
            leaf_val = parse(Float64, m.captures[2])
            cover = parse(Int64, m.captures[3])
            nodes[index] = TreeNode(0, leaf_val, 0, 0, 0, cover)
        end
    end
    nodes
end

using Base.Test
import Base.==

function ==(x::TreeNode, y::TreeNode)
    x.feature_index == y.feature_index
end

tree_str = "0:[f1<0.210065] yes=1,no=2,missing=1,gain=7.70125,cover=100\n\t1:[f0<1.14837] yes=3,no=4,missing=3,gain=5.454,cover=59\n\t\t3:leaf=0.0268182,cover=53\n\t\t4:leaf=-0.0672221,cover=6\n\t2:[f2<-1.61475] yes=5,no=6,missing=5,gain=2.06263,cover=41\n\t\t5:leaf=0.0641373,cover=1\n\t\t6:leaf=-0.0436476,cover=40\n"
nodes = parse_xgboost_tree(tree_str)

@test all(nodes .== [
    TreeNode(2,0.210065,2,3,2,100),
    TreeNode(1,1.14837,4,5,4,59),
    TreeNode(3,-1.61475,6,7,6,41),
    TreeNode(0,0.0268182,0,0,0,53),
    TreeNode(0,-0.0672221,0,0,0,6),
    TreeNode(0,0.0641373,0,0,0,1),
    TreeNode(0,-0.0436476,0,0,0,40)
])

[1m[32mTest Passed
[0m  Expression: all(nodes .== [TreeNode(2,0.210065,2,3,2,100),TreeNode(1,1.14837,4,5,4,59),TreeNode(3,-1.61475,6,7,6,41),TreeNode(0,0.0268182,0,0,0,53),TreeNode(0,-0.0672221,0,0,0,6),TreeNode(0,0.0641373,0,0,0,1),TreeNode(0,-0.0436476,0,0,0,40)])

### Exponential time exact algorithm

In [4]:
# compute the value of the tree for a given x
function eval_tree(x, tree_nodes, node_index=1, missing_value=nothing)
    node = tree_nodes[node_index]
    if node.feature_index == 0
        return node.value
    else
        if x[node.feature_index] == missing_value
            return eval_tree(x, tree_nodes, node.missing_index, nothing)
        elseif x[node.feature_index] < node.value
            return eval_tree(x, tree_nodes, node.yes_index, nothing)
        else
            return eval_tree(x, tree_nodes, node.no_index, nothing)
        end
    end
end

eval_tree (generic function with 3 methods)

In [5]:
using Iterators

# compute the expectation of the tree output conditioned on the variables x_S
function cond_expectation(S, x, tree_nodes, node_index=1, missing_value=nothing, weight=1.0)
    node = tree_nodes[node_index]
    val = 0.0
    sum_weight = 0
    if node.feature_index == 0
        val = node.value
        sum_weight = weight
    elseif node.feature_index in S
        if x[node.feature_index] == missing_value
            val,sum_weight = cond_expectation(S, x, tree_nodes, node.missing_index, missing_value, weight)
        elseif x[node.feature_index] < node.value
            val,sum_weight = cond_expectation(S, x, tree_nodes, node.yes_index, missing_value, weight)
        else
            val,sum_weight = cond_expectation(S, x, tree_nodes, node.no_index, missing_value, weight)
        end
    else
        @assert node.missing_index == node.no_index || node.missing_index == node.yes_index
        val1,weight1 = cond_expectation(S, x, tree_nodes, node.yes_index, missing_value, weight*(tree_nodes[node.yes_index].cover/node.cover))
        val2,weight2 = cond_expectation(S, x, tree_nodes, node.no_index, missing_value, weight*(tree_nodes[node.no_index].cover/node.cover))
        sum_weight = weight1 + weight2
        val = (weight1*val1 + weight2*val2)/sum_weight
    end
    val,sum_weight
end

function shapley_weight(M, s)
    factorial(s)*factorial(M-s-1)/factorial(M)
end

# uses the conditional expectation algorithm to brute force compute a SHAP value
function brute_force_phi(bst, x, i)
    brute_force_phi_data(unsafe_string.(XGBoost.XGBoosterDumpModel(bst.handle, "", 1)), x, i)
end

function brute_force_phi_data(data, x, i)
    phi = 0.0

    for tree_num in 1:length(data)
        nodes = parse_xgboost_tree(data[tree_num])
        
        for subset in subsets(setdiff(1:length(x), [i]))
            val1,weight1 = cond_expectation(union(subset, [i]), x, nodes)
            val2,weight2 = cond_expectation(subset, x, nodes)
            w = shapley_weight(length(x), length(subset))
            phi += w*(val1-val2)
        end
    end
    phi
end

brute_force_phi_data (generic function with 1 method)

In [6]:
function verify_match(data, x)
    phi1 = [brute_force_phi_data(data, x, i) for i in 1:length(x)]
    phi2 = xgboost_shap_data(data, x)
    println("brute_force = $phi1")
    println("fast_method = $phi2")
    if norm(phi1 .- phi2) > 1e-8 || isnan(norm(phi1 .- phi2))
        @assert false
    end
end

verify_match (generic function with 1 method)

In [7]:
x = ones(4)
data = ["0:[f0<0.5] yes=1,no=2,missing=1,gain=5.95771,cover=100\n\t1:[f1<0.5] yes=3,no=4,missing=3,gain=4.72408,cover=50\n\t\t3:leaf=0,cover=25\n\t\t4:leaf=0,cover=25\n\t2:[f1<0.5] yes=5,no=6,missing=5,gain=2.30706,cover=50\n\t\t5:leaf=0,cover=25\n\t\t6:leaf=1,cover=25\n"]
verify_match(data, x)

data = ["0:[f0<0.5] yes=1,no=2,missing=1,gain=5.95771,cover=100\n\t1:[f1<0.5] yes=3,no=4,missing=3,gain=4.72408,cover=50\n\t\t3:leaf=0,cover=25\n\t\t4:leaf=0,cover=25\n\t2:[f1<0.5] yes=5,no=6,missing=5,gain=2.30706,cover=50\n\t\t5:leaf=1,cover=25\n\t\t6:leaf=0,cover=25\n"];
verify_match(data, x)

data = ["0:[f0<0.5] yes=1,no=2,missing=1,gain=5.95771,cover=100\n\t1:[f1<0.5] yes=3,no=4,missing=3,gain=4.72408,cover=50\n\t\t3:leaf=0,cover=25\n\t\t4:leaf=0,cover=25\n\t2:[f0<0.4] yes=5,no=6,missing=5,gain=2.30706,cover=50\n\t\t5:leaf=1,cover=25\n\t\t6:leaf=1,cover=25\n"]
verify_match(data, x)

data = ["0:[f0<0.0547004] yes=1,no=2,missing=1,gain=6.41592,cover=100\n\t1:[f0<-0.1] yes=3,no=4,missing=3,gain=7.23454,cover=50\n\t\t3:leaf=0,cover=25\n\t\t4:leaf=0,cover=25\n\t2:[f0<0.5] yes=5,no=6,missing=5,gain=9.11159,cover=50\n\t\t5:leaf=0,cover=25\n\t\t6:leaf=1,cover=25\n"];
verify_match(data, x)

data = ["0:[f0<0.0547004] yes=1,no=2,missing=1,gain=6.41592,cover=100\n\t1:[f0<-0.1] yes=3,no=4,missing=3,gain=7.23454,cover=50\n\t\t3:leaf=1,cover=25\n\t\t4:leaf=0,cover=25\n\t2:[f0<0.5] yes=5,no=6,missing=5,gain=9.11159,cover=50\n\t\t5:leaf=0,cover=25\n\t\t6:leaf=1,cover=25\n"];
verify_match(data, x)

data = ["0:[f0<0.0547004] yes=1,no=2,missing=1,gain=6.41592,cover=100\n\t1:[f1<-0.1] yes=3,no=4,missing=3,gain=7.23454,cover=50\n\t\t3:leaf=1,cover=15\n\t\t4:leaf=0,cover=35\n\t2:[f1<0.5] yes=5,no=6,missing=5,gain=9.11159,cover=50\n\t\t5:leaf=0,cover=5\n\t\t6:leaf=1,cover=45\n"];
verify_match(data, x)

data = ["0:[f0<-0.108652] yes=1,no=2,missing=1,gain=9.91912,cover=200\n\t1:[f1<-0.0500525] yes=3,no=4,missing=3,gain=7.68742,cover=100\n\t\t3:[f2<-1.18479] yes=7,no=8,missing=7,gain=5.72911,cover=50\n\t\t\t7:leaf=0,cover=25\n\t\t\t8:leaf=0,cover=25\n\t\t4:[f2<-0.28887] yes=9,no=10,missing=9,gain=4.89582,cover=50\n\t\t\t9:leaf=0,cover=25\n\t\t\t10:leaf=0,cover=25\n\t2:[f3<-1.82883] yes=5,no=6,missing=5,gain=5.23317,cover=100\n\t\t5:[f2<0.914076] yes=11,no=12,missing=11,gain=6.40652,cover=50\n\t\t\t11:leaf=0,cover=25\n\t\t\t12:leaf=1,cover=25\n\t\t6:[f2<0.914076] yes=13,no=14,missing=13,gain=6.40652,cover=50\n\t\t\t13:leaf=0,cover=35\n\t\t\t14:leaf=0,cover=15\n"]
verify_match(data, x)

brute_force = [0.375,0.375,0.0,0.0]
fast_method = [0.375,0.375,0.0,0.0]
brute_force = [0.125,-0.375,0.0,0.0]
fast_method = [0.125,-0.375,0.0,0.0]
brute_force = [0.5,0.0,0.0,0.0]
fast_method = [0.5,0.0,0.0,0.0]
brute_force = [0.75,0.0,0.0,0.0]
fast_method = [0.75,0.0,0.0,0.0]
brute_force = [0.5,0.0,0.0,0.0]
fast_method = [0.5,0.0,0.0,0.0]
brute_force = [0.4,0.0,0.0,0.0]
fast_method = [0.4,-1.38778e-17,0.0,0.0]
brute_force = [0.0833333,0.0,0.0833333,-0.291667]
fast_method = [0.0833333,0.0,0.0833333,-0.291667]


## Build an XGBoost tree model to explain

In [8]:
using Iterators
using XGBoost

N = 1000
M = 10
X = randn(N,M)
x = randn(M)
y = randn(N)
dtrain = DMatrix(X,label=y)
bst = xgboost(dtrain, 1, param=Dict("max_depth"=>3, "objective"=>"reg:linear", "eta"=>.1, "base_score"=>0.3), silent=true)

data=unsafe_string.(XGBoost.XGBoosterDumpModel(bst.handle, "", 1))
x = ones(M)
println(data[1])
verify_match(data, x)

0:[f1<-1.69235] yes=1,no=2,missing=1,gain=15.3372,cover=1000
	1:[f7<0.161436] yes=3,no=4,missing=3,gain=8.69375,cover=35
		3:[f2<0.699213] yes=7,no=8,missing=7,gain=3.15086,cover=23
			7:leaf=-0.0282265,cover=16
			8:leaf=0.0478976,cover=7
		4:[f1<-1.72871] yes=9,no=10,missing=9,gain=3.38603,cover=12
			9:leaf=0.119984,cover=10
			10:leaf=-0.0147658,cover=2
	2:[f6<-0.509197] yes=5,no=6,missing=5,gain=12.2108,cover=965
		5:[f6<-2.61395] yes=11,no=12,missing=11,gain=8.48565,cover=273
			11:leaf=0.101897,cover=5
			12:leaf=-0.0185253,cover=268
		6:[f2<1.77262] yes=13,no=14,missing=13,gain=6.6369,cover=692
			13:leaf=-0.0390368,cover=668
			14:leaf=-0.0921749,cover=24

brute_force = [0.0,-0.00371667,0.00196208,0.0,0.0,0.0,-0.00656882,0.000976718,0.0,0.0]
fast_method = [0.0,-0.00371667,0.00196208,0.0,0.0,0.0,-0.00656882,0.000976718,0.0,0.0]
