In [None]:
import pydot
import matplotlib.pyplot as plt
from inspect import signature

def create_block_diagram(block, output_path="model_structure.png", display_diagram=True):
    """
    Creates a visual diagram of a model block structure.
    
    Parameters:
    -----------
    block : object
        Any model block with dynamics, shocks, and reward attributes
    output_path : str
        File path for saving the diagram
    display_diagram : bool
        Whether to display the diagram after creation
    
    Returns:
    --------
    graph : pydot.Dot
        The generated graph object
    """
    
    # Define different colors & shapes for variable types
    style_map = {
        "Control": {"shape": "rectangle", "color": "blue"},   # Control variables (decisions)
        "Reward": {"shape": "diamond", "color": "green"},     # Rewards
        "Shock": {"shape": "triangle", "color": "red"},       # Shocks (exogenous)
        "State": {"shape": "ellipse", "color": "yellow"}      # Default for state variables
    }
    
    # Create a directed graph
    graph = pydot.Dot("ModelStructure", graph_type="digraph", rankdir="LR")
    
    # Extract variables dynamically from dynamics
    if hasattr(block, "dynamics"):
        for var_name, func in block.dynamics.items():
            var_type = "State"
            dependencies = []

            if isinstance(func, Control):
                var_type = "Control"
                # For Control variables, determine dependencies based on variable name
                if var_name == "c":
                    dependencies.append("w")  # Consumption depends on wealth
                elif var_name == "R":
                    dependencies.append("a")  # Interest rate depends on assets
            else:
                # For non-Control variables, get all input parameters
                try:
                    dependencies = list(signature(func).parameters.keys())
                    
                    # Special handling for AR processes (like income)
                    if var_name == "y" and "y" in dependencies:
                        # This is just to be explicit about self-reference
                        pass  # Keep the self-reference for AR processes
                except:
                    # If we can't get the signature, use an empty list
                    pass

            node_style = style_map[var_type]
            graph.add_node(
                pydot.Node(
                    var_name,
                    shape=node_style["shape"],
                    fillcolor=node_style["color"],
                    style="filled"
                )
            )
            
            for dep in dependencies:
                graph.add_edge(pydot.Edge(dep, var_name))

    # Add Shock variables
    if hasattr(block, "shocks"):
        for shock in block.shocks.keys():
            node_style = style_map["Shock"]
            
            graph.add_node(
                pydot.Node(
                    shock,
                    shape=node_style["shape"],
                    fillcolor=node_style["color"],
                    style="filled"
                )
            )

    # Add Reward variables
    if hasattr(block, "reward"):
        for reward_name, reward_func in block.reward.items():
            node_style = style_map["Reward"]
            
            graph.add_node(
                pydot.Node(
                    reward_name,
                    shape=node_style["shape"],
                    fillcolor=node_style["color"],
                    style="filled"
                )
            )
            
            try:
                for arg in signature(reward_func).parameters:
                    graph.add_edge(pydot.Edge(arg, reward_name))
            except:
                # If we can't get the signature, skip adding edges
                pass
    
    # Save the graph to file
    graph.write_png(output_path)
    
    # Display if requested
    if display_diagram:
        img = plt.imread(output_path)
        plt.figure(figsize=(10, 7))
        plt.imshow(img)
        plt.axis('off')
        plt.show()
    
    return graph