In [1]:
import sys
sys.path.append('./src')
from test.code import Stage1, Stage2, Stage3

In [6]:
def generate_mermaid_flowchart(objects, diagram_level=1):
    if diagram_level not in (1, 2, 3):
        raise ValueError(f"Invalid diagram_level: {diagram_level}")

    # Dictionary to track which class produces each output
    # Map outputs to their source objects
    output_to_object = {}
    for obj in objects:
        for output in obj.outputs:
            output_to_object[output] = obj.name
    
    # Start building the Mermaid flowchart
    mermaid_code = ["```mermaid\nflowchart TB"]
    
    # Add detailed nodes for each class
    for obj in objects:
        # Replace spaces with underscores for valid Mermaid IDs
        node_id = obj.name.replace(" ", "_")
        # Format the node content with description, inputs, and outputs
        node_content = [f"<ins>**{obj.name}**</ins>"]
        if diagram_level > 1:
            node_content.append(f"\t\t<ins>**Description**</ins>: {obj.description}")
        if diagram_level > 2:
            # Add inputs
            node_content += ["\t\t<ins>**Inputs**</ins>:"]
            for input_item in obj.inputs:
                node_content.append(f"\t\t\t- **{input_item}**: description")
            # Add outputs
            node_content += ["\t\t<ins>**Outputs**</ins>:"]
            for output_item in obj.outputs:
                node_content.append(f"\t\t\t- **{output_item}**: description")
        # Join all the content with line breaks and properly escape double quotes
        formatted_content = "\n".join(node_content).replace('"', '\\"')
        # Add the node with all the formatted content
        mermaid_code.append(f'    {node_id}["{formatted_content}"]')
    
    # Add connections between objects
    # Use a set to avoid duplicate connections
    connections = set()
    for target in objects:
        for input_item in target.inputs:
            if input_item in output_to_object:
                source_name = output_to_object[input_item]
                if source_name != target.name:  # Avoid self-connections
                    source_id = source_name.replace(" ", "_")
                    target_id = target.name.replace(" ", "_")
                    if diagram_level in (1, 2):
                        connection = f'    {source_id} --> {target_id}'
                    elif diagram_level == 3:
                        connection = f'    {source_id} -->|"{input_item}"| {target_id}'
                    connections.add(connection)
    
    # Add connections to the flowchart
    mermaid_code.extend(connections)
    
    # Add some styling to make the boxes look nicer
    mermaid_code.append("classDef default fill:#f9f9f9,stroke:#333,stroke-width:1px")
    
    return "\n".join(mermaid_code) + "\n```"

In [7]:
objects = [Stage1, Stage2, Stage3]
diagram_level = 3
mermaid_code = generate_mermaid_flowchart(objects, diagram_level)

In [8]:
mermaid_code

'```mermaid\nflowchart TB\n    stage1["stage1\n\t\tDescription: abc\n\t\tInputs:\n\t\t\t- **in1_1**: description\n\t\t\t- **in1_2**: description\n\t\t\t- **in1_3**: description\n\t\tOutputs:\n\t\t\t- **out1_1**: description\n\t\t\t- **out1_2**: description\n\t\t\t- **out1_3**: description"]\n    stage2["stage2\n\t\tDescription: abc\n\t\tInputs:\n\t\t\t- **in2_1**: description\n\t\t\t- **in2_2**: description\n\t\t\t- **out1_2**: description\n\t\tOutputs:\n\t\t\t- **out2_1**: description\n\t\t\t- **out2_2**: description\n\t\t\t- **out2_3**: description"]\n    stage3["stage3\n\t\tDescription: abc\n\t\tInputs:\n\t\t\t- **out1_2**: description\n\t\t\t- **out2_3**: description\n\t\t\t- **out1_3**: description\n\t\t\t- **out2_2**: description\n\t\tOutputs:\n\t\t\t- **out3_1**: description\n\t\t\t- **out3_2**: description\n\t\t\t- **out3_3**: description"]\n    stage1 -->|"out1_3"| stage3\n    stage1 -->|"out1_2"| stage2\n    stage2 -->|"out2_3"| stage3\n    stage1 -->|"out1_2"| stage3\n    st

In [9]:
with open('./diagram.md', 'w') as f:
    f.write(mermaid_code)