In [49]:
import re, math, json
from typing import Tuple, Union, Dict
from collections import defaultdict
import numpy as np

In [76]:
class DotContent:
    
    def __init__(self, path) -> None:
        self.dot_content = open(path).read()
        self.son_relation_raw = re.findall('\d+ -> \d+', self.dot_content)
        self.label_raw = re.findall('(\d+) \[label="([^"]+)"\]', self.dot_content)
        
        self.depth = 0
        self.node_map = {'0': 0}
        self.children_num = defaultdict(int)
        self.raw_tree = defaultdict(lambda: (-2, 0))
        self.full_tree = defaultdict(lambda: (-2, 0))
        self.serialized = []
    
    def _process_node(self, label_line) -> Tuple[int, float]:
        judgement = label_line.split('\\n')[0]
        node_attribute = int(re.search('\[(\d+)\]', judgement).group(1))
        node_threshold = float(re.search('\d+\.\d+', judgement).group())
        return node_attribute, node_threshold
    
    def _process_leaf(self, label_line) -> Tuple[int, int]:
        sample_num_raw = label_line.split('\\n')[-1]
        sample_num = np.array(eval(re.search('\[[\d, ]+\]', sample_num_raw).group()))
        class_index = np.where(sample_num)[0]
        assert class_index.shape[0] == 1
        return -1, int(class_index[0])
    
    
    def build(self) -> Dict[int, Tuple[int, Union[int, float]]]:

        for line in self.son_relation_raw:
            father, son = line.split(' -> ')
            if self.children_num[father] == 0:
                self.node_map[son] = self.node_map[father] * 2 + 1
            elif self.children_num[father] == 1:
                self.node_map[son] = self.node_map[father] * 2 + 2
            else:
                assert False
            self.children_num[father] += 1
            
        for name, label_line in self.label_raw:
            if label_line.startswith('gini'):
                self.raw_tree[self.node_map[name]] = self._process_leaf(label_line)
            else:
                self.raw_tree[self.node_map[name]] = self._process_node(label_line)
        
        self.depth = math.ceil(math.log2(2 + max(self.node_map.values())))
        return self.raw_tree
    
    
    def expand(self, save_json=''):
        assert self.depth != 0, "Should call `build()` first!"
        self.deepest = (2 ** (self.depth - 1) - 1, 2 ** self.depth - 1)
        self.full_tree = self.raw_tree.copy()
        
        for n in range(self.deepest[1]):
            if n not in self.full_tree:
                self.full_tree[n] = self.full_tree[(n-1) // 2]
                
        for n in range(self.deepest[0]):
            if self.full_tree[n][0] == -1:
                self.full_tree[n] = (0, 0)

        self.serialized = [(n, *self.full_tree[n]) for n in range(self.deepest[1])]
        if save_json:
            json.dump(self.serialized, open(save_json, 'w'))
        return self.serialized


dot = DotContent('../dct-viz.dot')
dot.build()
dot.expand(save_json='./dct.json')

[(0, 2, 2.45),
 (1, 0, 0),
 (2, 3, 1.75),
 (3, 0, 0),
 (4, 0, 0),
 (5, 2, 4.95),
 (6, 2, 4.85),
 (7, 0, 0),
 (8, 0, 0),
 (9, 0, 0),
 (10, 0, 0),
 (11, 3, 1.65),
 (12, 3, 1.55),
 (13, 1, 3.1),
 (14, 0, 0),
 (15, 0, 0),
 (16, 0, 0),
 (17, 0, 0),
 (18, 0, 0),
 (19, 0, 0),
 (20, 0, 0),
 (21, 0, 0),
 (22, 0, 0),
 (23, 0, 0),
 (24, 0, 0),
 (25, 0, 0),
 (26, 0, 6.95),
 (27, 0, 0),
 (28, 0, 0),
 (29, 0, 0),
 (30, 0, 0),
 (31, -1, 0),
 (32, -1, 0),
 (33, -1, 0),
 (34, -1, 0),
 (35, -1, 0),
 (36, -1, 0),
 (37, -1, 0),
 (38, -1, 0),
 (39, -1, 0),
 (40, -1, 0),
 (41, -1, 0),
 (42, -1, 0),
 (43, -1, 0),
 (44, -1, 0),
 (45, -1, 0),
 (46, -1, 0),
 (47, -1, 1),
 (48, -1, 1),
 (49, -1, 2),
 (50, -1, 2),
 (51, -1, 2),
 (52, -1, 2),
 (53, -1, 1),
 (54, -1, 2),
 (55, -1, 2),
 (56, -1, 2),
 (57, -1, 1),
 (58, -1, 1),
 (59, -1, 2),
 (60, -1, 2),
 (61, -1, 2),
 (62, -1, 2)]

In [65]:
dot.raw_tree

defaultdict(<function __main__.DotContent.__init__.<locals>.<lambda>()>,
            {0: (2, 2.45),
             1: (-1, 0),
             2: (3, 1.75),
             5: (2, 4.95),
             11: (3, 1.65),
             23: (-1, 1),
             24: (-1, 2),
             12: (3, 1.55),
             25: (-1, 2),
             26: (0, 6.95),
             53: (-1, 1),
             54: (-1, 2),
             6: (2, 4.85),
             13: (1, 3.1),
             27: (-1, 2),
             28: (-1, 1),
             14: (-1, 2)})