You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After reading some issues I realized that would be useful to share a function that I made while working with your project.
This function receives a sklearn Random Forest class and read each tree writing a list with all info.
It is possible that it is not 100% correct.
def model_to_txt(self, index, show: bool = True, save: bool = False):
# https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html#sphx-glr-auto-examples-tree-plot-unveil-tree-structure-py
forest = self.estimators_
model_info = list()
model_info.append(
f"DATASET_NAME: {config['DATASET']['NAME']}.train{index}.csv"
f"\nENSEMBLE: RF"
f"\nNB_TREES: {len(forest)}"
f"\nNB_FEATURES: {forest[0].tree_.n_features}"
f"\nNB_CLASSES: {forest[0].tree_.n_classes[0]}"
f"\nMAX_TREE_DEPTH: {forest[0].tree_.max_depth}"
"\nFormat: node / node type (LN - leave node, IN - internal node) "
"left child / right child / feature / threshold / node_depth / "
"majority class (starts with index 0)"
)
for tree_idx, est in enumerate(forest):
tree = est.tree_
n_nodes = tree.node_count
children_left = tree.children_left
children_right = tree.children_right
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, 0)] # start with the root node id (0) and its depth (0)
model_info.append(f"\n\n[TREE {tree_idx}]\nNB_NODES: {n_nodes}")
while len(stack) > 0:
node_id, depth = stack.pop()
node_depth[node_id] = depth
if children_left[node_id] != children_right[node_id]:
stack.append((children_left[node_id], depth + 1))
stack.append((children_right[node_id], depth + 1))
else:
is_leaves[node_id] = True
for i in range(n_nodes):
class_idx = np.argmax(tree.value[i][0])
if is_leaves[i]:
model_info.append(f"\n{i} LN -1 -1 -1 -1 {node_depth[i]} {class_idx}")
else:
model_info.append(
f"\n{i} IN {children_left[i]} {children_right[i]} "
f"{tree.feature[i]} {tree.threshold[i]} {node_depth[i]} -1"
)
model_info.append("\n\n")
if show:
print(*model_info)
if save:
with open(
f"./data/processed/forests/{config['DATASET']['NAME']}.RF{index}.txt",
"w"
) as f:
for item in model_info:
f.write(item)
The text was updated successfully, but these errors were encountered:
crimson-luis
changed the title
Function to save Ranfom Forest model to txt file.
Function to save Random Forest model to txt file.
Jun 3, 2022
After reading some issues I realized that would be useful to share a function that I made while working with your project.
This function receives a sklearn Random Forest class and read each tree writing a list with all info.
It is possible that it is not 100% correct.
The text was updated successfully, but these errors were encountered: