In [1]:
import numpy as np
import pandas as pd
import seaborn as sns

import itertools
from collections import defaultdict

from typing import Dict, Tuple, Any
from anytree import Node, RenderTree

## Construct a tree

In [2]:
# function g: combined features shepley values ranking
def g(scores: Dict[Tuple[Any], float]) -> Tuple[Tuple[Any], float]:
    l = sorted(scores.items(), key=lambda x: abs(x[1]), reverse=True)
    return l[0]

def flatten(li):
    for ele in li:
        if isinstance(ele, list) or isinstance(ele, tuple):
            yield from flatten(ele)
        else:
            yield ele

In [3]:
np.random.seed(8)

cols = list('ABCD')
num_feature = len(cols)
siv = np.random.randn(num_feature, num_feature)
r, c = np.tril_indices(num_feature, -1)
r_u, c_u = zip(*(zip(*(c, r))))
siv[r_u, c_u] = siv[r, c]
siv_display = pd.DataFrame(siv, index=cols, columns=cols)
siv_display

Unnamed: 0,A,B,C,D
A,0.091205,-2.296492,0.794828,-1.123327
B,-2.296492,2.409834,0.976421,-0.664035
C,0.794828,0.976421,-1.183427,-0.378359
D,-1.123327,-0.664035,-0.378359,-0.791615


In [4]:
r_diag, c_diag = np.diag_indices(num_feature)
main_effect = siv[r_diag, c_diag]
pd.Series(main_effect, index=cols, name='main_effect')

A    0.091205
B    2.409834
C   -1.183427
D   -0.791615
Name: main_effect, dtype: float64

In [5]:
import pprint

# init
nodes = {}
for i, c in enumerate(cols):
    nodes[i] = Node(name=c, parent=None, value=main_effect[i])
print('Init Nodes:')
pprint.pprint(nodes)

Init Nodes:
{0: Node('/A', value=0.09120471661981977),
 1: Node('/B', value=2.4098343033415413),
 2: Node('/C', value=-1.183427147333015),
 3: Node('/D', value=-0.7916152714963363)}


In [6]:
scores = {}
done = set()  # check need to run it or pass at the next time

# filter done nodes, so that we dont need to run again
nodes_to_run = [k for k in nodes.keys() if k not in done]
print(f'candidates: {nodes_to_run}')
# first round
for cmbs in itertools.combinations(nodes_to_run, 2):  # combination = 2 to build binary tree
    # score is calculated by sum of all related shapley values
    if cmbs not in scores.keys():
        r, c = list(zip(*itertools.product(flatten(cmbs), flatten(cmbs))))
        scores[cmbs] = siv[r, c].sum()
    # print the calulate result
    feature_name = ''.join([cols[i] for i in flatten(cmbs)])
    print(f'Feature Combination: {feature_name}')
    print(f'co-cordinates: {list(zip(*(r, c)))}')
    print(f'Values: {siv[r, c]}')
    print()

cmbs, max_value = g(scores)
feature_name = ''.join([cols[i] for i in flatten(cmbs)])
print(f'Max Shapley value combination: {feature_name} = {max_value:.4f}')

candidates: [0, 1, 2, 3]
Feature Combination: AB
co-cordinates: [(0, 0), (0, 1), (1, 0), (1, 1)]
Values: [ 0.09120472 -2.29649157 -2.29649157  2.4098343 ]

Feature Combination: AC
co-cordinates: [(0, 0), (0, 2), (2, 0), (2, 2)]
Values: [ 0.09120472  0.79482764  0.79482764 -1.18342715]

Feature Combination: AD
co-cordinates: [(0, 0), (0, 3), (3, 0), (3, 3)]
Values: [ 0.09120472 -1.1233268  -1.1233268  -0.79161527]

Feature Combination: BC
co-cordinates: [(1, 1), (1, 2), (2, 1), (2, 2)]
Values: [ 2.4098343   0.9764211   0.9764211  -1.18342715]

Feature Combination: BD
co-cordinates: [(1, 1), (1, 3), (3, 1), (3, 3)]
Values: [ 2.4098343  -0.66403547 -0.66403547 -0.79161527]

Feature Combination: CD
co-cordinates: [(2, 2), (2, 3), (3, 2), (3, 3)]
Values: [-1.18342715 -0.37835857 -0.37835857 -0.79161527]

Max Shapley value combination: BC = 3.1792


In [7]:
scores

{(0, 1): -2.0919441284429663,
 (0, 2): 0.49743284798360676,
 (0, 3): -2.9470641559787936,
 (1, 2): 3.1792493488104068,
 (1, 3): 0.29014809267699393,
 (2, 3): -2.731759559566469}

In [8]:
feature_name = ''.join([cols[i] for i in flatten(cmbs)])
children = []
for c in cmbs:
    children.append(nodes[c])
    done.add(c)
    # need to remove all impossible options for 'scores'
    impossible_coor = list(filter(lambda x: c in x, scores.keys()))
    for coor in impossible_coor:
        scores.pop(coor, None)

nodes[cmbs] = Node(name=feature_name, value=max_value, children=children)

print('New Nodes')
pprint.pprint(nodes)
print('Scores Left:')
pprint.pprint(scores)

New Nodes
{0: Node('/A', value=0.09120471661981977),
 1: Node('/BC/B', value=2.4098343033415413),
 2: Node('/BC/C', value=-1.183427147333015),
 3: Node('/D', value=-0.7916152714963363),
 (1, 2): Node('/BC', value=3.1792493488104068)}
Scores Left:
{(0, 3): -2.9470641559787936}


In [9]:
# filter done nodes, so that we dont need to run again
nodes_to_run = [k for k in nodes.keys() if k not in done]
print(f'candidates: {nodes_to_run}')

# second round
for cmbs in itertools.combinations(nodes_to_run, 2):  # combination = 2 to build binary tree
    # score is calculated by sum of all related shapley values
    if cmbs not in scores.keys():
        r, c = list(zip(*itertools.product(flatten(cmbs), flatten(cmbs))))
        scores[cmbs] = siv[r, c].sum()

cmbs, max_value = g(scores)
feature_name = ''.join([cols[i] for i in flatten(cmbs)])
print(f'Max Shapley value combination: {feature_name} = {max_value:.4f}')

candidates: [0, 3, (1, 2)]
Max Shapley value combination: AD = -2.9471


In [10]:
scores

{(0, 3): -2.9470641559787936,
 (0, (1, 2)): 0.267126195722702,
 (3, (1, 2)): 0.3028459974087421}

In [11]:
feature_name = ''.join([cols[i] for i in flatten(cmbs)])
children = []
for c in cmbs:
    children.append(nodes[c])
    done.add(c)
    # need to remove all impossible options for 'scores'
    impossible_coor = list(filter(lambda x: c in x, scores.keys()))
    for coor in impossible_coor:
        scores.pop(coor, None)

nodes[cmbs] = Node(name=feature_name, value=max_value, children=children)

print('New Nodes')
pprint.pprint(nodes)
print('Scores Left:')
pprint.pprint(scores)

New Nodes
{0: Node('/AD/A', value=0.09120471661981977),
 1: Node('/BC/B', value=2.4098343033415413),
 2: Node('/BC/C', value=-1.183427147333015),
 3: Node('/AD/D', value=-0.7916152714963363),
 (0, 3): Node('/AD', value=-2.9470641559787936),
 (1, 2): Node('/BC', value=3.1792493488104068)}
Scores Left:
{}


In [12]:
nodes

{0: Node('/AD/A', value=0.09120471661981977),
 1: Node('/BC/B', value=2.4098343033415413),
 2: Node('/BC/C', value=-1.183427147333015),
 3: Node('/AD/D', value=-0.7916152714963363),
 (1, 2): Node('/BC', value=3.1792493488104068),
 (0, 3): Node('/AD', value=-2.9470641559787936)}

In [13]:
# need to think about how to end this

nodes['root'] = Node(''.join(cols), value=siv.sum(), children=[nodes[(0, 3)], nodes[(1, 2)]])

r = RenderTree(nodes['root'])

In [14]:
for pre, fill, node in r:
    print("%s%s" % (pre, node.name))

ABCD
├── AD
│   ├── A
│   └── D
└── BC
    ├── B
    └── C
