## Complete tree printing

In [14]:
import re

def parse_log_file(file_path):
    with open(file_path, 'r') as f:
        logs = []
        for line in f:
            # Match for function entry logs with '>' or '-->'
            match_push = re.match(r'^(--+>|\s*>)\s*call function (.+?) in (.+?):(\d+)', line)
            if match_push:
                function_name = match_push.group(2)
                file_path = match_push.group(3)
                line_number = int(match_push.group(4))
                logs.append((function_name, file_path, line_number, "push"))
                continue

            # Match for function exit logs with '<' or '<--'
            match_pop = re.match(r'^(<--+|\s*<)\s*exit function (.+?) in (.+?):(\d+)', line)
            if match_pop:
                function_name = match_pop.group(2)
                file_path = match_pop.group(3)
                line_number = int(match_pop.group(4))
                logs.append((function_name, file_path, line_number, "pop"))
                continue

    return logs

class CallStackNode:
    def __init__(self, function_name, file_path, line_number):
        self.function_name = function_name
        self.file_path = file_path
        self.line_number = line_number
        self.children = []
        self.parent = None

    def add_child(self, child_node):
        child_node.parent = self
        self.children.append(child_node)

    def __repr__(self, level=0, is_last_child=True, parent_last_childs=[]):
        indent = ''
        for parent_last in parent_last_childs:
            indent += '    ' if parent_last else '│   '

        branch = '└── ' if is_last_child else '├── '
        repr_str = f"{indent}{branch}{self.function_name}, {self.file_path}:{self.line_number}\n" # complete
        # repr_str = f"{indent}{branch}{self.function_name}\n" # simple without file path and line number

        for i, child in enumerate(self.children):
            repr_str += child.__repr__(level + 1, i == len(self.children) - 1, parent_last_childs + [is_last_child])
        return repr_str

class CallStackTree:
    def __init__(self):
        self.root = CallStackNode("root", "", 0)
        self.current_node = self.root

    def push(self, function_name, file_path, line_number):
        new_node = CallStackNode(function_name, file_path, line_number)
        self.current_node.add_child(new_node)
        self.current_node = new_node

    def pop(self):
        if self.current_node != self.root:
            self.current_node = self.current_node.parent

    def __repr__(self):
        return self.root.__repr__()

# Parse log file and construct call stack tree
log_file_path = "/root/vescale_prj/veScale/test/parallel/pipeline/instruction/logs-test_schedule-host_f78e8e970a17-pid_199624-py/tracing-test_schedule-20240829_071457.log"
logs = parse_log_file(log_file_path)

call_stack_tree = CallStackTree()

for log in logs:
    function_name, file_path, line_number, operation = log
    if operation == "push":
        call_stack_tree.push(function_name, file_path, line_number)
    elif operation == "pop":
        call_stack_tree.pop()

print(call_stack_tree)

└── root, :0
    ├── init_device_mesh, /root/vescale_prj/veScale/vescale/devicemesh_api/api.py:48
    │   └── init_device_mesh, /root/vescale_prj/veScale/vescale/dtensor/device_mesh.py:594
    │       └── __init__, /root/vescale_prj/veScale/vescale/dtensor/device_mesh.py:224
    │           ├── update_vescale_debug_mode_from_env, /root/vescale_prj/veScale/vescale/debug/debug_log.py:88
    │           ├── _get_or_create_default_group, /root/vescale_prj/veScale/vescale/dtensor/device_mesh.py:313
    │           │   └── _get_device_handle, /root/vescale_prj/veScale/vescale/dtensor/device_mesh.py:153
    │           ├── _get_device_handle, /root/vescale_prj/veScale/vescale/dtensor/device_mesh.py:153
    │           ├── _get_current_device, /root/vescale_prj/veScale/vescale/dtensor/device_mesh.py:281
    │           ├── _validate_mesh, /root/vescale_prj/veScale/vescale/dtensor/device_mesh.py:339
    │           └── _init_process_groups, /root/vescale_prj/veScale/vescale/dtensor/device_mesh.

### Study plan

In [4]:
import re
from collections import deque

def parse_log_file(file_path):
    with open(file_path, 'r') as f:
        logs = []
        for line in f:
            # Match for function entry logs with '>' or '-->'
            match_push = re.match(r'^(--+>|\s*>)\s*call function (.+?) in (.+?):(\d+)', line)
            if match_push:
                function_name = match_push.group(2)
                file_path = match_push.group(3)
                line_number = int(match_push.group(4))
                logs.append((function_name, file_path, line_number, "push"))
                continue

            # Match for function exit logs with '<' or '<--'
            match_pop = re.match(r'^(<--+|\s*<)\s*exit function (.+?) in (.+?):(\d+)', line)
            if match_pop:
                function_name = match_pop.group(2)
                file_path = match_pop.group(3)
                line_number = int(match_pop.group(4))
                logs.append((function_name, file_path, line_number, "pop"))
                continue

    return logs

class CallStackNode:
    def __init__(self, function_name, file_path, line_number):
        self.function_name = function_name
        self.file_path = file_path
        self.line_number = line_number
        self.children = []
        self.parent = None

    def add_child(self, child_node):
        child_node.parent = self
        self.children.append(child_node)

    def __repr__(self, level=0, is_last_child=True, parent_last_childs=[]):
        indent = ''
        for parent_last in parent_last_childs:
            indent += '    ' if parent_last else '│   '

        branch = '└── ' if is_last_child else '├── '
        repr_str = f"{indent}{branch}{self.function_name}, {self.file_path}:{self.line_number}\n" # complete

        for i, child in enumerate(self.children):
            repr_str += child.__repr__(level + 1, i == len(self.children) - 1, parent_last_childs + [is_last_child])
        return repr_str

class CallStackTree:
    def __init__(self):
        self.root = CallStackNode("root", "", 0)
        self.current_node = self.root

    def push(self, function_name, file_path, line_number):
        new_node = CallStackNode(function_name, file_path, line_number)
        self.current_node.add_child(new_node)
        self.current_node = new_node

    def pop(self):
        if self.current_node != self.root:
            self.current_node = self.current_node.parent

    def breadth_first_traversal(self):
        queue = deque([self.root])
        visited = set()
        traversal_order = []
        while queue:
            current_node = queue.popleft()
            node_id = (current_node.function_name, current_node.file_path, current_node.line_number)
            if node_id not in visited:
                visited.add(node_id)
                traversal_order.append(current_node)
                queue.extend(current_node.children)
        return traversal_order

    def __repr__(self):
        return self.root.__repr__()

# Parse log file and construct call stack tree
log_file_path = "/root/vescale_prj/veScale/test/parallel/pipeline/instruction/logs-test_schedule-host_f78e8e970a17-pid_199624-py/tracing-test_schedule-20240829_071457.log"
logs = parse_log_file(log_file_path)

call_stack_tree = CallStackTree()

for log in logs:
    function_name, file_path, line_number, operation = log
    if operation == "push":
        call_stack_tree.push(function_name, file_path, line_number)
    elif operation == "pop":
        call_stack_tree.pop()

# Print the call stack tree
# print(call_stack_tree)

# Generate breadth-first learning plan without duplicates
traversal_order = call_stack_tree.breadth_first_traversal()
print("\nLearning Plan (Breadth-First, No Duplicates):")
for node in traversal_order:
    if node.function_name != "root":
        print(f"Learn function '{node.function_name}' in file '{node.file_path}', line {node.line_number}.")


Learning Plan (Breadth-First, No Duplicates):
Learn function 'init_device_mesh' in file '/root/vescale_prj/veScale/vescale/devicemesh_api/api.py', line 48.
Learn function 'init_ndtimers' in file '/root/vescale_prj/veScale/vescale/ndtimeline/api.py', line 72.
Learn function '__init__' in file '/root/vescale_prj/veScale/test/parallel/pipeline/instruction/test_schedule.py', line 164.
Learn function 'get_global_tensor_parallel_meshes' in file '/root/vescale_prj/veScale/vescale/devicemesh_api/api.py', line 361.
Learn function 'get_linear_pp_module_dep2' in file '/root/vescale_prj/veScale/vescale/pipe/_schedules/instruction_base.py', line 291.
Learn function '<listcomp>' in file '/root/vescale_prj/veScale/test/parallel/pipeline/instruction/test_schedule.py', line 635.
Learn function '__init__' in file '/root/vescale_prj/veScale/vescale/pipe/pipe_emmiter.py', line 133.
Learn function 'execute' in file '/root/vescale_prj/veScale/vescale/pipe/pipe_emmiter.py', line 267.
Learn function 'flush' 