In [None]:
import matplotlib.pyplot as plt
import random

    
def hierarchy_pos(G, root=None, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5):

    '''
    From Joel's answer at https://stackoverflow.com/a/29597209/2966723.  
    Licensed under Creative Commons Attribution-Share Alike 
    
    If the graph is a tree this will return the positions to plot this in a 
    hierarchical layout.
    
    G: the graph (must be a tree)
    
    root: the root node of current branch 
    - if the tree is directed and this is not given, 
      the root will be found and used
    - if the tree is directed and this is given, then 
      the positions will be just for the descendants of this node.
    - if the tree is undirected and not given, 
      then a random choice will be used.
    
    width: horizontal space allocated for this branch - avoids overlap with other branches
    
    vert_gap: gap between levels of hierarchy
    
    vert_loc: vertical location of root
    
    xcenter: horizontal location of root
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  #allows back compatibility with nx version 1.11
        else:
            root = random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5, pos = None, parent = None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''
    
        if pos is None:
            pos = {root:(xcenter,vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)  
        if len(children)!=0:
            dx = width/len(children) 
            nextx = xcenter - width/2 - dx/2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G,child, width = width*.5, vert_gap = vert_gap, 
                                    vert_loc = vert_loc-vert_gap, xcenter=nextx,
                                    pos=pos, parent = root)
        return pos

            
    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

In [None]:
pos = hierarchy_pos(G,bayes_spyct.root_node, width=5., vert_gap = 5.)

In [None]:
leaves = []
bayes_spyct.get_leaves(bayes_spyct.root_node, leaves)
leaves

In [None]:
non_leaves = [x for x in nodes if x not in leaves]
non_leaves

In [None]:
prototype_expected_value = {node: int(torch.sum(node.prototype).item()) for node in leaves}
leaf_sizes = [1000 for _ in list(prototype_expected_value.values())]

In [None]:
weights_per_node = {}
for node in non_leaves:
    weights_per_node[node] = node.split_model.weight[0]

In [None]:
weights_top5 = {}
for key, value in weights_per_node.items():
    abs_values = np.abs(value.numpy())
    sorted_indices = np.argsort(abs_values)[::-1]
    sorted_values = abs_values[sorted_indices]
    sorted_dict = {idx: val for idx, val in zip(sorted_indices, sorted_values)}
    weights_top5[key] = sorted_dict

In [None]:
feature_groups = {'isco': isco_idcs,
                  'p16': p16_idcs,
                  'other': other_idcs}

feature_group_colors = {'isco': 'red', 'p16': 'blue', 'other': 'green'}

# Create a color map for the features
feature_color_map = {}

# Assign colors for isco and p16
for group, color in feature_group_colors.items():
    if group in ['isco', 'p16', 'other']:
        features = feature_groups[group]
        for feature in features:
            feature_color_map[feature] = color

In [None]:
from matplotlib.patches import Patch

fig = plt.figure(figsize=(18,10))
ax = fig.add_subplot(1, 1, 1)

for node in non_leaves:
    features = list(weights_top5[node])
    num_colors = 3
    colors = [feature_color_map[feature] for feature in features]
    edges_ = [edge for edge in edges if (edge[0]==node)]

    offset = 0.0
    for i, color in enumerate(set(colors)):
        offset += 0.1
        sum_per_group = sum([weights_top5[node][k] for k,v in feature_color_map.items() if v==color])
        sum_weights = sum(list(weights_top5[node].values()))
        width = sum_per_group/sum_weights
        print(color, width)
        # width = colors.count(color)
        x, y = pos[node]

        angle = i * (2 * offset) - offset
        offset_dx = offset * (num_colors - 1) * 0.5 * (-1)**i
        offset_dy = offset * (num_colors - 1) * 0.5 * (-1)**i

        nx.draw_networkx_edges(G, pos, ax=ax, edgelist=[edges_[0]], edge_color=color, width=width*1.5,
                               connectionstyle=f'arc3, rad={offset}', alpha=0.5,
                               min_source_margin=5, min_target_margin=5)
        nx.draw_networkx_edges(G, pos, ax=ax, edgelist=[edges_[1]], edge_color=color, width=width*1.5,
                               connectionstyle=f'arc3, rad={-offset}', alpha=0.5,
                               min_source_margin=5, min_target_margin=5)
        
nx.draw_networkx_nodes(G, pos=pos,ax=ax, nodelist=non_leaves, node_color='w', edgecolors='black', alpha = 1, node_shape='s')
nx.draw_networkx_nodes(G, pos=pos,ax=ax, nodelist=leaves, node_color='w', edgecolors='black', alpha = 1, node_size = leaf_sizes)
nx.draw_networkx_labels(G, pos=pos,ax=ax, labels=prototype_expected_value, font_size=10);

legend_elements = [Patch(color=color, label=group) for group, color in feature_group_colors.items()]
plt.legend(handles=legend_elements, title='Feature Groups', loc='center right')