In [27]:
from typing import Dict, Any, List, Set, Optional, Callable
from collections import deque
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import multiprocessing
import pickle
import logging
from dataclasses import dataclass, field
from enum import Enum
import time
import json
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from datetime import datetime
import random


# -------------------------------------------------------------------
# Configuration and Logging
# -------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(threadName)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)


class ExecutionMode(Enum):
    """Execution mode for the engine."""

    SEQUENTIAL = "sequential"
    THREAD = "thread"


@dataclass
class ExecutionTrace:
    """Trace data for a single node execution."""

    node_name: str
    node_type: str
    start_time: float
    end_time: float
    duration: float
    level: int
    thread_id: str
    success: bool = True
    error: Optional[str] = None

    @property
    def relative_start(self) -> float:
        """Start time relative to execution start."""
        return self.start_time

    @property
    def relative_end(self) -> float:
        """End time relative to execution start."""
        return self.end_time


@dataclass
class NodeMetrics:
    """Metrics for node execution."""

    execution_count: int = 0
    total_execution_time: float = 0.0
    last_execution_time: float = 0.0
    error_count: int = 0

    def record_execution(self, duration: float):
        self.execution_count += 1
        self.total_execution_time += duration
        self.last_execution_time = duration

    def record_error(self):
        self.error_count += 1

    @property
    def avg_execution_time(self) -> float:
        if self.execution_count == 0:
            return 0.0
        return self.total_execution_time / self.execution_count


# -------------------------------------------------------------------
# NodeDefinition: stateless compute template
# -------------------------------------------------------------------
class NodeDefinition:
    """A node definition describes pure logic. No state, no wiring.

    For production scalability, NodeDefinitions should be:
    - Stateless and pure (no side effects)
    - Serializable (for distributed execution)
    - Idempotent (safe to retry)
    """

    type_name = "BASE"
    num_inputs = 0
    num_outputs = 0

    # Execution hints for optimization
    is_expensive = False  # CPU-intensive operations
    is_io_bound = False  # I/O operations
    estimated_duration_ms = 100  # For scheduling

    def compute(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """Override in subclasses to implement logic."""
        raise NotImplementedError()

    def validate_inputs(self, inputs: Dict[str, Any]) -> bool:
        """Validate inputs before execution."""
        return len(inputs) == self.num_inputs

    def serialize(self) -> bytes:
        """Serialize for distributed execution."""
        return pickle.dumps(self)

    @classmethod
    def deserialize(cls, data: bytes) -> "NodeDefinition":
        """Deserialize from bytes."""
        return pickle.loads(data)


# -------------------------------------------------------------------
# Example NodeDefinitions
# -------------------------------------------------------------------
class ConstDefinition(NodeDefinition):
    type_name = "CONST"
    num_inputs = 0
    num_outputs = 1

    def compute(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        # Const nodes do not compute—they're set directly in NodeInstance.
        return {}


class AddDefinition(NodeDefinition):
    type_name = "ADD"
    num_inputs = 2
    num_outputs = 1
    estimated_duration_ms = 10

    def compute(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        time.sleep(random.randrange(0, 10) * 0.01)
        return {"output1": inputs["input1"] + inputs["input2"]}


class MultiplyDefinition(NodeDefinition):
    type_name = "MULTIPLY"
    num_inputs = 2
    num_outputs = 1
    is_expensive = True
    estimated_duration_ms = 50

    def compute(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        # Simulate expensive computation
        time.sleep(random.randrange(0, 10) * 0.01)
        return {"output1": inputs["input1"] * inputs["input2"]}


# -------------------------------------------------------------------
# NodeInstance: runtime wiring + state
# -------------------------------------------------------------------
class NodeInstance:
    """Runtime instance with wiring, state, and execution metadata."""

    def __init__(
        self, name: str, definition: NodeDefinition, node_id: Optional[str] = None
    ):
        self.name = name
        self.node_id = node_id or name  # Unique identifier
        self.definition = definition

        # inputN → (source_node, source_output_name)
        self.inputs: Dict[str, Any] = {}

        # outputN → computed values
        self.outputs: Dict[str, Any] = {}

        # reverse dependency graph (children needing this output)
        self.children: List["NodeInstance"] = []

        # parent nodes (dependencies)
        self.parents: Set[NodeInstance] = set()

        # flag for incremental execution
        self.is_dirty = True

        # lock for thread-safe output updates
        self.lock = threading.Lock()

        # Execution metrics
        self.metrics = NodeMetrics()

        # Error handling
        self.last_error: Optional[Exception] = None
        self.retry_count = 0
        self.max_retries = 3

    def input_name(self, index: int) -> str:
        return f"input{index}"

    def output_name(self, index: int) -> str:
        return f"output{index}"

    def set_input(
        self, input_index: int, src_node: "NodeInstance", src_output_index: int
    ):
        input_key = self.input_name(input_index)
        output_key = src_node.output_name(src_output_index)
        self.inputs[input_key] = (src_node, output_key)

        # Build dependency graph
        src_node.children.append(self)
        self.parents.add(src_node)

    def resolve_inputs(self) -> Dict[str, Any]:
        """Resolve inputs with validation."""
        resolved = {}
        for input_name, (src_node, src_output_name) in self.inputs.items():
            if src_output_name not in src_node.outputs:
                raise ValueError(
                    f"Output {src_output_name} not available from {src_node.name}"
                )
            resolved[input_name] = src_node.outputs[src_output_name]
        return resolved

    def to_json(self) -> Dict[str, Any]:
        """Serialize node state to JSON (for persistence/debugging)."""
        return {
            "name": self.name,
            "node_id": self.node_id,
            "type": self.definition.type_name,
            "outputs": self.outputs,
            "is_dirty": self.is_dirty,
            "metrics": {
                "execution_count": self.metrics.execution_count,
                "avg_execution_time": self.metrics.avg_execution_time,
                "error_count": self.metrics.error_count,
            },
        }

    def __repr__(self):
        return f"<NodeInstance {self.name}>"

    def __hash__(self):
        return hash(self.node_id)

    def __eq__(self, other):
        return isinstance(other, NodeInstance) and self.node_id == other.node_id


# -------------------------------------------------------------------
# NodeRegistry with versioning and plugin support
# -------------------------------------------------------------------
class NodeRegistry:
    """Registry for node definitions with versioning support."""

    _registry: Dict[str, Dict[str, NodeDefinition]] = (
        {}
    )  # type -> version -> definition
    _latest_versions: Dict[str, str] = {}  # type -> latest version

    @classmethod
    def register(cls, definition: NodeDefinition, version: str = "1.0.0"):
        """Register a node definition with version."""
        type_name = definition.type_name

        if type_name not in cls._registry:
            cls._registry[type_name] = {}
            cls._latest_versions[type_name] = version

        cls._registry[type_name][version] = definition

        # Update latest version (simple string comparison)
        if version > cls._latest_versions[type_name]:
            cls._latest_versions[type_name] = version

        logger.info(f"Registered {type_name} v{version}")

    @classmethod
    def create_definition(
        cls, type_name: str, version: Optional[str] = None
    ) -> NodeDefinition:
        """Create a node definition, optionally specifying version."""
        if type_name not in cls._registry:
            raise ValueError(f"Unknown node type: {type_name}")

        version = version or cls._latest_versions[type_name]

        if version not in cls._registry[type_name]:
            raise ValueError(f"Version {version} not found for {type_name}")

        return cls._registry[type_name][version]

    @classmethod
    def list_types(cls) -> List[str]:
        """List all registered node types."""
        return list(cls._registry.keys())


# Register definitions
NodeRegistry.register(ConstDefinition(), "1.0.0")
NodeRegistry.register(AddDefinition(), "1.0.0")
NodeRegistry.register(MultiplyDefinition(), "1.0.0")


# -------------------------------------------------------------------
# Execution Engine with Production Features
# -------------------------------------------------------------------
class ExecutionEngine:
    """Production-ready execution engine with advanced features."""

    def __init__(
        self,
        nodes: List[NodeInstance],
        max_workers: int = None,
        mode: ExecutionMode = ExecutionMode.THREAD,
        enable_checkpointing: bool = False,
        checkpoint_callback: Optional[Callable] = None,
        enable_profiling: bool = True,
    ):
        self.nodes = nodes
        self.max_workers = max_workers or multiprocessing.cpu_count()
        self.mode = mode
        self.execution_order: List[List[NodeInstance]] = []
        self.enable_checkpointing = enable_checkpointing
        self.checkpoint_callback = checkpoint_callback
        self.enable_profiling = enable_profiling

        # Build node lookup for fast access
        self.node_lookup = {node.node_id: node for node in nodes}

        # Execution statistics
        self.total_execution_time = 0.0
        self.level_execution_times: List[float] = []
        self.execution_start_time = 0.0

        # Profiling data
        self.execution_traces: List[ExecutionTrace] = []
        self.traces_lock = threading.Lock()

    def topological_sort(self) -> List[List[NodeInstance]]:
        """
        Perform topological sort and group nodes by execution level.
        Nodes at the same level can be executed in parallel.

        Returns: List of levels, where each level is a list of nodes
                 that can execute in parallel.
        """
        # Calculate in-degree for each node
        in_degree = {node: len(node.parents) for node in self.nodes}

        # Find all nodes with no dependencies (in-degree = 0)
        queue = deque([node for node in self.nodes if in_degree[node] == 0])

        levels = []
        processed = 0

        while queue:
            # All nodes in current queue can execute in parallel
            current_level = list(queue)
            levels.append(current_level)
            queue.clear()
            processed += len(current_level)

            # Process all nodes in current level
            for node in current_level:
                # Reduce in-degree of children
                for child in node.children:
                    in_degree[child] -= 1
                    # If all dependencies satisfied, add to next level
                    if in_degree[child] == 0:
                        queue.append(child)

        # Check for cycles
        if processed < len(self.nodes):
            unprocessed = [n.name for n in self.nodes if in_degree[n] > 0]
            raise ValueError(
                f"Cycle detected in node graph. Unprocessed: {unprocessed}"
            )

        logger.info(
            f"Topological sort complete: {len(levels)} levels, {processed} nodes"
        )
        return levels

    def mark_dirty(self, node: NodeInstance):
        """Mark a node and its children dirty (incremental updates)."""
        if not node.is_dirty:
            node.is_dirty = True
            for child in node.children:
                self.mark_dirty(child)

    def execute_node(self, node: NodeInstance, level: int) -> NodeInstance:
        """Execute a single node with error handling and metrics."""
        start_time = time.time()
        thread_id = threading.current_thread().name
        success = True
        error_msg = None

        try:
            logger.info(f"Executing {node.name}")

            # Const nodes just have preset output1
            if node.definition.type_name != "CONST":
                resolved_inputs = node.resolve_inputs()

                # Validate inputs
                if not node.definition.validate_inputs(resolved_inputs):
                    raise ValueError(f"Invalid inputs for {node.name}")

                outputs = node.definition.compute(resolved_inputs)

                with node.lock:
                    node.outputs = outputs

            logger.info(f"Completed {node.name}: {node.outputs}")

            with node.lock:
                node.is_dirty = False
                node.retry_count = 0
                node.last_error = None

            # Record metrics
            end_time = time.time()
            duration = end_time - start_time
            node.metrics.record_execution(duration)

            # Checkpoint if enabled
            if self.enable_checkpointing and self.checkpoint_callback:
                self.checkpoint_callback(node)

        except Exception as e:
            end_time = time.time()
            duration = end_time - start_time
            node.metrics.record_error()
            success = False
            error_msg = str(e)

            with node.lock:
                node.last_error = e
                node.retry_count += 1

            logger.error(f"Error executing {node.name}: {e}")

            # Retry logic
            if node.retry_count < node.max_retries:
                logger.warning(
                    f"Retrying {node.name} (attempt {node.retry_count + 1}/{node.max_retries})"
                )
                return self.execute_node(node, level)
            else:
                logger.error(f"Max retries exceeded for {node.name}")
                raise

        # Record execution trace
        if self.enable_profiling:
            trace = ExecutionTrace(
                node_name=node.name,
                node_type=node.definition.type_name,
                start_time=start_time - self.execution_start_time,
                end_time=end_time - self.execution_start_time,
                duration=duration,
                level=level,
                thread_id=thread_id,
                success=success,
                error=error_msg,
            )
            with self.traces_lock:
                self.execution_traces.append(trace)

        return node

    def execute_level_parallel(self, level: List[NodeInstance], level_idx: int):
        """Execute all nodes in a level using either sequential or parallel execution."""
        dirty_nodes = [node for node in level if node.is_dirty]

        if not dirty_nodes:
            return

        level_start = time.time()

        if self.mode == ExecutionMode.SEQUENTIAL:
            # Sequential execution
            logger.info(f"Executing level with {len(dirty_nodes)} node(s) sequentially")
            completed = 0
            for node in dirty_nodes:
                try:
                    self.execute_node(node, level_idx)
                    completed += 1
                except Exception as e:
                    logger.error(f"Failed to execute {node.name}: {e}")
                    raise
        else:
            # Parallel execution with ThreadPoolExecutor
            logger.info(f"Executing level with {len(dirty_nodes)} node(s) in parallel")
            with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
                futures = {
                    executor.submit(self.execute_node, node, level_idx): node
                    for node in dirty_nodes
                }

                completed = 0
                for future in as_completed(futures):
                    node = futures[future]
                    try:
                        future.result()
                        completed += 1
                    except Exception as e:
                        logger.error(f"Failed to execute {node.name}: {e}")
                        raise

        level_time = time.time() - level_start
        self.level_execution_times.append(level_time)
        logger.info(
            f"Level completed in {level_time:.3f}s ({completed}/{len(dirty_nodes)} nodes)"
        )

    def get_execution_stats(self) -> Dict[str, Any]:
        """Get comprehensive execution statistics."""
        node_stats = []
        for node in self.nodes:
            node_stats.append(
                {
                    "name": node.name,
                    "type": node.definition.type_name,
                    "execution_count": node.metrics.execution_count,
                    "avg_time_ms": node.metrics.avg_execution_time * 1000,
                    "error_count": node.metrics.error_count,
                }
            )

        return {
            "total_execution_time": self.total_execution_time,
            "level_count": len(self.execution_order),
            "level_execution_times": self.level_execution_times,
            "node_stats": node_stats,
            "total_nodes": len(self.nodes),
            "execution_traces": self.execution_traces,
        }

    def run(self):
        """Run execution with comprehensive logging and metrics."""
        self.execution_start_time = time.time()
        execution_start = self.execution_start_time

        logger.info("=" * 60)
        logger.info("Starting execution engine")
        logger.info(f"Mode: {self.mode.value}, Max workers: {self.max_workers}")
        logger.info("=" * 60)

        # Compute execution order
        self.execution_order = self.topological_sort()

        logger.info(f"Execution plan ({len(self.execution_order)} levels):")
        for i, level in enumerate(self.execution_order):
            node_names = [node.name for node in level]
            logger.info(f"  Level {i}: {node_names}")

        # Execute each level
        for level_idx, level in enumerate(self.execution_order):
            logger.info(f"\n{'='*60}")
            logger.info(f"Executing Level {level_idx}")
            logger.info(f"{'='*60}")
            self.execute_level_parallel(level, level_idx)

        self.total_execution_time = time.time() - execution_start

        logger.info("\n" + "=" * 60)
        logger.info(f"Execution completed in {self.total_execution_time:.3f}s")
        logger.info("=" * 60)

        return self.get_execution_stats()

    def plot_execution_profile(self, output_file: str = "execution_profile.png"):
        """Generate execution profile visualization as a Gantt chart."""
        if not self.execution_traces:
            logger.warning("No execution traces available for profiling")
            return

        fig, ax = plt.subplots(figsize=(14, max(8, len(self.execution_traces) * 0.4)))

        # Define colors for different node types
        type_colors = {
            "CONST": "#3498db",  # Blue
            "ADD": "#2ecc71",  # Green
            "MULTIPLY": "#e74c3c",  # Red
            "BASE": "#95a5a6",  # Gray
        }

        # Sort traces by start time for better visualization
        sorted_traces = sorted(self.execution_traces, key=lambda t: t.start_time)

        # Create y-axis positions
        y_positions = {}
        current_y = 0

        for trace in sorted_traces:
            if trace.node_name not in y_positions:
                y_positions[trace.node_name] = current_y
                current_y += 1

        # Plot each execution trace as a horizontal bar
        for trace in sorted_traces:
            y_pos = y_positions[trace.node_name]
            color = type_colors.get(trace.node_type, "#95a5a6")

            # Add alpha for failed executions
            alpha = 0.4 if not trace.success else 0.8

            ax.barh(
                y_pos,
                trace.duration,
                left=trace.start_time,
                height=0.8,
                color=color,
                alpha=alpha,
                edgecolor="black",
                linewidth=0.5,
            )

            # Add duration label on the bar
            if trace.duration > 0.001:  # Only show label if bar is wide enough
                ax.text(
                    trace.start_time + trace.duration / 2,
                    y_pos,
                    f"{trace.duration*1000:.1f}ms",
                    ha="center",
                    va="center",
                    fontsize=8,
                    fontweight="bold",
                    color="white" if trace.duration > 0.02 else "black",
                )

        # Customize the plot
        ax.set_yticks(range(len(y_positions)))
        ax.set_yticklabels([name for name in y_positions.keys()])
        ax.set_xlabel("Time (seconds)", fontsize=12, fontweight="bold")
        ax.set_ylabel("Node Name", fontsize=12, fontweight="bold")
        ax.set_title(
            "Execution Profile - Parallel Node Execution Timeline",
            fontsize=14,
            fontweight="bold",
            pad=20,
        )

        # Add grid for better readability
        ax.grid(True, axis="x", alpha=0.3, linestyle="--")
        ax.set_axisbelow(True)

        # Create legend
        legend_elements = [
            mpatches.Patch(color=color, label=node_type, alpha=0.8)
            for node_type, color in type_colors.items()
            if any(t.node_type == node_type for t in sorted_traces)
        ]
        ax.legend(
            handles=legend_elements,
            loc="lower right",
            title="Node Types",
            framealpha=0.5,
        )

        # Add execution summary text
        total_time = max(t.end_time for t in sorted_traces)
        summary_text = f"Total Execution Time: {total_time:.3f}s\n"
        summary_text += f"Nodes Executed: {len(sorted_traces)}\n"
        summary_text += (
            f"Parallelization: {self.mode.value} ({self.max_workers} workers)"
        )

        ax.text(
            0.02,
            0.98,
            summary_text,
            transform=ax.transAxes,
            fontsize=9,
            verticalalignment="top",
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
        )

        plt.tight_layout()
        plt.savefig(output_file, dpi=150, bbox_inches="tight")
        logger.info(f"Execution profile saved to {output_file}")
        plt.close()

    def export_graph(self, filepath: str):
        """Export graph structure to JSON for visualization/debugging."""
        graph_data = {"nodes": [node.to_json() for node in self.nodes], "edges": []}

        for node in self.nodes:
            for child in node.children:
                graph_data["edges"].append({"from": node.node_id, "to": child.node_id})

        with open(filepath, "w") as f:
            json.dump(graph_data, f, indent=2)

        logger.info(f"Graph exported to {filepath}")


# -------------------------------------------------------------------
# Example Usage
# -------------------------------------------------------------------
if __name__ == "__main__":

    # Create a more complex graph to demonstrate parallelization
    constA = NodeInstance("ConstA", NodeRegistry.create_definition("CONST"), "const_a")
    constA.outputs["output1"] = 10

    constB = NodeInstance("ConstB", NodeRegistry.create_definition("CONST"), "const_b")
    constB.outputs["output1"] = 4

    constC = NodeInstance("ConstC", NodeRegistry.create_definition("CONST"), "const_c")
    constC.outputs["output1"] = 3

    constD = NodeInstance("ConstD", NodeRegistry.create_definition("CONST"), "const_d")
    constD.outputs["output1"] = 10

    add1 = NodeInstance("Add1", NodeRegistry.create_definition("ADD"), "add_1")
    add1.set_input(1, constA, 1)
    add1.set_input(2, constB, 1)

    multiply1 = NodeInstance(
        "Multiply1", NodeRegistry.create_definition("MULTIPLY"), "mult_1"
    )
    multiply1.set_input(1, constC, 1)
    multiply1.set_input(2, constD, 1)

    add2 = NodeInstance("Add2", NodeRegistry.create_definition("ADD"), "add_2")
    add2.set_input(1, add1, 1)
    add2.set_input(2, multiply1, 1)

    nodes = [constA, constB, constC, constD, add1, multiply1, add2]

    # Run execution engine with ThreadPoolExecutor (FIXED)
    engine = ExecutionEngine(
        nodes,
        max_workers=4,
        mode=ExecutionMode.THREAD,  # Changed from PROCESS to THREAD
        # mode=ExecutionMode.SEQUENTIAL,  # Changed from PROCESS to THREAD
        enable_checkpointing=False,
        enable_profiling=True,
    )

    stats = engine.run()

    # Print execution statistics
    print("\n" + "=" * 60)
    print("EXECUTION STATISTICS")
    print("=" * 60)
    print(f"Total time: {stats['total_execution_time']:.3f}s")
    print(f"Levels: {stats['level_count']}")
    print(f"\nNode Statistics:")
    for node_stat in stats["node_stats"]:
        print(
            f"  {node_stat['name']}: {node_stat['execution_count']} exec, "
            f"{node_stat['avg_time_ms']:.2f}ms avg"
        )

    print("\n" + "=" * 60)
    print("FINAL RESULTS")
    print("=" * 60)
    print(f"Add1 result: {add1.outputs['output1']} (10 + 4)")
    print(f"Multiply1 result: {multiply1.outputs['output1']} (3 * 10)")
    print(f"Add2 result: {add2.outputs['output1']} (14 + 30)")

    # Export graph for debugging
    engine.export_graph("execution_graph.json")
    print("\nGraph exported to execution_graph.json")

    # Generate execution profile visualization
    engine.plot_execution_profile("execution_profile.png")
    print("Execution profile chart saved to execution_profile.png")

2025-11-24 05:08:32,460 - MainThread - INFO - Registered CONST v1.0.0
2025-11-24 05:08:32,460 - MainThread - INFO - Registered ADD v1.0.0
2025-11-24 05:08:32,461 - MainThread - INFO - Registered MULTIPLY v1.0.0
2025-11-24 05:08:32,463 - MainThread - INFO - Starting execution engine
2025-11-24 05:08:32,463 - MainThread - INFO - Mode: thread, Max workers: 4
2025-11-24 05:08:32,464 - MainThread - INFO - Topological sort complete: 3 levels, 7 nodes
2025-11-24 05:08:32,464 - MainThread - INFO - Execution plan (3 levels):
2025-11-24 05:08:32,464 - MainThread - INFO -   Level 0: ['ConstA', 'ConstB', 'ConstC', 'ConstD']
2025-11-24 05:08:32,465 - MainThread - INFO -   Level 1: ['Add1', 'Multiply1']
2025-11-24 05:08:32,465 - MainThread - INFO -   Level 2: ['Add2']
2025-11-24 05:08:32,465 - MainThread - INFO - 
2025-11-24 05:08:32,465 - MainThread - INFO - Executing Level 0
2025-11-24 05:08:32,465 - MainThread - INFO - Executing level with 4 node(s) in parallel
2025-11-24 05:08:32,466 - ThreadPoo


EXECUTION STATISTICS
Total time: 0.192s
Levels: 3

Node Statistics:
  ConstA: 1 exec, 3.52ms avg
  ConstB: 1 exec, 2.79ms avg
  ConstC: 1 exec, 2.65ms avg
  ConstD: 1 exec, 2.04ms avg
  Add1: 1 exec, 55.89ms avg
  Multiply1: 1 exec, 85.99ms avg
  Add2: 1 exec, 96.02ms avg

FINAL RESULTS
Add1 result: 14 (10 + 4)
Multiply1 result: 30 (3 * 10)
Add2 result: 44 (14 + 30)

Graph exported to execution_graph.json
Execution profile chart saved to execution_profile.png
