# PyDart Library – Second Testing Checkpoint

## Overview

In this iteration, I addressed retrieval issues from the previous checkpoint and began exploring DAG-based approaches for more accurate task execution representation.

## Main Contributions

1. **Retrieval Problem Resolution**  
   - Fixed missing input errors identified in the last iteration.  
   - Ensured proper retrieval of necessary data before execution.  

2. **Validation of Output Generation**  
   - Verified whether retrieval was functioning correctly.  
   - Checked if outputs were being generated or missing due to errors in execution stages.  
   - Identified that, in most cases, errors in running stages prevented output generation.  

3. **Shift Towards DAG-Based Approaches**  
   - Realized that simply splitting the forward pass was insufficient.  
   - Began investigating DAG-based execution to better represent real DNN task execution.  

## Iterative Development Process

As with previous phases, this checkpoint involved multiple iterations while developing the required classes. The earlier formats of the classes was maintained throughput the process.

---

**Note**: Multiple iterations were performed during the development of these classes. The key checkpoints included here highlight the most significant developments. Subsequent iterations followed a similar approach.


In [None]:
# -*- coding: utf-8 -*-
"""Unified_Taskset_Evaluation.ipynb

This script implements a single Taskset approach, utilizing single tensor inputs for both profiling and evaluation phases.

"""

import networkx as nx
import os
import torch
import threading
import queue
import time
import re
import pandas as pd
import torch.nn as nn
import torch.profiler
from torch.utils.data import DataLoader, TensorDataset
from typing import Callable, Any, List, Dict, Optional
import copy
import math
import matplotlib.pyplot as plt


# --- Node Class (Unchanged) ---
class Node:
    """
    Represents a computational resource: either a CPU-only node (1 CPU core)
    or a GPU+CPU pair. Each Node has:
      - node_id (e.g., 'CPU-0', 'GPU-0-CPU-1')
      - A worker thread + queue to run tasks
    """

    def __init__(self, node_id: str, cpus=None, gpu=None):
        self._node_id = node_id
        self._cpus = tuple(cpus or [])
        self._gpu = gpu

        self._original_affinity = os.sched_getaffinity(0)
        self._task_queue = queue.Queue()
        self._stop_signal = False

        self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
        self._worker_thread.start()

        self.current_load = 0.0  # Initialize current load
        self.assigned_stages = []  # List to track assigned stages

    @property
    def node_id(self):
        return self._node_id

    @property
    def cpus(self):
        return self._cpus

    @property
    def gpu(self):
        return self._gpu

    def assign_task(self, func: Callable, *args, **kwargs) -> queue.Queue:
        """
        Enqueue a function to this node. Returns a queue from which
        the caller can retrieve the result (blocking).
        """
        result_queue = queue.Queue(maxsize=1)
        self._task_queue.put((func, args, kwargs, result_queue))
        return result_queue

    def stop(self):
        """
        Signal the node to stop after processing queued tasks.
        """
        self._stop_signal = True
        self._task_queue.put(None)
        self._worker_thread.join()

    def _worker_loop(self):
        while not self._stop_signal:
            item = self._task_queue.get()
            if item is None:
                break
            func, args, kwargs, result_queue = item
            try:
                self._set_context()
                result = func(*args, **kwargs)
            except Exception as e:
                result = e
            finally:
                self._reset_context()

            result_queue.put(result)

    def _set_context(self):
        if self._cpus:
            os.sched_setaffinity(0, self._cpus)
        if self._gpu is not None and torch.cuda.is_available():
            torch.cuda.set_device(self._gpu)

    def _reset_context(self):
        os.sched_setaffinity(0, self._original_affinity)
        # Optionally reset GPU device if needed

    @staticmethod
    def discover_nodes() -> List['Node']:
        """
        Create a Node for each CPU core, and for each GPU+CPU pair.
        """
        nodes = []
        num_cpus = os.cpu_count() or 1
        ngpus = torch.cuda.device_count()

        # CPU-only nodes
        for core_id in range(num_cpus):
            node = Node(node_id=f"CPU-{core_id}", cpus=[core_id])
            nodes.append(node)

        # GPU+CPU nodes
        for g in range(ngpus):
            for core_id in range(num_cpus):
                node = Node(node_id=f"GPU-{g}-CPU-{core_id}", cpus=[core_id], gpu=g)
                nodes.append(node)

        return nodes

    def __repr__(self):
        return f"Node({self._node_id}, cpus={self._cpus}, gpu={self._gpu})"

# --- Profiler Class (Modified) ---
class Profiler:
    """
    In 'init' mode: Gather detailed profiling info for each leaf layer on each Node,
    storing results in a CSV-based ProfileDB.
    In 'runtime' mode: Potentially gather minimal logs (optional).
    """

    def __init__(self, mode: str, profile_db_path='profiling_results.csv', log_dir='logs'):
        assert mode in ['init', 'runtime']
        self.mode = mode
        self.profile_db_path = profile_db_path
        self.log_dir = log_dir
        os.makedirs(self.log_dir, exist_ok=True)

        columns = [
            'Model', 'Layer', 'Compute',
            'Self CPU (us)', 'CPU Total (us)', 'CUDA Total (us)',
            'Self CPU Mem (bytes)', 'Self CUDA Mem (bytes)',
            'Total Execution Time (us)', 'Total Memory Used (bytes)'
        ]
        if os.path.exists(self.profile_db_path):
            self.profile_db = pd.read_csv(self.profile_db_path)
        else:
            self.profile_db = pd.DataFrame(columns=columns)

        self.runtime_csv = os.path.join(self.log_dir, 'runtime_results.csv')
        if not os.path.exists(self.runtime_csv):
            rt_cols = ['Model', 'Layer', 'Compute', 'Execution Time (us)']
            pd.DataFrame(columns=rt_cols).to_csv(self.runtime_csv, index=False)

    def _register_hooks(self, model: nn.Module):
        def hook_wrapper(layer_name):
            def hook(mod, inp, out):
                with torch.profiler.record_function(layer_name):
                    pass
            return hook

        for idx, (name, layer) in enumerate(model.named_modules()):
            if not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.ModuleList) and layer != model:
                layer.register_forward_hook(hook_wrapper(f"{name}_{idx}"))

    def profile_model(self, model: nn.Module, input_data: Any, node, model_name: str, warmup_iters=3):
        """
        Schedule a profiling task on 'node'. In 'init' mode, we gather
        full per-layer times.
        """
        def profiling_task():
            device = torch.device(f"cuda:{node.gpu}" if node.gpu is not None and torch.cuda.is_available() else "cpu")
            model.to(device)

            if self.mode == 'init':
                # Warmup
                with torch.no_grad():
                    for _ in range(warmup_iters):
                        model(input_data.to(device))
                self._profile_init(model, input_data, node, model_name, device)
            else:
                self._profile_runtime(model, input_data, node, model_name, device)

        rq = node.assign_task(profiling_task)
        rq.get()  # block

    def _profile_init(self, model, input_data, node, model_name, device):
        self._register_hooks(model)
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
            profile_memory=True
        ) as prof:
            with torch.no_grad():
                model(input_data.to(device))
                prof.step()

        stats = self._process_events(prof, model, node, runtime=False)
        self._update_profile_db(stats, model_name, node, runtime=False)

    def _profile_runtime(self, model, input_data, node, model_name, device):
        self._register_hooks(model)
        with torch.no_grad():
            with torch.profiler.profile(
                activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
            ) as prof:
                model(input_data.to(device))
                prof.step()
            stats = self._process_events(prof, model, node, runtime=True)
            self._append_runtime_csv(stats, model_name, node)

    def _process_events(self, profiler, model, node, runtime=False):
        recognized = set()
        for n, m in model.named_modules():
            if n:
                recognized.add(n)

        aggregated = {
            'forward_pass': dict(self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                                 self_cpu_memory_usage=0, self_cuda_memory_usage=0, compute=node.node_id),
            'misc': dict(self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                         self_cpu_memory_usage=0, self_cuda_memory_usage=0, compute=node.node_id)
        }

        events = list(profiler.events())
        found_root = False

        def strip_suffix(s):
            return re.sub(r'(\.|_)\d+$', '', s)

        for e in events:
            if e.name == "":
                found_root = True
                aggregated['forward_pass']['self_cpu_time_total'] += e.self_cpu_time_total
                aggregated['forward_pass']['cpu_time_total'] += e.cpu_time_total
                aggregated['forward_pass']['cuda_time_total'] += e.device_time_total
                if not runtime:
                    aggregated['forward_pass']['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                    aggregated['forward_pass']['self_cuda_memory_usage'] += e.self_device_memory_usage  # Updated
            else:
                base = strip_suffix(e.name)
                if base in recognized:
                    if base not in aggregated:
                        aggregated[base] = dict(
                            self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                            self_cpu_memory_usage=0, self_cuda_memory_usage=0,
                            compute=node.node_id
                        )
                    aggregated[base]['self_cpu_time_total'] += e.self_cpu_time_total
                    aggregated[base]['cpu_time_total'] += e.cpu_time_total
                    aggregated[base]['cuda_time_total'] += e.device_time_total
                    if not runtime:
                        aggregated[base]['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                        aggregated[base]['self_cuda_memory_usage'] += e.self_device_memory_usage
                else:
                    aggregated['misc']['self_cpu_time_total'] += e.self_cpu_time_total
                    aggregated['misc']['cpu_time_total'] += e.cpu_time_total
                    aggregated['misc']['cuda_time_total'] += e.device_time_total
                    if not runtime:
                        aggregated['misc']['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                        aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated

        # If no root event found, sum all into forward_pass
        if not found_root:
            for k in list(aggregated.keys()):
                if k not in ('forward_pass', 'misc'):
                    aggregated['forward_pass']['self_cpu_time_total'] += aggregated[k]['self_cpu_time_total']
                    aggregated['forward_pass']['cpu_time_total'] += aggregated[k]['cpu_time_total']
                    aggregated['forward_pass']['cuda_time_total'] += aggregated[k]['cuda_time_total']
                    if not runtime:
                        aggregated['forward_pass']['self_cpu_memory_usage'] += aggregated[k]['self_cpu_memory_usage']
                        aggregated['forward_pass']['self_cuda_memory_usage'] += aggregated[k]['self_cuda_memory_usage']

            aggregated['forward_pass']['self_cpu_time_total'] += aggregated['misc']['self_cpu_time_total']
            aggregated['forward_pass']['cpu_time_total'] += aggregated['misc']['cpu_time_total']
            aggregated['forward_pass']['cuda_time_total'] += aggregated['misc']['cuda_time_total']
            if not runtime:
                aggregated['forward_pass']['self_cpu_memory_usage'] += aggregated['misc']['self_cpu_memory_usage']
                aggregated['forward_pass']['self_cuda_memory_usage'] += aggregated['misc']['self_cuda_memory_usage']

        return aggregated

    def _update_profile_db(self, stats, model_name, node, runtime=False):
        if runtime:
            return
        for layer_name, data in stats.items():
            total_t = data['cpu_time_total'] + data['cuda_time_total']
            total_m = data['self_cpu_memory_usage'] + data['self_cuda_memory_usage']
            row = {
                'Model': model_name,
                'Layer': layer_name,
                'Compute': data['compute'],
                'Self CPU (us)': data['self_cpu_time_total'],
                'CPU Total (us)': data['cpu_time_total'],
                'CUDA Total (us)': data['cuda_time_total'],
                'Self CPU Mem (bytes)': data['self_cpu_memory_usage'],
                'Self CUDA Mem (bytes)': data['self_cuda_memory_usage'],
                'Total Execution Time (us)': total_t * 1_000_000,  # Convert to microseconds
                'Total Memory Used (bytes)': total_m
            }
            self.profile_db = self._upsert(self.profile_db, row)
        self.profile_db.to_csv(self.profile_db_path, index=False)

    def _upsert(self, df, row):
        mask = (
            (df['Model'] == row['Model']) &
            (df['Layer'] == row['Layer']) &
            (df['Compute'] == row['Compute'])
        )
        if not df[mask].empty:
            existing_time = df.loc[mask, 'Total Execution Time (us)'].max()
            if row['Total Execution Time (us)'] > existing_time:
                for k, v in row.items():
                    df.loc[mask, k] = v
        else:
            new_row = pd.DataFrame([row])
            if not new_row.dropna().empty:
                df = pd.concat([df, new_row], ignore_index=True)  # Fixed FutureWarning by checking non-empty
        return df

    def _append_runtime_csv(self, stats, model_name, node):
        rows = []
        for layer_name, data in stats.items():
            exec_time = data['cpu_time_total'] + data['cuda_time_total']
            rows.append({
                'Model': model_name,
                'Layer': layer_name,
                'Compute': data['compute'],
                'Execution Time (us)': exec_time * 1_000_000  # Convert to microseconds
            })
        if rows:
            rdf = pd.read_csv(self.runtime_csv)
            rdf = pd.concat([rdf, pd.DataFrame(rows)], ignore_index=True)
            rdf.to_csv(self.runtime_csv, index=False)

    def get_profile_db(self):
        return self.profile_db

    def print_profile_db(self):
        if self.profile_db.empty:
            print("ProfileDB is empty.")
        else:
            print("ProfileDB:\n", self.profile_db.to_string(index=False))

# --- Stage Class (Unchanged) ---
class Stage:
    """
    Represents a partitioned segment of a model, assigned to a specific Node.

    Attributes:
        stage_id (str): Unique identifier for the stage.
        layers (nn.ModuleList): The layers assigned to this stage.
        assigned_node (Node): The Node responsible for executing this stage.
        dependencies (List[str]): List of stage_ids that this stage depends on.
        dependents (List[str]): List of stage_ids that depend on this stage.
        execution_time (Optional[float]): Time taken to execute this stage.
        input_data (Optional[torch.Tensor]): Input tensor for this stage.
        output_data (Optional[torch.Tensor]): Output tensor from this stage.
        transfer_time (float): Time spent on data transfers (in seconds).
        task (Task): Reference to the Task instance this stage belongs to.
    """

    def __init__(self, stage_id: str, layers: nn.ModuleList, assigned_node: 'Node', task: 'Task'):
        """
        Initializes the Stage object.

        Args:
            stage_id (str): Unique identifier for the stage.
            layers (nn.ModuleList): The layers assigned to this stage.
            assigned_node (Node): The Node responsible for executing this stage.
            task (Task): The Task instance this stage belongs to.
        """
        self.stage_id = stage_id
        self.layers = layers
        self.assigned_node = assigned_node

        self.dependencies: List[str] = []
        self.dependents: List[str] = []

        self.execution_time: Optional[float] = None
        self.input_data: Optional[torch.Tensor] = None
        self.output_data: Optional[torch.Tensor] = None

        self.transfer_time: float = 0.0

        self.task = task  # Reference to the Task

    def add_dependency(self, stage_id: str):
        """
        Adds a dependency to this stage.

        Args:
            stage_id (str): The stage_id that this stage depends on.
        """
        self.dependencies.append(stage_id)

    def add_dependent(self, stage_id: str):
        """
        Adds a dependent to this stage.

        Args:
            stage_id (str): The stage_id that depends on this stage.
        """
        self.dependents.append(stage_id)

    def run_stage(self):
        """
        Executes the stage's layers on the assigned node.
        """
        start_time = time.time()
        transfer_start = time.time()
        try:
            device = torch.device(
                f"cuda:{self.assigned_node.gpu}" if self.assigned_node.gpu is not None and torch.cuda.is_available()
                else "cpu"
            )

            if self.input_data is None:
                print(f"[Stage] {self.stage_id}: No input data provided. Executing with empty tensor.")
                out = torch.tensor([])
                transfer_end = time.time()
                self.transfer_time += transfer_end - transfer_start
            else:
                inp = self.input_data.to(device)
                transfer_end = time.time()
                self.transfer_time += transfer_end - transfer_start

                with torch.no_grad():
                    out = inp
                    for layer in self.layers:
                        out = layer(out)

                if device.type == 'cuda':
                    transfer_start = time.time()
                    out = out.cpu()
                    transfer_end = time.time()
                    self.transfer_time += transfer_end - transfer_start

            self.output_data = out

        except Exception as e:
            print(f"[Stage] {self.stage_id}: Error during execution: {e}")
            self.output_data = None
        finally:
            end_time = time.time()
            self.execution_time = end_time - start_time
            print(f"[Stage] {self.stage_id}: Executed on {self.assigned_node.node_id} in {self.execution_time:.6f} seconds. Transfer Time: {self.transfer_time:.6f} seconds.")

            # Update Task's busy time with both execution and transfer times
            self.task.update_busy_time(self.execution_time, self.transfer_time)

            # If this is the final stage, set the Task's output data
            if not self.dependents:
                self.task.set_output_data(self.output_data)

            # Notify Scheduler of stage completion
            self.task.scheduler.stage_completed(self.stage_id)

        return self.output_data

    def __repr__(self):
        return (f"Stage(stage_id={self.stage_id}, assigned_node={self.assigned_node.node_id}, "
                f"dependencies={self.dependencies}, dependents={self.dependents}, "
                f"execution_time={self.execution_time:.6f if self.execution_time else 'N/A'}, "
                f"transfer_time={self.transfer_time:.6f}, output_data_present={self.output_data is not None})")

    def __deepcopy__(self, memo):
        """
        Create a new Stage, copying only what we need, while referencing the same Node.
        """
        # 1) Create the new Stage without doing a deepcopy on 'assigned_node'
        new_stage = Stage(
            stage_id = copy.deepcopy(self.stage_id, memo),
            layers   = copy.deepcopy(self.layers, memo),
            assigned_node = self.assigned_node,  # <-- same Node object
            task = None  # or self.task if you want the same reference or plan to reassign later
        )

        # 2) Copy lists and basic attributes
        new_stage.dependencies   = copy.deepcopy(self.dependencies, memo)
        new_stage.dependents     = copy.deepcopy(self.dependents, memo)
        new_stage.execution_time = self.execution_time
        new_stage.transfer_time  = self.transfer_time

        # 3) Decide whether to copy input_data/output_data
        #    For final input/output analysis, you might want to do the following:
        new_stage.input_data  = copy.deepcopy(self.input_data, memo)
        new_stage.output_data = copy.deepcopy(self.output_data, memo)
        #
        # But if your pipeline is re-initialized or you always supply fresh input_data, do this:
        # new_stage.input_data  = None
        # new_stage.output_data = None

        return new_stage


# --- Task Class (Unchanged) ---
class Task:
    """
    Represents a single DNN inference task.

    Attributes:
        task_id (str): Unique identifier for the task.
        model (nn.Module): The DNN model to be executed.
        input_data (torch.Tensor): The input tensor for the model.
        model_name (str): Name of the model (used for profiling).
        stages (Dict[str, Stage]): Dictionary of Stage objects representing the task's execution stages.
        scheduler (Scheduler): Reference to the Scheduler handling this task.
        start_time (Optional[float]): Timestamp when the task started execution.
        finish_time (Optional[float]): Timestamp when the task finished execution.
        output_data (Optional[torch.Tensor]): The final output tensor after executing all stages.
        busy_time (float): Total time spent executing stages (in seconds).
        computation_time (float): Total time spent on computations (in seconds).
        transfer_time (float): Total time spent on data transfers (in seconds).
    """

    def __init__(self, task_id: str, model: nn.Module, input_data: torch.Tensor, model_name: str, scheduler: 'Scheduler'):
        """
            Initializes the Task object.

            Args:
                task_id (str): Unique identifier for the task.
                model (nn.Module): The DNN model to be executed.
                input_data (torch.Tensor): The input tensor for the model.
                model_name (str): Name of the model (used for profiling).
                scheduler (Scheduler): Reference to the Scheduler handling this task.
        """
        self.task_id = task_id
        self.model = model
        self.input_data = input_data
        self.model_name = model_name

        self.scheduler = scheduler  # Reference to the Scheduler

        self.stages: Dict[str, Stage] = {}

        self.start_time: Optional[float] = None
        self.finish_time: Optional[float] = None

        self.output_data: Optional[torch.Tensor] = None
        self.busy_time: float = 0.0
        self.computation_time: float = 0.0
        self.transfer_time: float = 0.0

    def add_stage(self, stage: 'Stage'):
        """
        Adds a Stage object to the task.

        Args:
            stage (Stage): The Stage object to be added.
        """
        if stage.stage_id in self.stages:
            raise ValueError(f"Stage ID {stage.stage_id} already exists in Task {self.task_id}.")
        self.stages[stage.stage_id] = stage

    def get_stage(self, stage_id: str) -> Optional['Stage']:
        """
        Retrieves a Stage object by its stage_id.

        Args:
            stage_id (str): The unique identifier of the stage.

        Returns:
            Optional[Stage]: The Stage object if found, else None.
        """
        return self.stages.get(stage_id, None)

    def get_total_execution_time(self) -> float:
        """
        Calculates the total execution time of the task.

        Returns:
            float: Total execution time in seconds.
        """
        if self.start_time and self.finish_time:
            return self.finish_time - self.start_time
        return 0.0

    def update_busy_time(self, stage_execution_time: float, stage_transfer_time: float = 0.0):
        """
        Updates the cumulative busy time of the task.

        Args:
            stage_execution_time (float): Execution time of a stage in seconds.
            stage_transfer_time (float, optional): Transfer time of a stage in seconds. Defaults to 0.0.
        """
        self.busy_time += stage_execution_time
        self.transfer_time += stage_transfer_time
        self.computation_time += (stage_execution_time - stage_transfer_time)

    def set_output_data(self, output: torch.Tensor):
        """
        Sets the final output data of the task.

        Args:
            output (torch.Tensor): The final output tensor after all stages.
        """
        self.output_data = output
        self.finish_time = time.time()

    def __repr__(self):
        return (f"Task(task_id={self.task_id}, model_name={self.model_name}, "
                f"stages={list(self.stages.keys())}, "
                f"start_time={self.start_time}, finish_time={self.finish_time}, "
                f"busy_time={self.busy_time:.6f}, computation_time={self.computation_time:.6f}, "
                f"transfer_time={self.transfer_time:.6f}, output_data_present={self.output_data is not None})")

    def __deepcopy__(self, memo):
        # Create a new Task instance without the scheduler reference
        new_task = Task(
            task_id=copy.deepcopy(self.task_id, memo),
            model=copy.deepcopy(self.model, memo),
            input_data=copy.deepcopy(self.input_data, memo),
            model_name=copy.deepcopy(self.model_name, memo),
            scheduler=None  # Exclude scheduler to prevent deepcopy issues
        )
        # Deepcopy stages
        new_task.stages = copy.deepcopy(self.stages, memo)
        # Copy other attributes
        new_task.start_time = self.start_time
        new_task.finish_time = self.finish_time
        new_task.output_data = self.output_data
        new_task.busy_time = self.busy_time
        new_task.computation_time = self.computation_time
        new_task.transfer_time = self.transfer_time
        return new_task

# --- Taskset Class ---
class Taskset:
    """
    Manages a collection of Tasks and orchestrates their execution using the Scheduler.

    Attributes:
        tasks (List[Task]): A list of Task instances to be executed.
        scheduler (Scheduler): The Scheduler responsible for decomposing and allocating tasks.
        total_utilization (float): Overall resource utilization of the taskset.
        average_turnaround_time (float): Average turnaround time of all tasks.
        throughput (float): Number of tasks completed per unit time.
        makespan (float): Total time to complete all tasks.
        task_completion_rate (float): Ratio of completed tasks to total tasks.
        average_resource_utilization_per_node (Dict[str, float]): Average utilization per node.
    """

    def __init__(self, tasks: List['Task'], scheduler: 'Scheduler'):
        """
        Initializes the Taskset object.

        Args:
            tasks (List[Task]): A list of Task instances to be managed.
            scheduler (Scheduler): The Scheduler responsible for task decomposition and allocation.
        """
        self.tasks = tasks
        self.scheduler = scheduler

        # Performance Metrics
        self.total_utilization: float = 0.0
        self.average_turnaround_time: float = 0.0
        self.throughput: float = 0.0
        self.makespan: float = 0.0
        self.task_completion_rate: float = 0.0
        self.average_resource_utilization_per_node: Dict[str, float] = {}

    def schedule_all_tasks(self):
        """
        Decompose and allocate all tasks using the Scheduler.
        """
        for task in self.tasks:
            self.scheduler.decompose_and_allocate_task(task)

    def execute_all(self):
        """
        Execute all tasks in the Taskset concurrently using threading based on the precomputed allocation strategy.
        """
        threads = []
        for task in self.tasks:
            # thread = threading.Thread(target=self.scheduler.execute_task_with_graph, args=(task,))
            thread = threading.Thread(target=self.scheduler.execute_task, args=(task,))
            thread.start()
            threads.append(thread)

        for thread in threads:
            thread.join()

        self.calculate_metrics()

    def calculate_metrics(self):
        """
        Calculates all performance metrics for the taskset.
        """
        # Total Utilization
        total_busy_time = sum(task.busy_time for task in self.tasks)
        total_available_time = self.scheduler.observation_window * len(self.scheduler.nodes)
        self.total_utilization = total_busy_time / total_available_time if total_available_time > 0 else 0.0

        # Average Turnaround Time
        turnaround_times = [task.get_total_execution_time() for task in self.tasks]
        self.average_turnaround_time = sum(turnaround_times) / len(turnaround_times) if turnaround_times else 0.0

        # Throughput
        self.throughput = len(self.tasks) / self.makespan if self.makespan > 0 else 0.0

        # Makespan
        start_times = [task.start_time for task in self.tasks if task.start_time is not None]
        finish_times = [task.finish_time for task in self.tasks if task.finish_time is not None]
        if start_times and finish_times:
            earliest_start = min(start_times)
            latest_finish = max(finish_times)
            self.makespan = latest_finish - earliest_start
        else:
            self.makespan = 0.0

        # Task Completion Rate
        completed_tasks = [task for task in self.tasks if task.output_data is not None]
        self.task_completion_rate = len(completed_tasks) / len(self.tasks) if self.tasks else 0.0

        # Average Resource Utilization per Node
        node_utilization = {node.node_id: 0.0 for node in self.scheduler.nodes}
        for task in self.tasks:
            for stage in task.stages.values():
                node_id = stage.assigned_node.node_id
                node_utilization[node_id] += stage.busy_time
        for node in self.scheduler.nodes:
            total_node_time = self.scheduler.observation_window
            self.average_resource_utilization_per_node[node.node_id] = (
                node_utilization[node.node_id] / total_node_time if total_node_time > 0 else 0.0
            )

    def __repr__(self):
        return (
            f"Taskset(total_tasks={len(self.tasks)}, "
            f"total_utilization={self.total_utilization:.2%}, "
            f"average_turnaround_time={self.average_turnaround_time:.6f} seconds, "
            f"throughput={self.throughput:.2f} tasks/sec, "
            f"makespan={self.makespan:.6f} seconds, "
            f"task_completion_rate={self.task_completion_rate:.2%}, "
            f"average_resource_utilization_per_node={self.average_resource_utilization_per_node})"
        )

# --- Scheduler Class (Unchanged, but ensure execution graphs are built correctly) ---
class Scheduler:
    """
    Scheduler class responsible for decomposing tasks into stages,
    allocating stages to nodes using Dynamic Programming to minimize
    maximum node utilization, dispatching stages for execution,
    and managing dependencies between stages via an execution graph.
    """

    def __init__(
        self,
        nodes: List['Node'],
        profiler: 'Profiler',
        observation_window: float = 1000.0,
        load_metric_func: Optional[Callable[[float, float], float]] = None
    ):
        """
        Initializes the Scheduler.

        Args:
            nodes (List[Node]): List of available Node instances.
            profiler (Profiler): Profiler instance used for gathering execution times.
            observation_window (float, optional): Time window for utilization calculations. Defaults to 1000.0.
            load_metric_func (Callable, optional): User-defined load metric function. Defaults to None.
        """
        self.nodes = nodes
        self.profiler = profiler
        self.observation_window = observation_window
        self.load_metric = load_metric_func if load_metric_func else self.default_load_metric
        self.lock = threading.Lock()

        self.tasks: Dict[str, Task] = {}  # task_id -> Task
        self.stage_map: Dict[str, Stage] = {}  # stage_id -> Stage
        self.completed_stages: set = set()

        # Execution Graphs: task_id -> DiGraph
        self.execution_graphs: Dict[str, nx.DiGraph] = {}

    def default_load_metric(self, execution_time: float, observation_window: float) -> float:
        """
        Default load metric: execution_time divided by observation_window.

        Args:
            execution_time (float): Execution time of the stage.
            observation_window (float): Observation window.

        Returns:
            float: Utilization.
        """
        return execution_time / observation_window

    def decompose_and_allocate_task(self, task: Task):
        """
        Decomposes the Task into Stages, builds the execution graph, and allocates each Stage to a Node based on the DP allocation strategy.

        Args:
            task (Task): The Task instance to decompose and allocate.
        """
        with self.lock:
            self.tasks[task.task_id] = task
            task.start_time = time.time()

            print(f"[Scheduler] Starting decomposition and allocation for Task '{task.task_id}'.")

            # Decompose task into stages
            stages = self.decompose_task_into_stages(task)

            # Build execution graph for the task
            exec_graph = self.build_execution_graph(task)
            self.execution_graphs[task.task_id] = exec_graph
            print(f"[Scheduler] Built execution graph for Task '{task.task_id}'. Nodes: {exec_graph.number_of_nodes()}, Edges: {exec_graph.number_of_edges()}.")

            # Allocate stages to nodes with dynamic grouping
            allocation = self.dp_allocate(task, stages)

            # Assign stages to nodes and update node loads
            for stage_id, node in allocation.items():
                stage = task.get_stage(stage_id)
                stage.assigned_node = node
                self.stage_map[stage_id] = stage
                node.assigned_stages.append(stage_id)
                node.current_load += self.load_metric(
                    self.get_execution_time(stage, node),
                    self.observation_window
                )
                # Print allocation details
                print(f"[Scheduler] Allocated Stage '{stage_id}' to Node '{node.node_id}'.")

            # Perform grouping of allocated stages
            grouped_allocation = self.group_allocated_stages(task, allocation)

            print(f"[Scheduler] Completed allocation and grouping for Task '{task.task_id}'.")

    def decompose_task_into_stages(self, task: Task) -> List[Stage]:
        """
        Decomposes a Task into multiple Stages (groups of layers).

        Args:
            task (Task): The Task instance to decompose.

        Returns:
            List[Stage]: List of Stage instances.
        """
        stages = []
        for idx, (name, layer) in enumerate(task.model.named_children()):
            # Incorporate layer name into stage_id for easier profiling lookup
            stage_id = f"{task.task_id}-stage-{idx}_{name}"
            stage = Stage(stage_id=stage_id, layers=nn.ModuleList([layer]), assigned_node=None, task=task)
            task.add_stage(stage)
            stages.append(stage)
            # Print stage creation details
            print(f"[Scheduler] Created Stage '{stage_id}' with Layer '{name}'.")
        return stages

    def dp_allocate(self, task: Task, stages: List[Stage]) -> Dict[str, 'Node']:
        """
        Allocates stages to nodes using the Dynamic Programming algorithm to minimize maximum node utilization.
        Allows grouping of consecutive stages for better load balancing.

        Args:
            task (Task): The Task instance.
            stages (List[Stage]): List of Stage instances to allocate.

        Returns:
            Dict[str, Node]: Mapping from stage_id to Node.
        """
        num_stages = len(stages)
        num_nodes = len(self.nodes)

        # Initialize current node loads
        w = [node.current_load for node in self.nodes]

        # Initialize DP table
        M = [[math.inf for _ in range(num_nodes + 1)] for _ in range(num_stages + 1)]
        # Base cases
        for k in range(num_nodes + 1):
            M[0][k] = 0.0

        # Fill DP table
        for n in range(1, num_stages + 1):
            for k in range(1, num_nodes + 1):
                for x in range(0, n):
                    # Calculate total execution time for grouping stages x to n-1 on node k-1
                    grouped_execution_time = sum(
                        self.get_execution_time(stages[y], self.nodes[k - 1]) for y in range(x, n)
                    )
                    util_sum = self.load_metric(grouped_execution_time, self.observation_window)
                    current_max = max(M[x][k - 1], w[k - 1] + util_sum)
                    if current_max < M[n][k]:
                        M[n][k] = current_max

        # Backtrack to find allocation
        allocation = {}
        n = num_stages
        k = num_nodes

        while n > 0 and k > 0:
            for x in range(0, n):
                grouped_execution_time = sum(
                    self.get_execution_time(stages[y], self.nodes[k - 1]) for y in range(x, n)
                )
                util_sum = self.load_metric(grouped_execution_time, self.observation_window)
                current_max = max(M[x][k - 1], w[k - 1] + util_sum)
                if math.isclose(current_max, M[n][k], rel_tol=1e-6):
                    # Assign stages x to n-1 to node k-1
                    for y in range(x, n):
                        stage_id = stages[y].stage_id
                        allocation[stage_id] = self.nodes[k - 1]
                    n = x
                    k = k - 1
                    break

        # Print allocation details
        if allocation:
            print(f"[Scheduler] Allocation Mapping for Task '{task.task_id}':")
            current_node = None
            current_group = []
            # Sort allocation based on stage index for proper grouping
            sorted_allocation = sorted(allocation.items(), key=lambda x: int(x[0].split("-stage-")[1].split("_")[0]))
            for stage_id, node in sorted_allocation:
                if node != current_node:
                    if current_group:
                        print(f"  - Grouped Stages '{', '.join(current_group)}' allocated to Node '{current_node.node_id}'.")
                        current_group = []
                    current_node = node
                current_group.append(stage_id)
            if current_group:
                print(f"  - Grouped Stages '{', '.join(current_group)}' allocated to Node '{current_node.node_id}'.")
        else:
            print(f"[Scheduler] No allocation performed for Task '{task.task_id}'.")

        return allocation

    def group_allocated_stages(self, task: Task, allocation: Dict[str, 'Node']) -> Dict[str, 'Node']:
        """
        Groups allocated stages into a single parent stage by merging layers.

        Args:
            task (Task): The Task instance.
            allocation (Dict[str, 'Node']): Mapping from stage_id to Node.

        Returns:
            Dict[str, 'Node']: Updated mapping from new grouped stage_id to Node.
        """
        # Identify groups based on allocation
        groups = {}
        current_node = None
        current_group = []
        sorted_allocation = sorted(allocation.items(), key=lambda x: int(x[0].split("-stage-")[1].split("_")[0]))
        for stage_id, node in sorted_allocation:
            if node != current_node:
                if current_group:
                    if current_node.node_id not in groups:
                        groups[current_node.node_id] = []
                    groups[current_node.node_id].append(current_group.copy())
                    current_group = []
                current_node = node
                if current_node.node_id not in groups:
                    groups[current_node.node_id] = []
            current_group.append(stage_id)
        if current_group:
            if current_node.node_id not in groups:
                groups[current_node.node_id] = []
            groups[current_node.node_id].append(current_group.copy())

        # Merge groups into new parent stages
        new_allocation = {}
        for node_id, stage_groups in groups.items():
            for group in stage_groups:
                if len(group) == 1:
                    # Single stage, no need to merge
                    stage_id = group[0]
                    new_allocation[stage_id] = self.nodes[self.get_node_index(node_id)]
                else:
                    # Merge stages
                    new_stage_id = f"{group[0]}_to_{group[-1]}"
                    merged_layers = nn.ModuleList()
                    for stage_id in group:
                        stage = task.get_stage(stage_id)
                        merged_layers.extend(stage.layers)
                    # Create new merged stage
                    merged_stage = Stage(stage_id=new_stage_id, layers=merged_layers, assigned_node=self.nodes[self.get_node_index(node_id)], task=task)
                    # Update dependencies
                    first_stage = task.get_stage(group[0])
                    last_stage = task.get_stage(group[-1])

                    # Set dependencies
                    merged_stage.dependencies = first_stage.dependencies.copy()
                    for dep in merged_stage.dependencies:
                        dep_stage = task.get_stage(dep)
                        dep_stage.dependents.remove(group[0])
                        dep_stage.dependents.append(new_stage_id)

                    # Set dependents
                    merged_stage.dependents = last_stage.dependents.copy()
                    for dep in merged_stage.dependents:
                        dep_stage = task.get_stage(dep)
                        dep_stage.dependencies.remove(group[-1])
                        dep_stage.dependencies.append(new_stage_id)

                    # Add merged stage to task
                    task.add_stage(merged_stage)

                    # Remove old stages from task and execution graph
                    for stage_id in group:
                        del task.stages[stage_id]
                        if stage_id in self.stage_map:
                            del self.stage_map[stage_id]
                        if stage_id in self.execution_graphs[task.task_id]:
                            self.execution_graphs[task.task_id].remove_node(stage_id)
                        # Remove from node's assigned stages
                        node = self.nodes[self.get_node_index(node_id)]
                        if stage_id in node.assigned_stages:
                            node.assigned_stages.remove(stage_id)

                    # Assign merged stage to node
                    new_allocation[new_stage_id] = self.nodes[self.get_node_index(node_id)]
                    self.stage_map[new_stage_id] = merged_stage
                    node = self.nodes[self.get_node_index(node_id)]
                    node.assigned_stages.append(new_stage_id)
                    node.current_load += self.load_metric(
                        self.get_execution_time(merged_stage, node),
                        self.observation_window
                    )
                    print(f"[Scheduler] Merged Stages '{', '.join(group)}' into '{new_stage_id}' and allocated to Node '{node.node_id}'.")

        # Update execution graph to include merged stages
        self.update_execution_graph_after_grouping(task)

        return new_allocation

    def get_node_index(self, node_id: str) -> int:
        """
        Retrieves the index of a node based on its node_id.

        Args:
            node_id (str): The Node ID.

        Returns:
            int: Index of the node in the nodes list.
        """
        for idx, node in enumerate(self.nodes):
            if node.node_id == node_id:
                return idx
        raise ValueError(f"Node with ID '{node_id}' not found.")

    def update_execution_graph_after_grouping(self, task: Task):
        """
        Updates the execution graph after stages have been merged.

        Args:
            task (Task): The Task instance.
        """
        G = self.execution_graphs.get(task.task_id, None)
        if not G:
            print(f"[Scheduler] No execution graph found for Task '{task.task_id}' to update.")
            return

        # Rebuild the execution graph based on the updated stages
        G.clear()
        for stage_id, stage in task.stages.items():
            G.add_node(stage_id)
            for dep in stage.dependencies:
                G.add_edge(dep, stage_id)

        print(f"[Scheduler] Updated execution graph for Task '{task.task_id}'. Nodes: {G.number_of_nodes()}, Edges: {G.number_of_edges()}.")

    def visualize_execution_graph(self, task_id: str):
        """
        Visualizes the execution graph for a given Task.

        Args:
            task_id (str): The Task ID whose execution graph is to be visualized.
        """
        if task_id not in self.execution_graphs:
            print(f"[Scheduler] No execution graph found for Task '{task_id}'.")
            return

        G = self.execution_graphs[task_id]
        pos = nx.spring_layout(G)
        plt.figure(figsize=(8, 6))
        nx.draw(
            G, pos, with_labels=True,
            node_color='lightblue', edge_color='gray',
            node_size=2000, font_size=10, arrows=True
        )
        plt.title(f"Execution Graph for Task '{task_id}'")
        plt.show()

    def get_execution_time(self, stage: Stage, node: 'Node') -> float:
        """
        Retrieves the execution time of a stage on a node by querying the Profiler's profile database.

        Args:
            stage (Stage): The Stage instance.
            node (Node): The Node instance.

        Returns:
            float: Execution time in seconds. Returns observation_window if not found.
        """
        # Extract layer names from stage_id
        try:
            # print(stage.stage_id.split('_')[-1])
            layer_info = stage.stage_id.split('_',1)[-1]
            # Handle grouped stages by splitting '_to_'
            layer_names = layer_info.split("_to_")
            print(layer_names)
        except ValueError:
            print(f"[Scheduler] Warning: Unable to parse layer names from stage_id '{stage.stage_id}'. Assigning max execution time.")
            return self.observation_window

        # Sum execution times of all layers in the group
        total_exec_time = 0.0
        for layer_name in layer_names:
            # Remove stage identifier suffix if present
            base_layer_name = re.sub(r'(\.|_)\d+$', '', layer_name)
            query = (
                (self.profiler.profile_db['Model'] == stage.task.model_name) &
                (self.profiler.profile_db['Layer'] == base_layer_name) &
                (self.profiler.profile_db['Compute'] == node.node_id)
            )
            execution_time_us = self.profiler.profile_db.loc[query, 'Total Execution Time (us)']
            if not execution_time_us.empty:
                exec_time = execution_time_us.values[0] / 1_000_000  # Convert microseconds to seconds
                total_exec_time += exec_time
                print(f"[Scheduler] Retrieved execution time for Layer '{base_layer_name}' on Node '{node.node_id}': {exec_time:.6f} seconds.")
            else:
                print(f"[Scheduler] Warning: No profiling data for Layer '{base_layer_name}' on Node '{node.node_id}'. Assigning max execution time ({self.observation_window} seconds).")
                total_exec_time += self.observation_window

        return total_exec_time

    def dispatch_allocation(self, allocation: Dict[str, 'Node']):
        """
        Dispatches all allocated stages to their respective nodes for execution.

        Args:
            allocation (Dict[str, Node]): Mapping from stage_id to Node.
        """
        for stage_id, node in allocation.items():
            stage = self.stage_map[stage_id]
            self.dispatch_stage(stage)

    def dispatch_stage(self, stage: Stage):
        """
        Dispatches a single stage to its assigned node for execution.

        Args:
            stage (Stage): The Stage instance to dispatch.
        """
        # Set stage.input_data based on dependencies
        if stage.dependencies:
            # Assuming single dependency
            dep_stage_id = stage.dependencies[0]
            dep_stage = self.stage_map.get(dep_stage_id, None)
            if dep_stage and dep_stage.output_data is not None:
                stage.input_data = dep_stage.output_data
            else:
                print(f"[Scheduler] Warning: Dependency stage '{dep_stage_id}' output data is None for Stage '{stage.stage_id}'.")
        else:
            # For initial stages, input_data is from Task's input_data
            stage.input_data = stage.task.input_data

        def stage_execution():
            try:
                stage.run_stage()
            finally:
                self.stage_completed(stage.stage_id)

        # Assign the stage's run_stage method to the node's task queue
        node_queue = stage.assigned_node.assign_task(stage_execution)
        # Print dispatching details
        print(f"[Scheduler] Dispatched Stage '{stage.stage_id}' to Node '{stage.assigned_node.node_id}'.")

    def execute_task(self, task: Task):
        """
        Executes a single Task by dispatching its allocated Stages.

        Args:
            task (Task): The Task instance to execute.
        """
        allocation = {stage_id: stage.assigned_node for stage_id, stage in task.stages.items()}
        print(f"[Scheduler] Executing Task '{task.task_id}' with {len(allocation)} stages.")
        self.dispatch_allocation(allocation)

    def execute_task_with_graph(self, task: Task):
        """
        Executes a Task using its execution graph to manage dependencies.

        Args:
            task (Task): The Task instance to execute.
        """
        execution_graph = self.execution_graphs.get(task.task_id, None)
        if not execution_graph:
            print(f"[Scheduler] No execution graph found for Task '{task.task_id}'. Cannot execute with graph.")
            return

        print(f"[Scheduler] Starting execution of Task '{task.task_id}' using execution graph.")

        # Perform topological sort to determine execution order
        try:
            sorted_stages = list(nx.topological_sort(execution_graph))
        except nx.NetworkXUnfeasible:
            print(f"[Scheduler] Error: Execution graph for Task '{task.task_id}' has cycles. Cannot proceed.")
            return

        for stage_id in sorted_stages:
            stage = task.get_stage(stage_id)
            if all(dep in self.completed_stages for dep in stage.dependencies):
                self.dispatch_stage(stage)

        print(f"[Scheduler] Completed execution dispatch for Task '{task.task_id}'.")

    def stage_completed(self, stage_id: str):
        """
        Called when a stage is completed to trigger dependent stages.

        Args:
            stage_id (str): The ID of the completed stage.
        """
        with self.lock:
            self.completed_stages.add(stage_id)
            task_id = self.get_task_id_from_stage(stage_id)
            if not task_id:
                print(f"[Scheduler] Warning: Task ID not found for stage '{stage_id}'.")
                return
            task = self.tasks.get(task_id)
            if not task:
                print(f"[Scheduler] Warning: Task '{task_id}' not found for stage '{stage_id}'.")
                return
            stage = task.get_stage(stage_id)
            if not stage:
                print(f"[Scheduler] Warning: Stage '{stage_id}' not found in Task '{task_id}'.")
                return
            # Trigger dependent stages if all their dependencies are met
            for dependent_stage_id in stage.dependents:
                dependent_stage = task.get_stage(dependent_stage_id)
                if dependent_stage and all(dep in self.completed_stages for dep in dependent_stage.dependencies):
                    print(f"[Scheduler] Dependencies met for Stage '{dependent_stage_id}'. Dispatching for execution.")
                    self.dispatch_stage(dependent_stage)

    def get_task_id_from_stage(self, stage_id: str) -> Optional[str]:
        """
        Retrieves the Task ID from a given Stage ID.

        Args:
            stage_id (str): The Stage ID.

        Returns:
            Optional[str]: The corresponding Task ID, or None if not found.
        """
        try:
            task_id, _ = stage_id.split("-stage-")
            return task_id
        except ValueError:
            return None

    def shutdown(self):
        """
        Shuts down the Scheduler gracefully by stopping all Nodes.
        """
        print("[Scheduler] Shutting down all Nodes.")
        for node in self.nodes:
            node.stop()
        print("[Scheduler] All Nodes have been shut down.")

    def calculate_average_utilization(self, task: Task) -> float:
        """
        Calculates the average utilization of a task across all nodes.

        Args:
            task (Task): The Task instance.

        Returns:
            float: Average utilization.
        """
        total_U = 0.0
        count = 0
        for stage in task.stages.values():
            for node in self.nodes:
                total_U += self.load_metric(
                    self.get_execution_time(stage, node),
                    self.observation_window
                )
                count += 1
        return total_U / count if count > 0 else 0.0

    def build_execution_graph(self, task: Task) -> nx.DiGraph:
        """
        Builds an execution graph for a given Task.

        Args:
            task (Task): The Task instance.

        Returns:
            nx.DiGraph: Directed graph representing Stage dependencies.
        """
        G = nx.DiGraph()
        for stage_id, stage in task.stages.items():
            G.add_node(stage_id)
            for dep in stage.dependencies:
                G.add_edge(dep, stage_id)
        return G

# --- Evaluator Class ---
class Evaluator:
    """
    Evaluator class responsible for running tasks using both naive PyTorch execution and
    the custom parallel and pipeline approach. It analyzes speedup, throughput, and
    verifies output correctness.
    """

    def __init__(self, scheduler: 'Scheduler', taskset: 'Taskset', profiler: 'Profiler'):
        """
        Initializes the Evaluator.

        Args:
            scheduler (Scheduler): The Scheduler instance handling task allocation.
            taskset (Taskset): The Taskset containing all tasks to be evaluated.
            profiler (Profiler): The Profiler instance for gathering execution metrics.
        """
        self.scheduler = scheduler
        self.taskset = taskset
        self.profiler = profiler

        # Dictionaries to store outputs from both execution methods
        self.naive_outputs: Dict[str, torch.Tensor] = {}
        self.parallel_outputs: Dict[str, torch.Tensor] = {}

        # Execution times
        self.naive_execution_times: Dict[str, float] = {}
        self.parallel_execution_times: Dict[str, float] = {}

    def run_evaluation(self):
        """
        Runs the entire evaluation process: profiling, naive execution, parallel execution,
        output comparison, and performance analysis.
        """
        print("=== Starting Evaluation ===\n")

        # Step 1: Set Observation Time
        # (Already set in init_phase)

        # Step 2: Run Evaluation Phase
        # (Handled externally via eval_phase utility function)

        print("=== Evaluation Completed ===\n")

    def run_naive_execution(self):
        """
        Executes all evaluation tasks sequentially on a single device (CPU or GPU).
        """
        print("[Evaluator] Starting Naive Execution.")
        for task in self.taskset.tasks:
            model = task.model
            input_tensor = task.input_data  # Now a single Tensor for evaluation
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model.to(device)  # Move model to device
            input_tensor = input_tensor.to(device)  # Move input to device

            start_time = time.time()
            with torch.no_grad():
                output = model(input_tensor)
            end_time = time.time()
            exec_time = end_time - start_time
            print(f"[Evaluator] Task '{task.task_id}' (Naive) executed in {exec_time:.6f} seconds.")
            self.naive_execution_times[task.task_id] = exec_time
            self.naive_outputs[task.task_id] = output

        print("[Evaluator] Naive Execution Completed.\n")

        # Clean up resources
        self.cleanup_resources()

    def run_parallel_execution(self):
        """
        Executes all tasks in the taskset using the Scheduler's parallel and pipeline approach.
        Stores the outputs and execution times for comparison.
        """
        print("[Evaluator] Starting Parallel Execution.")

        # Reset outputs and execution times
        self.parallel_outputs = {}
        self.parallel_execution_times = {}

        # Deep copy tasks to avoid interference
        parallel_tasks = copy.deepcopy(self.taskset.tasks)

        # Reassign the scheduler to the copied tasks
        for task in parallel_tasks:
            task.scheduler = self.scheduler

        # Assign a separate Taskset for parallel execution
        parallel_taskset = Taskset(tasks=parallel_tasks, scheduler=self.scheduler)

        # Start parallel execution
        parallel_start_time = time.time()
        parallel_taskset.execute_all()
        parallel_end_time = time.time()
        total_parallel_time = parallel_end_time - parallel_start_time

        # Collect outputs and execution times
        for task in parallel_tasks:
            self.parallel_outputs[task.task_id] = task.output_data
            self.parallel_execution_times[task.task_id] = task.get_total_execution_time()

        print(f"[Evaluator] Parallel Execution Completed in {total_parallel_time:.6f} seconds.\n")

        # Clean up resources
        self.cleanup_resources()

    def compare_outputs(self):
        """
        Compares the outputs from naive and parallel executions to verify correctness.
        """
        print("[Evaluator] Comparing Outputs for Correctness.")

        all_match = True
        for task_id in self.naive_outputs:
            naive_output = self.naive_outputs.get(task_id)
            parallel_output = self.parallel_outputs.get(task_id)

            if naive_output is None or parallel_output is None:
                print(f"[Evaluator] Task '{task_id}' missing output in one of the executions.")
                all_match = False
                continue

            if torch.equal(naive_output, parallel_output):
                print(f"[Evaluator] Task '{task_id}' outputs match exactly.")
            elif torch.allclose(naive_output, parallel_output, atol=1e-6):
                print(f"[Evaluator] Task '{task_id}' outputs are close within tolerance.")
            else:
                print(f"[Evaluator] Task '{task_id}' outputs do NOT match.")
                all_match = False

        if all_match:
            print("[Evaluator] All task outputs match between naive and parallel executions.\n")
        else:
            print("[Evaluator] Some task outputs do not match. Investigate discrepancies.\n")

    def analyze_speedup_throughput(self):
        """
        Analyzes speedup and throughput between naive and parallel executions.
        """
        print("[Evaluator] Analyzing Speedup and Throughput.")

        # Calculate total naive and parallel execution times
        total_naive_time = sum(self.naive_execution_times.values())
        total_parallel_time = sum(self.parallel_execution_times.values())

        # Speedup: naive_time / parallel_time
        speedup = total_naive_time / total_parallel_time if total_parallel_time > 0 else float('inf')

        # Throughput: number of tasks / total time
        num_tasks = len(self.taskset.tasks)
        throughput_naive = num_tasks / total_naive_time if total_naive_time > 0 else 0.0
        throughput_parallel = num_tasks / total_parallel_time if total_parallel_time > 0 else 0.0

        print(f"[Evaluator] Total Naive Execution Time: {total_naive_time:.6f} seconds.")
        print(f"[Evaluator] Total Parallel Execution Time: {total_parallel_time:.6f} seconds.")
        print(f"[Evaluator] Speedup: {speedup:.2f}x.")
        print(f"[Evaluator] Throughput (Naive): {throughput_naive:.2f} tasks/sec.")
        print(f"[Evaluator] Throughput (Parallel): {throughput_parallel:.2f} tasks/sec.\n")

    def cleanup_resources(self):
        """
        Cleans up resources after execution by clearing outputs and resetting node states.
        """
        print("[Evaluator] Cleaning up resources.")

        # Clear stored outputs and execution times
        self.naive_outputs.clear()
        self.parallel_outputs.clear()
        self.naive_execution_times.clear()
        self.parallel_execution_times.clear()

        # Reset node loads and assigned stages
        for node in self.scheduler.nodes:
            node.current_load = 0.0
            node.assigned_stages.clear()

        # Reset Scheduler's completed stages
        with self.scheduler.lock:
            self.scheduler.completed_stages.clear()

        print("[Evaluator] Resources cleaned up.\n")

# --- Utility Functions (Modified to use single Taskset and single tensor inputs) ---
def init_phase(profiler: 'Profiler',taskset: 'Taskset', nodes: List['Node'], runs: int = 3, slack_percentage: float = 0.1):
    """
    Executes the initialization phase by running the profiler in 'init' mode multiple times
    and calculating the observation window based on profiling data.

    Args:
        profiler (Profiler): The Profiler instance.
        taskset (Taskset): The Taskset containing all tasks to be profiled.
        nodes (List[Node]): List of Node instances to profile on.
        runs (int, optional): Number of profiling runs. Defaults to 3.
        slack_percentage (float, optional): Slack to add to the observation time. Defaults to 0.1 (10%).
    """
    print("[Utility] Starting Init Phase.")
    for run in range(1, runs + 1):
        print(f"[Utility] Init Phase Run {run}/{runs}")
        for node in nodes:
            for task in taskset.tasks:
                profiler.profile_model(
                    model=copy.deepcopy(task.model),          # Use a copy to prevent state changes
                    input_data=copy.deepcopy(task.input_data),# Ensure input data is fresh
                    node=node,
                    model_name=task.model_name
                )
    print(f"[Utility] Completed Init Phase after {runs} runs.\n")

    # Calculate Observation Window
    profiler.print_profile_db()
    total_forward_time = profiler.profile_db['Total Execution Time (us)'].sum() / 1_000_000  # Convert to seconds
    observation_window = total_forward_time * (1 + slack_percentage)
    profiler.observation_window = observation_window  # Update profiler's observation window

    print(f"[Utility] Observation window set to {observation_window:.6f} seconds (Total Forward Time: {total_forward_time:.6f} + {slack_percentage*100}% slack).\n")

    taskset.schedule_all_tasks()
    print("[Taskset/Scheduler] Scheduling all tasks onto the compute")


    return observation_window

def eval_phase(evaluator: 'Evaluator', taskset: 'Taskset'):
    """
    Executes the evaluation phase by running naive and parallel executions,
    comparing their outputs, and analyzing speedup and throughput.

    Args:
        evaluator (Evaluator): The Evaluator instance.
        taskset (Taskset): The Taskset containing all tasks to be evaluated.
    """
    print("[Utility] Starting Evaluation Phase.")

    # Run Naive Execution
    evaluator.run_naive_execution()

    # Run Parallel Execution
    evaluator.run_parallel_execution()

    # Compare Outputs
    evaluator.compare_outputs()

    # Analyze Speedup and Throughput
    evaluator.analyze_speedup_throughput()

    print("[Utility] Evaluation Phase Completed.\n")

# --- Test Script (Modified to have single Taskset with single tensor inputs) ---
# --- Define a Simple Model for Demonstration ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(32 * 8 * 8, 64)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = self.relu2(self.conv2(out))
        out = torch.flatten(out, 1)
        out = self.relu3(self.fc1(out))
        out = self.fc2(out)
        return out

# --- Create Synthetic DataLoader ---
def create_synthetic_dataloader(batch_size: int = 1, num_samples: int = 1):
    inputs = torch.randn(num_samples, 3, 8, 8)  # Example input size
    targets = torch.randint(0, 10, (num_samples,))
    dataset = TensorDataset(inputs, targets)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    return dataloader

# --- Modified Task Initialization ---
def initialize_components():
    # Discover Nodes
    nodes = Node.discover_nodes()
    print(f"[Main] Discovered Nodes: {nodes}\n")

    # Initialize Profiler
    profiler = Profiler(mode='init')
    print("[Main] Initialized Profiler.\n")

    # Initialize Scheduler
    scheduler = Scheduler(
        nodes=nodes,
        profiler=profiler,
        observation_window=1000.0  # Initial observation window (will be recalculated)
    )
    print("[Main] Initialized Scheduler.\n")

    # Create Tasks with Single Tensors for both profiling and evaluation
    num_tasks = 5  # Example number of tasks
    tasks = []
    for i in range(num_tasks):
        model = SimpleCNN()
        # Create a single input tensor
        dataloader = create_synthetic_dataloader(batch_size=1, num_samples=1)
        single_input, single_target = next(iter(dataloader))
        task = Task(
            task_id=f"task{i+1}",
            model=model,
            input_data=single_input,  # input_data is a single Tensor
            model_name=model.__class__.__name__,
            scheduler=scheduler
        )
        tasks.append(task)
    print(f"[Main] Created {num_tasks} Tasks.\n")

    # Initialize Taskset with all tasks
    taskset = Taskset(tasks=tasks, scheduler=scheduler)
    print("[Main] Initialized Taskset.\n")

    # Initialize Evaluator with the single Taskset
    evaluator = Evaluator(
        scheduler=scheduler,
        taskset=taskset,
        profiler=profiler
    )
    print("[Main] Initialized Evaluator.\n")

    return evaluator, taskset, profiler, scheduler, nodes

# --- Run Evaluation ---
def run_evaluation():
    evaluator, taskset, profiler, scheduler, nodes = initialize_components()

    # Run Init Phase with Taskset
    init_observation_window = init_phase(profiler, taskset, nodes, runs=3, slack_percentage=0.1)
    # for task in taskset.tasks:
    #   scheduler.visualize_execution_graph(task.task_id)

    # Run Evaluation Phase with Taskset
    eval_phase(evaluator, taskset)

    # Print Performance Metrics from Taskset
    print("=== Performance Metrics ===")
    print(taskset)
    print("============================\n")

    # Shutdown Scheduler
    scheduler.shutdown()


if __name__ == "__main__":
    run_evaluation()


[Main] Discovered Nodes: [Node(CPU-0, cpus=(0,), gpu=None), Node(CPU-1, cpus=(1,), gpu=None), Node(GPU-0-CPU-0, cpus=(0,), gpu=0), Node(GPU-0-CPU-1, cpus=(1,), gpu=0)]

[Main] Initialized Profiler.

[Main] Initialized Scheduler.

[Main] Created 5 Tasks.

[Main] Initialized Taskset.

[Main] Initialized Evaluator.

[Utility] Starting Init Phase.
[Utility] Init Phase Run 1/3


  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated


[Utility] Init Phase Run 2/3


  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.se

[Utility] Init Phase Run 3/3


  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated
  aggregated['misc']['self_cuda_memory_usage'] += e.self_cuda_memory_usage  # Updated


[Utility] Completed Init Phase after 3 runs.

ProfileDB:
     Model        Layer     Compute  Self CPU (us)  CPU Total (us) CUDA Total (us) Self CPU Mem (bytes) Self CUDA Mem (bytes)  Total Execution Time (us) Total Memory Used (bytes)
SimpleCNN forward_pass       CPU-0     353908.825      371084.675               0                    0                     0               3.710847e+11                         0
SimpleCNN         misc       CPU-0     353436.028      370611.878               0                    0                     0               3.706119e+11                         0
SimpleCNN        conv1       CPU-0       8148.621        8148.621               0                    0                     0               8.148621e+09                         0
SimpleCNN        relu1       CPU-0         76.311          76.311               0                    0                     0               7.631100e+07                         0
SimpleCNN        conv2       CPU-0         77.788   

AttributeError: 'Stage' object has no attribute 'busy_time'

[Stage] task1-stage-0_conv1: Error during execution: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper_CUDA___slow_conv2d_forward)
[Stage] task1-stage-0_conv1: Executed on CPU-1 in 0.118975 seconds. Transfer Time: 0.000019 seconds.
[Stage] task2-stage-0_conv1: Error during execution: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper_CUDA___slow_conv2d_forward)
[Stage] task2-stage-0_conv1: Executed on CPU-1 in 0.000280 seconds. Transfer Time: 0.000015 seconds.
[Stage] task3-stage-0_conv1: Error during execution: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper_CUDA___slow_conv2d_forward)
[Stage] task3-stage-0_conv1: Executed on CPU-1 in 0.000262 seconds. Transfer Time: 0

In [None]:
# -*- coding: utf-8 -*-
"""Unified_Taskset_Evaluation.ipynb

This script implements a single Taskset approach, utilizing single tensor inputs for both profiling and evaluation phases.
"""

import networkx as nx
import os
import torch
import threading
import queue
import time
import re
import pandas as pd
import torch.nn as nn
import torch.profiler
from torch.utils.data import DataLoader, TensorDataset
from typing import Callable, Any, List, Dict, Optional
import copy
import math
import matplotlib.pyplot as plt


# --- Node Class ---
class Node:
    """
    Represents a computational resource: either a CPU-only node (1 CPU core)
    or a GPU+CPU pair. Each Node has:
      - node_id (e.g., 'CPU-0', 'GPU-0-CPU-1')
      - A worker thread + queue to run tasks
    """

    def __init__(self, node_id: str, cpus=None, gpu=None):
        self._node_id = node_id
        self._cpus = tuple(cpus or [])
        self._gpu = gpu

        self._original_affinity = os.sched_getaffinity(0)
        self._task_queue = queue.Queue()
        self._stop_signal = False

        self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
        self._worker_thread.start()

        self.current_load = 0.0  # Initialize current load
        self.assigned_stages = []  # List to track assigned stages

    @property
    def node_id(self):
        return self._node_id

    @property
    def cpus(self):
        return self._cpus

    @property
    def gpu(self):
        return self._gpu

    def assign_task(self, func: Callable, *args, **kwargs) -> queue.Queue:
        """
        Enqueue a function to this node. Returns a queue from which
        the caller can retrieve the result (blocking).
        """
        result_queue = queue.Queue(maxsize=1)
        self._task_queue.put((func, args, kwargs, result_queue))
        return result_queue

    def stop(self):
        """
        Signal the node to stop after processing queued tasks.
        """
        self._stop_signal = True
        self._task_queue.put(None)
        self._worker_thread.join()

    def _worker_loop(self):
        while not self._stop_signal:
            item = self._task_queue.get()
            if item is None:
                break
            func, args, kwargs, result_queue = item
            try:
                self._set_context()
                result = func(*args, **kwargs)
            except Exception as e:
                result = e
            finally:
                self._reset_context()

            result_queue.put(result)

    def _set_context(self):
        if self._cpus:
            os.sched_setaffinity(0, self._cpus)
        if self._gpu is not None and torch.cuda.is_available():
            torch.cuda.set_device(self._gpu)

    def _reset_context(self):
        os.sched_setaffinity(0, self._original_affinity)
        # Optionally reset GPU device if needed

    @staticmethod
    def discover_nodes() -> List['Node']:
        """
        Create a Node for each CPU core, and for each GPU+CPU pair.
        """
        nodes = []
        num_cpus = os.cpu_count() or 1
        ngpus = torch.cuda.device_count()

        # CPU-only nodes
        for core_id in range(num_cpus):
            node = Node(node_id=f"CPU-{core_id}", cpus=[core_id])
            nodes.append(node)

        # GPU+CPU nodes
        for g in range(ngpus):
            for core_id in range(num_cpus):
                node = Node(node_id=f"GPU-{g}-CPU-{core_id}", cpus=[core_id], gpu=g)
                nodes.append(node)

        return nodes

    def __repr__(self):
        return f"Node({self._node_id}, cpus={self._cpus}, gpu={self._gpu})"


# --- Profiler Class ---
class Profiler:
    """
    In 'init' mode: Gather detailed profiling info for each leaf layer on each Node,
    storing results in a CSV-based ProfileDB.
    In 'runtime' mode: Potentially gather minimal logs (optional).
    """

    def __init__(self, mode: str, profile_db_path='profiling_results.csv', log_dir='logs'):
        assert mode in ['init', 'runtime']
        self.mode = mode
        self.profile_db_path = profile_db_path
        self.log_dir = log_dir
        os.makedirs(self.log_dir, exist_ok=True)

        columns = [
            'Model', 'Layer', 'Compute',
            'Self CPU (us)', 'CPU Total (us)', 'CUDA Total (us)',
            'Self CPU Mem (bytes)', 'Self CUDA Mem (bytes)',
            'Total Execution Time (us)', 'Total Memory Used (bytes)'
        ]
        if os.path.exists(self.profile_db_path):
            self.profile_db = pd.read_csv(self.profile_db_path)
        else:
            self.profile_db = pd.DataFrame(columns=columns)

        self.runtime_csv = os.path.join(self.log_dir, 'runtime_results.csv')
        if not os.path.exists(self.runtime_csv):
            rt_cols = ['Model', 'Layer', 'Compute', 'Execution Time (us)']
            pd.DataFrame(columns=rt_cols).to_csv(self.runtime_csv, index=False)

        # We'll store an 'observation_window' if needed
        self.observation_window = 0.0

    def _register_hooks(self, model: nn.Module):
        def hook_wrapper(layer_name):
            def hook(mod, inp, out):
                with torch.profiler.record_function(layer_name):
                    pass
            return hook

        for idx, (name, layer) in enumerate(model.named_modules()):
            if not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.ModuleList) and layer != model:
                layer.register_forward_hook(hook_wrapper(f"{name}_{idx}"))

    def profile_model(self, model: nn.Module, input_data: Any, node, model_name: str, warmup_iters=3):
        """
        Schedule a profiling task on 'node'. In 'init' mode, we gather
        full per-layer times.
        """
        def profiling_task():
            device = torch.device(f"cuda:{node.gpu}" if node.gpu is not None and torch.cuda.is_available() else "cpu")
            model.to(device)

            if self.mode == 'init':
                # Warmup
                with torch.no_grad():
                    for _ in range(warmup_iters):
                        model(input_data.to(device))
                self._profile_init(model, input_data, node, model_name, device)
            else:
                self._profile_runtime(model, input_data, node, model_name, device)

        rq = node.assign_task(profiling_task)
        rq.get()  # block until done

    def _profile_init(self, model, input_data, node, model_name, device):
        self._register_hooks(model)
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
            profile_memory=True
        ) as prof:
            with torch.no_grad():
                model(input_data.to(device))
                prof.step()

        stats = self._process_events(prof, model, node, runtime=False)
        self._update_profile_db(stats, model_name, node, runtime=False)

    def _profile_runtime(self, model, input_data, node, model_name, device):
        self._register_hooks(model)
        with torch.no_grad():
            with torch.profiler.profile(
                activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
            ) as prof:
                model(input_data.to(device))
                prof.step()
            stats = self._process_events(prof, model, node, runtime=True)
            self._append_runtime_csv(stats, model_name, node)

    def _process_events(self, profiler, model, node, runtime=False):
        recognized = set()
        for n, m in model.named_modules():
            if n:
                recognized.add(n)

        aggregated = {
            'forward_pass': dict(self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                                 self_cpu_memory_usage=0, self_cuda_memory_usage=0, compute=node.node_id),
            'misc': dict(self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                         self_cpu_memory_usage=0, self_cuda_memory_usage=0, compute=node.node_id)
        }

        events = list(profiler.events())
        found_root = False

        def strip_suffix(s):
            return re.sub(r'(\.|_)\d+$', '', s)

        for e in events:
            if e.name == "":
                # Root event (some profiler versions yield blank top-level)
                found_root = True
                aggregated['forward_pass']['self_cpu_time_total'] += e.self_cpu_time_total
                aggregated['forward_pass']['cpu_time_total'] += e.cpu_time_total
                aggregated['forward_pass']['cuda_time_total'] += e.device_time_total
                if not runtime:
                    aggregated['forward_pass']['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                    aggregated['forward_pass']['self_cuda_memory_usage'] += e.self_device_memory_usage
            else:
                base = strip_suffix(e.name)
                if base in recognized:
                    if base not in aggregated:
                        aggregated[base] = dict(
                            self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                            self_cpu_memory_usage=0, self_cuda_memory_usage=0,
                            compute=node.node_id
                        )
                    aggregated[base]['self_cpu_time_total'] += e.self_cpu_time_total
                    aggregated[base]['cpu_time_total'] += e.cpu_time_total
                    aggregated[base]['cuda_time_total'] += e.device_time_total
                    if not runtime:
                        aggregated[base]['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                        aggregated[base]['self_cuda_memory_usage'] += e.self_device_memory_usage
                else:
                    aggregated['misc']['self_cpu_time_total'] += e.self_cpu_time_total
                    aggregated['misc']['cpu_time_total'] += e.cpu_time_total
                    aggregated['misc']['cuda_time_total'] += e.device_time_total
                    if not runtime:
                        aggregated['misc']['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                        aggregated['misc']['self_cuda_memory_usage'] += e.self_device_memory_usage

        # If no root event found, sum everything else into forward_pass
        if not found_root:
            # Merge everything else into 'forward_pass'
            for k in list(aggregated.keys()):
                if k not in ('forward_pass', 'misc'):
                    aggregated['forward_pass']['self_cpu_time_total'] += aggregated[k]['self_cpu_time_total']
                    aggregated['forward_pass']['cpu_time_total'] += aggregated[k]['cpu_time_total']
                    aggregated['forward_pass']['cuda_time_total'] += aggregated[k]['cuda_time_total']
                    if not runtime:
                        aggregated['forward_pass']['self_cpu_memory_usage'] += aggregated[k]['self_cpu_memory_usage']
                        aggregated['forward_pass']['self_cuda_memory_usage'] += aggregated[k]['self_cuda_memory_usage']

            aggregated['forward_pass']['self_cpu_time_total'] += aggregated['misc']['self_cpu_time_total']
            aggregated['forward_pass']['cpu_time_total'] += aggregated['misc']['cpu_time_total']
            aggregated['forward_pass']['cuda_time_total'] += aggregated['misc']['cuda_time_total']
            if not runtime:
                aggregated['forward_pass']['self_cpu_memory_usage'] += aggregated['misc']['self_cpu_memory_usage']
                aggregated['forward_pass']['self_cuda_memory_usage'] += aggregated['misc']['self_cuda_memory_usage']

        return aggregated

    def _update_profile_db(self, stats, model_name, node, runtime=False):
        if runtime:
            return
        for layer_name, data in stats.items():
            total_t = data['cpu_time_total'] + data['cuda_time_total']
            total_m = data['self_cpu_memory_usage'] + data['self_cuda_memory_usage']
            row = {
                'Model': model_name,
                'Layer': layer_name,
                'Compute': data['compute'],
                'Self CPU (us)': data['self_cpu_time_total'],
                'CPU Total (us)': data['cpu_time_total'],
                'CUDA Total (us)': data['cuda_time_total'],
                'Self CPU Mem (bytes)': data['self_cpu_memory_usage'],
                'Self CUDA Mem (bytes)': data['self_cuda_memory_usage'],
                'Total Execution Time (us)': total_t * 1_000_000,  # Convert to microseconds
                'Total Memory Used (bytes)': total_m
            }
            self.profile_db = self._upsert(self.profile_db, row)
        self.profile_db.to_csv(self.profile_db_path, index=False)

    def _upsert(self, df, row):
        mask = (
            (df['Model'] == row['Model']) &
            (df['Layer'] == row['Layer']) &
            (df['Compute'] == row['Compute'])
        )
        if not df[mask].empty:
            existing_time = df.loc[mask, 'Total Execution Time (us)'].max()
            if row['Total Execution Time (us)'] > existing_time:
                for k, v in row.items():
                    df.loc[mask, k] = v
        else:
            new_row = pd.DataFrame([row])
            if not new_row.dropna().empty:
                df = pd.concat([df, new_row], ignore_index=True)
        return df

    def _append_runtime_csv(self, stats, model_name, node):
        rows = []
        for layer_name, data in stats.items():
            exec_time = data['cpu_time_total'] + data['cuda_time_total']
            rows.append({
                'Model': model_name,
                'Layer': layer_name,
                'Compute': data['compute'],
                'Execution Time (us)': exec_time * 1_000_000  # microseconds
            })
        if rows:
            rdf = pd.read_csv(self.runtime_csv)
            rdf = pd.concat([rdf, pd.DataFrame(rows)], ignore_index=True)
            rdf.to_csv(self.runtime_csv, index=False)

    def get_profile_db(self):
        return self.profile_db

    def print_profile_db(self):
        if self.profile_db.empty:
            print("ProfileDB is empty.")
        else:
            print("ProfileDB:\n", self.profile_db.to_string(index=False))


# --- Stage Class ---
class Stage:
    """
    Represents a partitioned segment of a model, assigned to a specific Node.
    """

    def __init__(self, stage_id: str, layers: nn.ModuleList, assigned_node: 'Node', task: 'Task'):
        self.stage_id = stage_id
        self.layers = layers
        self.assigned_node = assigned_node

        self.dependencies: List[str] = []
        self.dependents: List[str] = []

        self.execution_time: Optional[float] = None
        self.input_data: Optional[torch.Tensor] = None
        self.output_data: Optional[torch.Tensor] = None
        self.transfer_time: float = 0.0

        self.task = task

    def add_dependency(self, stage_id: str):
        self.dependencies.append(stage_id)

    def add_dependent(self, stage_id: str):
        self.dependents.append(stage_id)

    def run_stage(self):
        start_time = time.time()
        transfer_start = time.time()
        try:
            device = torch.device(
                f"cuda:{self.assigned_node.gpu}" if (self.assigned_node.gpu is not None and torch.cuda.is_available())
                else "cpu"
            )

            if self.input_data is None:
                print(f"[Stage] {self.stage_id}: No input data provided. Executing with empty tensor.")
                out = torch.tensor([])
                transfer_end = time.time()
                self.transfer_time += (transfer_end - transfer_start)
            else:
                inp = self.input_data.to(device)
                transfer_end = time.time()
                self.transfer_time += (transfer_end - transfer_start)

                with torch.no_grad():
                    out = inp
                    for layer in self.layers:
                        out = layer(out)

                if device.type == 'cuda':
                    transfer_start = time.time()
                    out = out.cpu()
                    transfer_end = time.time()
                    self.transfer_time += (transfer_end - transfer_start)

            self.output_data = out

        except Exception as e:
            print(f"[Stage] {self.stage_id}: Error during execution: {e}")
            self.output_data = None
        finally:
            end_time = time.time()
            self.execution_time = end_time - start_time
            print(f"[Stage] {self.stage_id}: Executed on {self.assigned_node.node_id} in {self.execution_time:.6f} seconds. Transfer Time: {self.transfer_time:.6f} seconds.")

            # Update Task's busy time with both execution and transfer times
            self.task.update_busy_time(self.execution_time, self.transfer_time)

            # If this is the final stage, set the Task's output data
            if not self.dependents:
                self.task.set_output_data(self.output_data)

            # Notify Scheduler of stage completion
            self.task.scheduler.stage_completed(self.stage_id)

        return self.output_data

    def __repr__(self):
        return (f"Stage(stage_id={self.stage_id}, node={self.assigned_node.node_id if self.assigned_node else 'None'}, "
                f"deps={self.dependencies}, exec_time={self.execution_time}, transfer_time={self.transfer_time}, "
                f"output_data_present={self.output_data is not None})")

    def __deepcopy__(self, memo):
        new_stage = Stage(
            stage_id=copy.deepcopy(self.stage_id, memo),
            layers=copy.deepcopy(self.layers, memo),
            assigned_node=self.assigned_node,  # keep the same Node reference
            task=None  # or self.task if you prefer
        )
        new_stage.dependencies = copy.deepcopy(self.dependencies, memo)
        new_stage.dependents = copy.deepcopy(self.dependents, memo)
        new_stage.execution_time = self.execution_time
        new_stage.transfer_time = self.transfer_time
        # Optionally copy input/output
        new_stage.input_data = copy.deepcopy(self.input_data, memo)
        new_stage.output_data = copy.deepcopy(self.output_data, memo)
        return new_stage


# --- Task Class ---
class Task:
    """
    Represents a single DNN inference task.
    """

    def __init__(self, task_id: str, model: nn.Module, input_data: torch.Tensor, model_name: str, scheduler: 'Scheduler'):
        self.task_id = task_id
        self.model = model
        self.input_data = input_data
        self.model_name = model_name

        self.scheduler = scheduler

        self.stages: Dict[str, Stage] = {}

        self.start_time: Optional[float] = None
        self.finish_time: Optional[float] = None

        self.output_data: Optional[torch.Tensor] = None
        self.busy_time: float = 0.0
        self.computation_time: float = 0.0
        self.transfer_time: float = 0.0

    def add_stage(self, stage: 'Stage'):
        if stage.stage_id in self.stages:
            raise ValueError(f"Stage ID {stage.stage_id} already exists in Task {self.task_id}.")
        self.stages[stage.stage_id] = stage

    def get_stage(self, stage_id: str) -> Optional['Stage']:
        return self.stages.get(stage_id, None)

    def get_total_execution_time(self) -> float:
        if self.start_time and self.finish_time:
            return self.finish_time - self.start_time
        return 0.0

    def update_busy_time(self, stage_execution_time: float, stage_transfer_time: float = 0.0):
        self.busy_time += stage_execution_time
        self.transfer_time += stage_transfer_time
        self.computation_time += (stage_execution_time - stage_transfer_time)

    def set_output_data(self, output: torch.Tensor):
        self.output_data = output
        self.finish_time = time.time()

    def __repr__(self):
        return (f"Task(task_id={self.task_id}, model_name={self.model_name}, "
                f"stages={list(self.stages.keys())}, "
                f"busy_time={self.busy_time:.6f}, transfer_time={self.transfer_time:.6f}, "
                f"output_data_present={self.output_data is not None})")

    def __deepcopy__(self, memo):
        new_task = Task(
            task_id=copy.deepcopy(self.task_id, memo),
            model=copy.deepcopy(self.model, memo),
            input_data=copy.deepcopy(self.input_data, memo),
            model_name=copy.deepcopy(self.model_name, memo),
            scheduler=None  # exclude to avoid recursion
        )
        new_task.stages = copy.deepcopy(self.stages, memo)
        new_task.start_time = self.start_time
        new_task.finish_time = self.finish_time
        new_task.output_data = self.output_data
        new_task.busy_time = self.busy_time
        new_task.computation_time = self.computation_time
        new_task.transfer_time = self.transfer_time
        return new_task


# --- Taskset Class ---
class Taskset:
    """
    Manages a collection of Tasks and orchestrates their execution using the Scheduler.
    """

    def __init__(self, tasks: List['Task'], scheduler: 'Scheduler'):
        self.tasks = tasks
        self.scheduler = scheduler

        # Performance Metrics
        self.total_utilization: float = 0.0
        self.average_turnaround_time: float = 0.0
        self.throughput: float = 0.0
        self.makespan: float = 0.0
        self.task_completion_rate: float = 0.0
        self.average_resource_utilization_per_node: Dict[str, float] = {}

    def schedule_all_tasks(self):
        for task in self.tasks:
            self.scheduler.decompose_and_allocate_task(task)

    def execute_all(self):
        threads = []
        for task in self.tasks:
            t = threading.Thread(target=self.scheduler.execute_task, args=(task,))
            t.start()
            threads.append(t)

        for t in threads:
            t.join()

        self.calculate_metrics()

    def calculate_metrics(self):
        # 1) total_busy_time = sum of (execution_time) across all tasks
        total_busy_time = 0.0
        for task in self.tasks:
            total_busy_time += task.busy_time

        # 2) total_available_time = observation_window * #nodes
        total_available_time = self.scheduler.observation_window * len(self.scheduler.nodes)
        self.total_utilization = (total_busy_time / total_available_time) if total_available_time > 0 else 0.0

        # 3) average turnaround
        turnaround_times = [task.get_total_execution_time() for task in self.tasks]
        if turnaround_times:
            self.average_turnaround_time = sum(turnaround_times) / len(turnaround_times)
        else:
            self.average_turnaround_time = 0.0

        # 4) makespan = difference between earliest start and latest finish
        start_times = [task.start_time for task in self.tasks if task.start_time is not None]
        finish_times = [task.finish_time for task in self.tasks if task.finish_time is not None]
        if start_times and finish_times:
            earliest_start = min(start_times)
            latest_finish = max(finish_times)
            self.makespan = latest_finish - earliest_start
        else:
            self.makespan = 0.0

        # 5) throughput = number_of_tasks / makespan
        if self.makespan > 0:
            self.throughput = len(self.tasks) / self.makespan
        else:
            self.throughput = 0.0

        # 6) task completion rate
        completed_tasks = [t for t in self.tasks if t.output_data is not None]
        self.task_completion_rate = len(completed_tasks) / len(self.tasks) if self.tasks else 0.0

        # 7) average_resource_utilization_per_node
        node_utilization = {node.node_id: 0.0 for node in self.scheduler.nodes}
        for task in self.tasks:
            for stage in task.stages.values():
                if stage.execution_time is not None:
                    node_utilization[stage.assigned_node.node_id] += (stage.execution_time + stage.transfer_time)
        for node in self.scheduler.nodes:
            total_node_time = self.scheduler.observation_window
            if total_node_time > 0:
                self.average_resource_utilization_per_node[node.node_id] = node_utilization[node.node_id] / total_node_time
            else:
                self.average_resource_utilization_per_node[node.node_id] = 0.0

    def __repr__(self):
        return (
            f"Taskset(total_tasks={len(self.tasks)}, "
            f"total_utilization={self.total_utilization:.2%}, "
            f"average_turnaround_time={self.average_turnaround_time:.6f} sec, "
            f"throughput={self.throughput:.2f} tasks/sec, "
            f"makespan={self.makespan:.6f} sec, "
            f"task_completion_rate={self.task_completion_rate:.2%}, "
            f"avg_util_per_node={self.average_resource_utilization_per_node})"
        )


# --- Scheduler Class ---
class Scheduler:
    """
    Scheduler class responsible for decomposing tasks into stages,
    allocating stages to nodes using DP, dispatching stages for execution,
    and managing dependencies via an execution graph.
    """

    def __init__(
        self,
        nodes: List['Node'],
        profiler: 'Profiler',
        observation_window: float = 1000.0,
        load_metric_func: Optional[Callable[[float, float], float]] = None
    ):
        self.nodes = nodes
        self.profiler = profiler
        self.observation_window = observation_window
        self.load_metric = load_metric_func if load_metric_func else self.default_load_metric
        self.lock = threading.Lock()

        self.tasks: Dict[str, Task] = {}
        self.stage_map: Dict[str, Stage] = {}
        self.completed_stages: set = set()

        # Execution Graphs: task_id -> nx.DiGraph
        self.execution_graphs: Dict[str, nx.DiGraph] = {}

    def default_load_metric(self, execution_time: float, observation_window: float) -> float:
        return execution_time / observation_window

    def build_execution_graph(self, task: Task) -> nx.DiGraph:
      """
      Builds an execution graph for a given Task by adding each stage
      as a node and adding edges based on dependencies.
      """
      G = nx.DiGraph()
      for stage_id, stage in task.stages.items():
          G.add_node(stage_id)
          for dep_id in stage.dependencies:
              G.add_edge(dep_id, stage_id)
      return G


    def decompose_and_allocate_task(self, task: Task):
        with self.lock:
            self.tasks[task.task_id] = task
            task.start_time = time.time()

            print(f"[Scheduler] Starting decomposition and allocation for Task '{task.task_id}'.")

            # 1) Decompose
            stages = self.decompose_task_into_stages(task)

            # 2) Build execution graph
            exec_graph = self.build_execution_graph(task)
            self.execution_graphs[task.task_id] = exec_graph
            print(f"[Scheduler] Built execution graph for Task '{task.task_id}'. "
                  f"Nodes: {exec_graph.number_of_nodes()}, Edges: {exec_graph.number_of_edges()}.")

            # 3) DP Allocate
            allocation = self.dp_allocate(task, stages)

            # 4) Assign stages to nodes
            for stage_id, node in allocation.items():
                stage = task.get_stage(stage_id)
                stage.assigned_node = node
                self.stage_map[stage_id] = stage
                node.assigned_stages.append(stage_id)
                node.current_load += self.load_metric(
                    self.get_execution_time(stage, node),
                    self.observation_window
                )
                print(f"[Scheduler] Allocated Stage '{stage_id}' to Node '{node.node_id}'.")

            # 5) Group allocated stages
            self.group_allocated_stages(task, allocation)
            print(f"[Scheduler] Completed allocation and grouping for Task '{task.task_id}'.")

    def decompose_task_into_stages(self, task: Task) -> List[Stage]:
        """
        Decompose the model into stages (one layer per stage).
        Also link each stage to the previous stage for a linear chain:
          stage0 -> stage1 -> stage2 -> ...
        """
        stages = []
        previous_stage_id = None

        for idx, (name, layer) in enumerate(task.model.named_children()):
            stage_id = f"{task.task_id}-stage-{idx}_{name}"
            stage = Stage(stage_id=stage_id, layers=nn.ModuleList([layer]), assigned_node=None, task=task)
            task.add_stage(stage)
            stages.append(stage)
            print(f"[Scheduler] Created Stage '{stage_id}' with Layer '{name}'.")

            # Add a linear dependency from the previous stage
            if previous_stage_id is not None:
                stage.add_dependency(previous_stage_id)
                prev_stage = task.get_stage(previous_stage_id)
                prev_stage.add_dependent(stage_id)

            previous_stage_id = stage_id

        return stages

    def dp_allocate(self, task: Task, stages: List[Stage]) -> Dict[str, 'Node']:
        num_stages = len(stages)
        num_nodes = len(self.nodes)

        # Current node loads
        w = [node.current_load for node in self.nodes]

        # DP table M
        M = [[math.inf for _ in range(num_nodes + 1)] for _ in range(num_stages + 1)]
        for k in range(num_nodes + 1):
            M[0][k] = 0.0

        # Fill DP
        for n in range(1, num_stages + 1):
            for k in range(1, num_nodes + 1):
                for x in range(0, n):
                    grouped_execution_time = sum(
                        self.get_execution_time(stages[y], self.nodes[k - 1]) for y in range(x, n)
                    )
                    util_sum = self.load_metric(grouped_execution_time, self.observation_window)
                    current_max = max(M[x][k - 1], w[k - 1] + util_sum)
                    if current_max < M[n][k]:
                        M[n][k] = current_max

        # Backtrack
        allocation = {}
        n = num_stages
        k = num_nodes
        while n > 0 and k > 0:
            for x in range(0, n):
                grouped_execution_time = sum(
                    self.get_execution_time(stages[y], self.nodes[k - 1]) for y in range(x, n)
                )
                util_sum = self.load_metric(grouped_execution_time, self.observation_window)
                current_max = max(M[x][k - 1], w[k - 1] + util_sum)
                if math.isclose(current_max, M[n][k], rel_tol=1e-6):
                    for y in range(x, n):
                        allocation[stages[y].stage_id] = self.nodes[k - 1]
                    n = x
                    k -= 1
                    break

        # Print final mapping
        if allocation:
            print(f"[Scheduler] Allocation Mapping for Task '{task.task_id}':")
            sorted_allocation = sorted(allocation.items(),
                                       key=lambda x: int(x[0].split("-stage-")[1].split("_")[0]))
            current_node = None
            current_group = []
            for stage_id, node in sorted_allocation:
                if node != current_node:
                    if current_group:
                        print(f"  - Grouped Stages '{', '.join(current_group)}' allocated to Node '{current_node.node_id}'.")
                        current_group = []
                    current_node = node
                current_group.append(stage_id)
            if current_group:
                print(f"  - Grouped Stages '{', '.join(current_group)}' allocated to Node '{current_node.node_id}'.")
        else:
            print(f"[Scheduler] No allocation performed for Task '{task.task_id}'.")

        return allocation

    def group_allocated_stages(self, task: Task, allocation: Dict[str, 'Node']) -> Dict[str, 'Node']:
        """
        Merge consecutive stages on the same node into one Stage (optional).
        """
        groups = {}
        current_node = None
        current_group = []
        sorted_allocation = sorted(allocation.items(),
                                   key=lambda x: int(x[0].split("-stage-")[1].split("_")[0]))
        for stage_id, node in sorted_allocation:
            if node != current_node:
                if current_group:
                    if current_node.node_id not in groups:
                        groups[current_node.node_id] = []
                    groups[current_node.node_id].append(current_group.copy())
                    current_group = []
                current_node = node
                if current_node.node_id not in groups:
                    groups[current_node.node_id] = []
            current_group.append(stage_id)
        if current_group:
            if current_node.node_id not in groups:
                groups[current_node.node_id] = []
            groups[current_node.node_id].append(current_group.copy())

        new_allocation = {}
        for node_id, stage_groups in groups.items():
            for group in stage_groups:
                if len(group) == 1:
                    # single stage
                    stage_id = group[0]
                    new_allocation[stage_id] = self.nodes[self.get_node_index(node_id)]
                else:
                    # Merge
                    new_stage_id = f"{group[0]}_to_{group[-1]}"
                    merged_layers = nn.ModuleList()
                    for sid in group:
                        st = task.get_stage(sid)
                        merged_layers.extend(st.layers)
                    merged_stage = Stage(
                        stage_id=new_stage_id,
                        layers=merged_layers,
                        assigned_node=self.nodes[self.get_node_index(node_id)],
                        task=task
                    )
                    # Fix dependencies
                    first_stage = task.get_stage(group[0])
                    last_stage = task.get_stage(group[-1])
                    merged_stage.dependencies = first_stage.dependencies.copy()
                    for dep in merged_stage.dependencies:
                        dep_st = task.get_stage(dep)
                        dep_st.dependents.remove(group[0])
                        dep_st.dependents.append(new_stage_id)
                    merged_stage.dependents = last_stage.dependents.copy()
                    for dp in merged_stage.dependents:
                        dp_st = task.get_stage(dp)
                        dp_st.dependencies.remove(group[-1])
                        dp_st.dependencies.append(new_stage_id)

                    # Add to task, remove old
                    task.add_stage(merged_stage)
                    for sid in group:
                        del task.stages[sid]
                        if sid in self.stage_map:
                            del self.stage_map[sid]
                        if sid in self.execution_graphs[task.task_id]:
                            self.execution_graphs[task.task_id].remove_node(sid)
                        node_ = self.nodes[self.get_node_index(node_id)]
                        if sid in node_.assigned_stages:
                            node_.assigned_stages.remove(sid)

                    new_allocation[new_stage_id] = self.nodes[self.get_node_index(node_id)]
                    self.stage_map[new_stage_id] = merged_stage
                    node_ = self.nodes[self.get_node_index(node_id)]
                    node_.assigned_stages.append(new_stage_id)
                    node_.current_load += self.load_metric(
                        self.get_execution_time(merged_stage, node_),
                        self.observation_window
                    )
                    print(f"[Scheduler] Merged Stages '{', '.join(group)}' into '{new_stage_id}' and allocated to Node '{node_.node_id}'.")

        self.update_execution_graph_after_grouping(task)
        return new_allocation

    def get_node_index(self, node_id: str) -> int:
        for i, node in enumerate(self.nodes):
            if node.node_id == node_id:
                return i
        raise ValueError(f"Node with ID '{node_id}' not found.")

    def update_execution_graph_after_grouping(self, task: Task):
        G = self.execution_graphs.get(task.task_id, None)
        if not G:
            print(f"[Scheduler] No execution graph for Task '{task.task_id}' to update.")
            return
        G.clear()
        for sid, st in task.stages.items():
            G.add_node(sid)
            for dep in st.dependencies:
                G.add_edge(dep, sid)
        print(f"[Scheduler] Updated execution graph for Task '{task.task_id}'. "
              f"Nodes: {G.number_of_nodes()}, Edges: {G.number_of_edges()}.")

    def visualize_execution_graph(self, task_id: str):
        if task_id not in self.execution_graphs:
            print(f"[Scheduler] No execution graph found for Task '{task_id}'.")
            return
        G = self.execution_graphs[task_id]
        pos = nx.spring_layout(G)
        plt.figure(figsize=(8, 6))
        nx.draw(G, pos, with_labels=True, node_color='lightblue', edge_color='gray',
                node_size=2000, font_size=10, arrows=True)
        plt.title(f"Execution Graph for Task '{task_id}'")
        plt.show()

    def get_execution_time(self, stage: Stage, node: 'Node') -> float:
        """
        Look up the layer(s) in the profiler DB. If not found, return observation_window.
        """
        try:
            layer_info = stage.stage_id.split('_', 1)[-1]  # after 'stage-#_'
            # If grouped, e.g. "..._to_...", split
            layer_names = layer_info.split("_to_")
        except ValueError:
            print(f"[Scheduler] Warning: Unable to parse layer from '{stage.stage_id}'. Using max time.")
            return self.observation_window

        total_exec_time = 0.0
        for lname in layer_names:
            base_layer_name = re.sub(r'(\.|_)\d+$', '', lname)
            query = (
                (self.profiler.profile_db['Model'] == stage.task.model_name) &
                (self.profiler.profile_db['Layer'] == base_layer_name) &
                (self.profiler.profile_db['Compute'] == node.node_id)
            )
            execution_time_us = self.profiler.profile_db.loc[query, 'Total Execution Time (us)']
            if not execution_time_us.empty:
                val_s = execution_time_us.values[0] / 1_000_000
                total_exec_time += val_s
                print(f"[Scheduler] Retrieved execution time for Layer '{base_layer_name}' on Node '{node.node_id}': {val_s:.6f} seconds.")
            else:
                print(f"[Scheduler] Warning: No profiling data for Layer '{base_layer_name}' on Node '{node.node_id}'. "
                      f"Assigning max execution time ({self.observation_window} seconds).")
                total_exec_time += self.observation_window

        return total_exec_time

    def dispatch_allocation(self, allocation: Dict[str, 'Node']):
        for stage_id, node in allocation.items():
            stage = self.stage_map[stage_id]
            self.dispatch_stage(stage)

    def dispatch_stage(self, stage: Stage):
        # Transfer the output from dependencies
        if stage.dependencies:
            dep_id = stage.dependencies[0]
            dep_stage = self.stage_map.get(dep_id)
            if dep_stage and dep_stage.output_data is not None:
                stage.input_data = dep_stage.output_data
            else:
                print(f"[Scheduler] Warning: Dependency '{dep_id}' output is None for Stage '{stage.stage_id}'.")
        else:
            stage.input_data = stage.task.input_data

        def stage_execution():
            try:
                stage.run_stage()
            finally:
                self.stage_completed(stage.stage_id)

        node_queue = stage.assigned_node.assign_task(stage_execution)
        print(f"[Scheduler] Dispatched Stage '{stage.stage_id}' to Node '{stage.assigned_node.node_id}'.")

    def execute_task(self, task: Task):
        allocation = {sid: st.assigned_node for sid, st in task.stages.items()}
        print(f"[Scheduler] Executing Task '{task.task_id}' with {len(allocation)} stages.")
        self.dispatch_allocation(allocation)

    def stage_completed(self, stage_id: str):
        with self.lock:
            self.completed_stages.add(stage_id)
            task_id = self.get_task_id_from_stage(stage_id)
            if not task_id:
                print(f"[Scheduler] Warning: No Task ID found for stage '{stage_id}'.")
                return
            task = self.tasks.get(task_id)
            if not task:
                print(f"[Scheduler] Warning: Task '{task_id}' not found for stage '{stage_id}'.")
                return
            stage = task.get_stage(stage_id)
            if not stage:
                print(f"[Scheduler] Warning: Stage '{stage_id}' not found in Task '{task_id}'.")
                return
            # Trigger dependent stages if all deps are done
            for dep_id in stage.dependents:
                dep_stage = task.get_stage(dep_id)
                if dep_stage and all(d in self.completed_stages for d in dep_stage.dependencies):
                    print(f"[Scheduler] Dependencies met for Stage '{dep_id}'. Dispatching.")
                    self.dispatch_stage(dep_stage)

    def get_task_id_from_stage(self, stage_id: str) -> Optional[str]:
        try:
            tid, _ = stage_id.split("-stage-")
            return tid
        except ValueError:
            return None

    def shutdown(self):
        print("[Scheduler] Shutting down all Nodes.")
        for node in self.nodes:
            node.stop()
        print("[Scheduler] All Nodes have been shut down.")


# --- Evaluator Class ---
class Evaluator:
    """
    Evaluator runs tasks in naive mode vs. parallel mode and compares outputs.
    """

    def __init__(self, scheduler: 'Scheduler', taskset: 'Taskset', profiler: 'Profiler'):
        self.scheduler = scheduler
        self.taskset = taskset
        self.profiler = profiler

        self.naive_outputs: Dict[str, torch.Tensor] = {}
        self.parallel_outputs: Dict[str, torch.Tensor] = {}
        self.naive_execution_times: Dict[str, float] = {}
        self.parallel_execution_times: Dict[str, float] = {}

    def run_evaluation(self):
        print("=== Starting Evaluation ===\n")
        print("=== Evaluation Completed ===\n")

    def run_naive_execution(self):
        print("[Evaluator] Starting Naive Execution.")
        for task in self.taskset.tasks:
            model = task.model
            input_tensor = task.input_data
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model.to(device)
            input_tensor = input_tensor.to(device)

            t0 = time.time()
            with torch.no_grad():
                output = model(input_tensor)
            t1 = time.time()

            exec_time = t1 - t0
            print(f"[Evaluator] Task '{task.task_id}' (Naive) executed in {exec_time:.6f} seconds.")
            self.naive_execution_times[task.task_id] = exec_time
            self.naive_outputs[task.task_id] = output

        print("[Evaluator] Naive Execution Completed.\n")
        self.cleanup_resources()

    def run_parallel_execution(self):
        print("[Evaluator] Starting Parallel Execution.")
        self.parallel_outputs.clear()
        self.parallel_execution_times.clear()

        # Deep copy tasks
        parallel_tasks = copy.deepcopy(self.taskset.tasks)
        for t in parallel_tasks:
            t.scheduler = self.scheduler

        parallel_taskset = Taskset(tasks=parallel_tasks, scheduler=self.scheduler)
        parallel_start = time.time()
        parallel_taskset.execute_all()
        parallel_end = time.time()
        total_parallel_time = parallel_end - parallel_start

        for t in parallel_tasks:
            self.parallel_outputs[t.task_id] = t.output_data
            self.parallel_execution_times[t.task_id] = t.get_total_execution_time()

        print(f"[Evaluator] Parallel Execution Completed in {total_parallel_time:.6f} seconds.\n")
        self.cleanup_resources()

    def compare_outputs(self):
        print("[Evaluator] Comparing Outputs.")
        all_match = True
        for task_id in self.naive_outputs:
            naive_out = self.naive_outputs[task_id]
            parallel_out = self.parallel_outputs.get(task_id, None)
            if naive_out is None or parallel_out is None:
                print(f"[Evaluator] Task '{task_id}' missing output in one execution.")
                all_match = False
                continue
            if torch.equal(naive_out, parallel_out):
                print(f"[Evaluator] Task '{task_id}' outputs match exactly.")
            elif torch.allclose(naive_out, parallel_out, atol=1e-5):
                print(f"[Evaluator] Task '{task_id}' outputs are close within tolerance.")
            else:
                print(f"[Evaluator] Task '{task_id}' outputs do NOT match.")
                all_match = False
        if all_match:
            print("[Evaluator] All outputs match.\n")
        else:
            print("[Evaluator] Some outputs differ.\n")

    def analyze_speedup_throughput(self):
        print("[Evaluator] Analyzing Speedup.")
        total_naive = sum(self.naive_execution_times.values())
        total_parallel = sum(self.parallel_execution_times.values())
        speedup = total_naive / total_parallel if total_parallel > 0 else float('inf')
        num_tasks = len(self.taskset.tasks)
        naive_thr = num_tasks / total_naive if total_naive > 0 else 0
        parallel_thr = num_tasks / total_parallel if total_parallel > 0 else 0
        print(f"[Evaluator] Speedup: {speedup:.2f}x. "
              f"Naive Throughput: {naive_thr:.2f} tasks/s, Parallel Throughput: {parallel_thr:.2f} tasks/s.\n")

    def cleanup_resources(self):
        print("[Evaluator] Cleaning up resources.")
        self.naive_outputs.clear()
        self.parallel_outputs.clear()
        self.naive_execution_times.clear()
        self.parallel_execution_times.clear()

        # Reset node loads
        for node in self.scheduler.nodes:
            node.current_load = 0.0
            node.assigned_stages.clear()

        # Reset completed stages
        with self.scheduler.lock:
            self.scheduler.completed_stages.clear()

        print("[Evaluator] Resources cleaned up.\n")


# --- Utility Functions ---
def init_phase(profiler: 'Profiler', taskset: 'Taskset', nodes: List['Node'], runs: int = 3, slack_percentage: float = 0.1):
    print("[Utility] Starting Init Phase.")
    for run in range(1, runs + 1):
        print(f"[Utility] Init Phase Run {run}/{runs}")
        for node in nodes:
            for task in taskset.tasks:
                profiler.profile_model(
                    model=copy.deepcopy(task.model),
                    input_data=copy.deepcopy(task.input_data),
                    node=node,
                    model_name=task.model_name
                )
    print(f"[Utility] Completed Init Phase after {runs} runs.\n")
    profiler.print_profile_db()

    total_forward_time = profiler.profile_db['Total Execution Time (us)'].sum() / 1_000_000
    observation_window = total_forward_time * (1 + slack_percentage)
    profiler.observation_window = observation_window

    print(f"[Utility] Observation window set to {observation_window:.6f} seconds "
          f"(Total Forward Time: {total_forward_time:.6f} + {slack_percentage*100}% slack).\n")

    taskset.scheduler.observation_window = observation_window
    taskset.schedule_all_tasks()
    print("[Taskset/Scheduler] Scheduling all tasks onto the compute")

    return observation_window


def eval_phase(evaluator: 'Evaluator', taskset: 'Taskset'):
    print("[Utility] Starting Evaluation Phase.")

    # 1) Naive
    evaluator.run_naive_execution()

    # 2) Parallel
    evaluator.run_parallel_execution()

    # 3) Compare
    evaluator.compare_outputs()

    # 4) Speedup
    evaluator.analyze_speedup_throughput()

    print("[Utility] Evaluation Phase Completed.\n")


# --- Simple CNN Model ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(32 * 8 * 8, 64)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = self.relu2(self.conv2(out))
        out = torch.flatten(out, 1)
        out = self.relu3(self.fc1(out))
        out = self.fc2(out)
        return out


def create_synthetic_dataloader(batch_size: int = 1, num_samples: int = 1):
    inputs = torch.randn(num_samples, 3, 8, 8)
    targets = torch.randint(0, 10, (num_samples,))
    dataset = TensorDataset(inputs, targets)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    return dataloader


def initialize_components():
    nodes = Node.discover_nodes()
    print(f"[Main] Discovered Nodes: {nodes}\n")

    profiler = Profiler(mode='init')
    print("[Main] Initialized Profiler.\n")

    scheduler = Scheduler(
        nodes=nodes,
        profiler=profiler,
        observation_window=1000.0
    )
    print("[Main] Initialized Scheduler.\n")

    num_tasks = 5
    tasks = []
    for i in range(num_tasks):
        model = SimpleCNN()
        dl = create_synthetic_dataloader(batch_size=1, num_samples=1)
        single_input, _ = next(iter(dl))
        task = Task(
            task_id=f"task{i+1}",
            model=model,
            input_data=single_input,
            model_name=model.__class__.__name__,
            scheduler=scheduler
        )
        tasks.append(task)
    print(f"[Main] Created {num_tasks} Tasks.\n")

    taskset = Taskset(tasks=tasks, scheduler=scheduler)
    print("[Main] Initialized Taskset.\n")

    evaluator = Evaluator(
        scheduler=scheduler,
        taskset=taskset,
        profiler=profiler
    )
    print("[Main] Initialized Evaluator.\n")

    return evaluator, taskset, profiler, scheduler, nodes


def run_evaluation():
    evaluator, taskset, profiler, scheduler, nodes = initialize_components()

    # Init Phase
    init_observation_window = init_phase(profiler, taskset, nodes, runs=3, slack_percentage=0.1)

    # Evaluation Phase
    eval_phase(evaluator, taskset)

    # Print final metrics
    print("=== Performance Metrics ===")
    print(taskset)
    print("===========================")

    # Shutdown
    scheduler.shutdown()


if __name__ == "__main__":
    run_evaluation()


[Main] Discovered Nodes: [Node(CPU-0, cpus=(0,), gpu=None), Node(CPU-1, cpus=(1,), gpu=None), Node(GPU-0-CPU-0, cpus=(0,), gpu=0), Node(GPU-0-CPU-1, cpus=(1,), gpu=0)]

[Main] Initialized Profiler.

[Main] Initialized Scheduler.

[Main] Created 5 Tasks.

[Main] Initialized Taskset.

[Main] Initialized Evaluator.

[Utility] Starting Init Phase.
[Utility] Init Phase Run 1/3


  df = pd.concat([df, new_row], ignore_index=True)


[Utility] Init Phase Run 2/3
[Utility] Init Phase Run 3/3
[Utility] Completed Init Phase after 3 runs.

ProfileDB:
     Model        Layer     Compute  Self CPU (us)  CPU Total (us) CUDA Total (us) Self CPU Mem (bytes) Self CUDA Mem (bytes)  Total Execution Time (us) Total Memory Used (bytes)
SimpleCNN forward_pass       CPU-0     145935.871      161809.978               0                    0                     0               1.618100e+11                         0
SimpleCNN         misc       CPU-0     145614.326      161488.433               0                    0                     0               1.614884e+11                         0
SimpleCNN        conv1       CPU-0         84.699          84.699               0                    0                     0               8.469900e+07                         0
SimpleCNN        relu1       CPU-0         61.206          61.206               0                    0                     0               6.120600e+07                     

In [None]:
import torch
import torch.nn as nn
import copy
import time

# ---------------------------------------------------------------------
# 1) Use your existing Node, Profiler, Task, Taskset classes
#    We'll show simplified versions or placeholders below.
# ---------------------------------------------------------------------

class Node:
    def __init__(self, node_id, cpus=None, gpu=None):
        self.node_id = node_id
        self.cpus = cpus or []
        self.gpu = gpu
        self.current_load = 0.0
        self.assigned_stages = []

        import queue
        import threading
        self._task_queue = queue.Queue()
        self._stop_signal = False

        self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
        self._worker_thread.start()

    def assign_task(self, func):
        """Enqueue a function. Return a queue so we can optionally .get() the result."""
        import queue
        result_queue = queue.Queue(maxsize=1)
        self._task_queue.put((func, result_queue))
        return result_queue

    def _worker_loop(self):
        while not self._stop_signal:
            item = self._task_queue.get()
            if item is None:
                break
            func, result_q = item
            result = None
            try:
                result = func()
            except Exception as e:
                result = e
            result_q.put(result)

    def stop(self):
        self._stop_signal = True
        self._task_queue.put(None)
        self._worker_thread.join()

    def __repr__(self):
        return f"Node({self.node_id}, gpu={self.gpu})"


class Profiler:
    """
    Minimal placeholder. We'll assume we have a DataFrame with
    columns: ['Model','Layer','Compute','Total Execution Time (us)'].
    """
    def __init__(self, mode='init'):
        import pandas as pd
        self.mode = mode
        # Hardcode a small DataFrame for demonstration:
        data = {
            'Model':   ['SimpleCNN','SimpleCNN','SimpleCNN','SimpleCNN'],
            'Layer':   ['Conv2d','ReLU','Linear','ReLU'],  # placeholders
            'Compute': ['CPU-0','CPU-1','CPU-0','GPU-0-CPU-1'],  # placeholders
            'Total Execution Time (us)': [5_000, 2_000, 4_000, 2_500]
        }
        self.profile_db = pd.DataFrame(data)

    def get_profile_db(self):
        return self.profile_db


class Task:
    def __init__(self, task_id, model, input_data, model_name, scheduler=None):
        self.task_id = task_id
        self.model = model
        self.input_data = input_data
        self.model_name = model_name
        self.scheduler = scheduler

        self.stages = {}
        self.start_time = None
        self.finish_time = None
        self.output_data = None
        self.busy_time = 0.0
        self.computation_time = 0.0
        self.transfer_time = 0.0

    def add_stage(self, stage):
        if stage.stage_id in self.stages:
            raise ValueError(f"Stage ID {stage.stage_id} already exists.")
        self.stages[stage.stage_id] = stage

    def set_output_data(self, output):
        self.output_data = output
        self.finish_time = time.time()

    def update_busy_time(self, exec_time, transfer_time=0.0):
        self.busy_time += exec_time
        self.transfer_time += transfer_time
        self.computation_time += (exec_time - transfer_time)


class Taskset:
    """
    For demonstration, we just store tasks and can schedule/execute them.
    """
    def __init__(self, tasks, scheduler):
        self.tasks = tasks
        self.scheduler = scheduler

    def schedule_all_tasks(self):
        for t in self.tasks:
            self.scheduler.decompose_and_allocate_task(t)

    def execute_all(self):
        # For demonstration, we just do:
        for t in self.tasks:
            # each stage is dispatched in the scheduler
            # or we can do a simple synchronous approach
            self.scheduler.execute_task(t)

    def __repr__(self):
        return (f"Taskset with {len(self.tasks)} tasks.")


# ---------------------------------------------------------------------
# 2) We use your "recreated" Scheduler and Stage classes from the
#    previous conversation snippet. We'll define them here in short form.
# ---------------------------------------------------------------------

import networkx as nx
import re
import math
from typing import List

# -- We'll import the new Stage from your updated snippet:
# from your_code import Stage

class Stage:
    """From updated snippet: now supports add_layer()."""
    def __init__(self, stage_id, assigned_node=None, task=None):
        self.stage_id = stage_id
        self.assigned_node = assigned_node
        self.task = task

        self.layers = nn.ModuleList()
        self.dependencies = []
        self.dependents = []

        self.input_data = None
        self.output_data = None

        self.execution_time = None
        self.transfer_time = 0.0

    def add_layer(self, layer: nn.Module):
        self.layers.append(layer)

    @property
    def num_layers(self):
        return len(self.layers)

    def add_dependency(self, stage_id: str):
        self.dependencies.append(stage_id)

    def add_dependent(self, stage_id: str):
        self.dependents.append(stage_id)

    def run_stage(self):
        # print(self.input_data)
        import time
        start_time = time.time()
        try:
            node = self.assigned_node
            device_str = "cpu"
            if (node is not None) and (node.gpu is not None) and torch.cuda.is_available():
                device_str = f"cuda:{node.gpu}"
            device = torch.device(device_str)

            # Move layers to device
            for layer in self.layers:
                layer.to(device)

            if self.input_data is None:
                self.output_data = torch.tensor([])
            else:
                inp = self.input_data.to(device)
                with torch.no_grad():
                    out = inp
                    for lyr in self.layers:
                        out = lyr(out)
                if device.type == 'cuda':
                    self.output_data = out.cpu()
                else:
                    self.output_data = out

        except Exception as e:
            print(f"[Stage] {self.stage_id}: Error: {e}")
            self.output_data = None

        end_time = time.time()
        self.execution_time = end_time - start_time
        if self.task:
            self.task.update_busy_time(self.execution_time, self.transfer_time)

        if self.task and not self.dependents:
            self.task.set_output_data(self.output_data)

        if self.task and self.task.scheduler:
            self.task.scheduler.stage_completed(self.stage_id)

        # print(self.output_data)
        return self.output_data

    def __repr__(self):
        return (f"Stage(stage_id={self.stage_id}, #layers={self.num_layers}, "
                f"deps={self.dependencies}, node={self.assigned_node.node_id if self.assigned_node else None})")


class Scheduler:
    """
    Simplified version from the prior snippet.
    """
    def __init__(self, nodes:List[Node], profiler:Profiler, observation_window=1000.0):
        self.nodes = nodes
        self.profiler = profiler
        self.observation_window = observation_window

        self.tasks = {}
        self.stage_map = {}
        self.completed_stages = set()
        self.execution_graphs = {}

        import threading
        self.lock = threading.Lock()

    def decompose_and_allocate_task(self, task):
        with self.lock:
            self.tasks[task.task_id] = task
            task.start_time = time.time()

        # Step 1: Build initial stages (1-liner per top-level layer)
        initial_stages = self.build_initial_stages(task)

        # Step 2: Build graph
        G = self.build_execution_graph(initial_stages)
        self.execution_graphs[task.task_id] = G

        # Step 3: Topo sort
        topo_order = list(nx.topological_sort(G))

        # Step 4: DP allocate + merge
        merged_stages = self.dp_allocate(task, topo_order, initial_stages)

        # Step 5: Register final merged stages
        for stg in merged_stages:
            task.add_stage(stg)
            self.stage_map[stg.stage_id] = stg

        # Optionally print the final allocation
        print(f"[Scheduler] Task '{task.task_id}' final stages:")
        for ms in merged_stages:
            print("   ", ms)

    def build_initial_stages(self, task):
        stages = []
        idx = 0
        for name, layer in task.model.named_children():
            stage_id = f"{task.task_id}-stage-{idx}_{name}"
            stg = Stage(stage_id, assigned_node=None, task=task)
            stg.add_layer(layer)
            stages.append(stg)
            idx += 1
        return stages

    def build_execution_graph(self, stages):
        G = nx.DiGraph()
        for i, st in enumerate(stages):
            G.add_node(st.stage_id)
            if i > 0:
                # linear for simplicity
                st.add_dependency(stages[i-1].stage_id)
                stages[i-1].add_dependent(st.stage_id)
                G.add_edge(stages[i-1].stage_id, st.stage_id)
        return G

    def dp_allocate(self, task, topo_order, stages):
        # 1) re-order
        ordered_stages = [s for s in stages if s.stage_id in topo_order]
        ordered_stages.sort(key=lambda s: topo_order.index(s.stage_id))

        N = len(ordered_stages)
        K = len(self.nodes)
        w = [node.current_load for node in self.nodes]

        # DP table
        M = [[math.inf]*(K+1) for _ in range(N+1)]
        choice = [[0]*(K+1) for _ in range(N+1)]

        for k in range(K+1):
            M[0][k] = 0

        for n in range(1, N+1):
            for k in range(1, K+1):
                for x in range(0, n):
                    grouped_time = self.sum_exec_times(ordered_stages[x:n], self.nodes[k-1], task)
                    load_sum = grouped_time / self.observation_window
                    curr_max = max(M[x][k-1], w[k-1] + load_sum)
                    if curr_max < M[n][k]:
                        M[n][k] = curr_max
                        choice[n][k] = x

        merged_stages = []
        n = N
        k = K
        while n>0 and k>0:
            x = choice[n][k]
            node = self.nodes[k-1]
            # merge stages x..n-1
            new_stage = self.merge_stages(ordered_stages[x:n], node, task)
            merged_stages.insert(0, new_stage)

            n = x
            k -= 1

        return merged_stages

    def merge_stages(self, stage_list, node, task):
        if not stage_list:
            return None
        if len(stage_list)==1:
            st = stage_list[0]
            st.assigned_node = node
            node.assigned_stages.append(st.stage_id)
            return st

        first_id = stage_list[0].stage_id
        last_id  = stage_list[-1].stage_id
        new_id   = f"{first_id}_to_{last_id}"
        new_stage = Stage(new_id, node, task)

        # gather layers
        for s in stage_list:
            for lyr in s.layers:
                new_stage.add_layer(lyr)
        # dependencies / dependents
        new_stage.dependencies = stage_list[0].dependencies.copy()
        new_stage.dependents   = stage_list[-1].dependents.copy()

        node.assigned_stages.append(new_id)
        return new_stage

    def sum_exec_times(self, stage_list, node, task):
        """
        For testing, we just sum up times from the Profiler for each layer.
        We'll do naive matching by layer.__class__.__name__ -> 'Layer' column in the DB
        and 'Compute' == node.node_id if it exists. Otherwise fallback.
        """
        total_time = 0.0
        for st in stage_list:
            for lyr in st.layers:
                layer_class_name = lyr.__class__.__name__
                # Just do a naive match:
                # e.g. Conv2d -> 'Conv2d', ReLU -> 'ReLU', ...
                df = self.profiler.get_profile_db()
                query = (
                    (df['Model']==task.model_name) &
                    (df['Layer']==layer_class_name) &
                    (df['Compute']==node.node_id)
                )
                rows = df.loc[query, 'Total Execution Time (us)']
                if not rows.empty:
                    val_us = rows.values[0]
                    total_time += (val_us / 1e6)
                else:
                    total_time += self.observation_window

        return total_time

    def stage_completed(self, stage_id:str):
        with self.lock:
            self.completed_stages.add(stage_id)

    def execute_task(self, task):
        """
        Dispatch all final stages that have no dependencies or have
        dependencies already completed.
        For a linear chain, we just dispatch the first stage (the rest
        will get triggered in stage_completed).
        """
        # find stages that have no dependencies in task
        for sid, stg in task.stages.items():
            if not stg.dependencies:
                self.dispatch_stage(stg)

    def dispatch_stage(self, stage):
        # print(stage.)
        if stage.dependencies:
            # single dependency if linear
            dep_id = stage.dependencies[0]
            dep_stage = stage.task.stages.get(dep_id)
            print("******************")
            print(dep_stage.output_data)
            print("******************")
            stage.input_data = dep_stage.output_data
        else:
            print("******************")
            print(stage.task.input_data)
            print("******************")
            stage.input_data = stage.task.input_data

        result_q = stage.assigned_node.assign_task(stage.run_stage)
        # optionally .get() if you want synchronous. We'll do async.

    def shutdown(self):
        print("[Scheduler] Shutting down all Nodes.")
        for node in self.nodes:
            node.stop()
        print("[Scheduler] All Nodes have been shut down.")

# ---------------------------------------------------------------------
# 3) Define a simple CNN for demonstration
# ---------------------------------------------------------------------

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(8, 16, 3, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1   = nn.Linear(16*8*8, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x


# ---------------------------------------------------------------------
# 4) Demonstration of the test run
# ---------------------------------------------------------------------

if __name__ == "__main__":

    # (1) Create 2 "Node" objects for CPU
    node_cpu0 = Node(node_id="CPU-0")
    node_cpu1 = Node(node_id="CPU-1")
    nodes = [node_cpu0, node_cpu1]

    # (2) Create a Profiler placeholder
    profiler = Profiler(mode='init')

    # (3) Create a Scheduler
    scheduler = Scheduler(nodes=nodes, profiler=profiler, observation_window=100.0)

    # (4) Build a simple CNN
    model = SimpleCNN()
    input_tensor = torch.randn(1, 3, 8, 8)  # example input

    # (5) Create a single Task
    task = Task(
        task_id="task1",
        model=model,
        input_data=input_tensor,
        model_name="SimpleCNN",
        scheduler=scheduler
    )

    # (6) Put it in a Taskset, schedule and execute
    tset = Taskset([task], scheduler=scheduler)
    tset.schedule_all_tasks()
    tset.execute_all()

    # (7) Wait a bit for all stages to complete (since we used threads)
    import time
    time.sleep(2)

    # (8) Compare pipeline output with naive
    with torch.no_grad():
        naive_out = model(input_tensor)
    pipeline_out = task.output_data

    print("\n=== Comparison ===")
    print("Naive Output:", naive_out)
    print("Pipeline Output:", pipeline_out)

    if pipeline_out is None:
        print("Pipeline did not produce any output!")
    else:
        # check closeness
        if torch.allclose(naive_out, pipeline_out, atol=1e-5):
            print("[Test] SUCCESS: Outputs match (within tolerance).")
        else:
            print("[Test] WARNING: Outputs differ!")

    # (9) Print final stage info
    print("\n=== Stages Info ===")
    for sid, stg in task.stages.items():
        print(stg)

    # (10) Cleanup
    scheduler.shutdown()


[Scheduler] Task 'task1' final stages:
    Stage(stage_id=task1-stage-0_conv1_to_task1-stage-2_conv2, #layers=3, deps=[], node=CPU-0)
    Stage(stage_id=task1-stage-3_relu2_to_task1-stage-4_fc1, #layers=2, deps=['task1-stage-2_conv2'], node=CPU-1)
******************
tensor([[[[ 0.9599, -0.4344,  0.9591,  0.0693, -0.4149, -0.0587, -1.4252,
           -0.5186],
          [-0.0158,  0.1946,  1.0743,  0.2321,  0.4702, -0.5130,  0.4202,
            0.5887],
          [-2.1693, -1.9995, -0.4183,  0.5400, -1.9718, -0.6927,  0.1251,
            0.0414],
          [-0.0327, -0.0607, -1.1242, -1.8818,  0.8721,  1.5568, -0.9984,
           -1.4274],
          [ 1.9087, -0.4064, -3.2843, -2.2190, -0.5669,  1.3269,  1.1347,
           -0.9561],
          [ 0.3090,  0.6564,  1.0238,  0.6457,  0.3690,  0.6289,  2.0433,
            0.8297],
          [ 0.7926, -0.6222, -1.8273, -0.6226,  0.3800, -1.1853, -1.2737,
           -0.4199],
          [ 0.2040, -1.4604, -0.2295,  0.7866, -1.5409,  1.5228, -0.

In [None]:
import torch
import torch.nn as nn
import networkx as nx
import time
import math
import re
from typing import List, Optional
import queue  # Added import for handling queue.Empty

# ---------------------------------------------------------------------
# 1) Reuse your existing Node, Profiler, Task, Taskset, Stage
#    classes with necessary modifications.
# ---------------------------------------------------------------------

class Node:
    def __init__(self, node_id, cpus=None, gpu=None):
        self.node_id = node_id
        self.cpus = cpus or []
        self.gpu = gpu
        self.current_load = 0.0
        self.assigned_stages = []

        self._task_queue = queue.Queue()
        self._stop_signal = False

        self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
        self._worker_thread.start()

    def assign_task(self, func):
        """Enqueue a function. Return a queue so we can optionally .get() the result."""
        result_queue = queue.Queue(maxsize=1)
        self._task_queue.put((func, result_queue))
        return result_queue

    def _worker_loop(self):
        while not self._stop_signal:
            item = self._task_queue.get()
            if item is None:
                break
            func, result_q = item
            result = None
            try:
                result = func()
            except Exception as e:
                result = e
            result_q.put(result)

    def stop(self):
        self._stop_signal = True
        self._task_queue.put(None)
        self._worker_thread.join()

    def __repr__(self):
        return f"Node({self.node_id}, gpu={self.gpu})"


class Profiler:
    """Minimal example with a small profile_db DataFrame."""
    def __init__(self, mode='init'):
        import pandas as pd
        self.mode = mode
        data = {
            'Model':   ['SimpleCNN','SimpleCNN','SimpleCNN','SimpleCNN'],
            'Layer':   ['Conv2d','ReLU','Linear','ReLU'],  # placeholders
            'Compute': ['CPU-0','CPU-1','CPU-0','CPU-1'],  # Updated to match node IDs
            'Total Execution Time (us)': [5_000,2_000,4_000,2_500]
        }
        self.profile_db = pd.DataFrame(data)

    def get_profile_db(self):
        return self.profile_db


class Task:
    def __init__(self, task_id, model, input_data, model_name, scheduler=None):
        self.task_id = task_id
        self.model = model
        self.input_data = input_data
        self.model_name = model_name
        self.scheduler = scheduler

        self.stages = {}
        self.start_time = None
        self.finish_time = None
        self.output_data = None
        self.busy_time = 0.0
        self.computation_time = 0.0
        self.transfer_time = 0.0

    def add_stage(self, stage):
        if stage.stage_id in self.stages:
            raise ValueError(f"Stage ID {stage.stage_id} already exists.")
        self.stages[stage.stage_id] = stage

    def set_output_data(self, output):
        self.output_data = output
        self.finish_time = time.time()

    def update_busy_time(self, exec_time, transfer_time=0.0):
        self.busy_time += exec_time
        self.transfer_time += transfer_time
        self.computation_time += (exec_time - transfer_time)


class Taskset:
    def __init__(self, tasks, scheduler):
        self.tasks = tasks
        self.scheduler = scheduler

    def schedule_all_tasks(self):
        for t in self.tasks:
            self.scheduler.decompose_and_allocate_task(t)

    def execute_all(self):
        for t in self.tasks:
            self.scheduler.execute_task(t)

    def __repr__(self):
        return f"Taskset with {len(self.tasks)} tasks."


class Stage:
    def __init__(self, stage_id, assigned_node=None, task=None):
        self.stage_id = stage_id
        self.assigned_node = assigned_node
        self.task = task

        self.layers = nn.ModuleList()
        self.dependencies = []
        self.dependents = []

        self.input_data = None
        self.output_data = None
        self.execution_time = None
        self.transfer_time = 0.0

    def add_layer(self, layer):
        self.layers.append(layer)

    @property
    def num_layers(self):
        return len(self.layers)

    def add_dependency(self, stage_id):
        self.dependencies.append(stage_id)

    def add_dependent(self, stage_id):
        self.dependents.append(stage_id)

    def run_stage(self):
        start_time = time.time()
        try:
            node = self.assigned_node
            device_str = "cpu"
            if node and node.gpu is not None and torch.cuda.is_available():
                device_str = f"cuda:{node.gpu}"
            device = torch.device(device_str)

            # Move each layer to the assigned device
            for layer in self.layers:
                layer.to(device)

            if self.input_data is None:
                # No input => produce empty tensor
                self.output_data = torch.tensor([])
            else:
                inp = self.input_data.to(device)
                with torch.no_grad():
                    out = inp
                    for lyr in self.layers:
                        out = lyr(out)
                if device.type == 'cuda':
                    self.output_data = out.cpu()
                else:
                    self.output_data = out

            # Debug: Print output_data for verification
            print(f"[Stage] {self.stage_id} executed on {node.node_id}. Output shape: {self.output_data.shape}")

        except Exception as e:
            print(f"[Stage] {self.stage_id}: Error: {e}")
            self.output_data = None

        end_time = time.time()
        self.execution_time = end_time - start_time

        if self.task:
            self.task.update_busy_time(self.execution_time, self.transfer_time)
        if self.task and not self.dependents:
            self.task.set_output_data(self.output_data)

        # Mark stage completed
        if self.task and self.task.scheduler:
            self.task.scheduler.stage_completed(self.stage_id)

        return self.output_data

    def __repr__(self):
        return (f"Stage(stage_id={self.stage_id}, #layers={self.num_layers}, "
                f"deps={self.dependencies}, node={self.assigned_node.node_id if self.assigned_node else None})")


# ---------------------------------------------------------------------
# 2) Updated 'SchedulerNoMerge' that dispatches dependent stages
# ---------------------------------------------------------------------
class SchedulerNoMerge:
    """Scheduler that topologically sorts the layers, assigns each as a separate stage, no merging."""
    def __init__(self, nodes:List[Node], profiler:Profiler, observation_window=1000.0):
        self.nodes = nodes
        self.profiler = profiler
        self.observation_window = observation_window

        self.tasks = {}
        self.stage_map = {}
        self.completed_stages = set()
        self.execution_graphs = {}

        self.lock = threading.Lock()

    def decompose_and_allocate_task(self, task:Task):
        with self.lock:
            self.tasks[task.task_id] = task
            task.start_time = time.time()

        # 1) Build an initial stage per layer (linear chain)
        stages = []
        idx = 0
        for name, layer in task.model.named_children():
            st_id = f"{task.task_id}-stage-{idx}_{name}"
            stg = Stage(st_id, assigned_node=None, task=task)
            stg.add_layer(layer)
            stages.append(stg)
            idx += 1

        # 2) Build a linear execution graph:
        G = nx.DiGraph()
        for i in range(len(stages)):
            s = stages[i]
            G.add_node(s.stage_id)
            if i > 0:
                # linear for simplicity
                s.add_dependency(stages[i-1].stage_id)
                stages[i-1].add_dependent(s.stage_id)
                G.add_edge(stages[i-1].stage_id, s.stage_id)
        self.execution_graphs[task.task_id] = G

        # 3) Do a topological sort
        topo_order = list(nx.topological_sort(G))

        # 4) Allocate each stage individually (no merges)
        #    Use a minimal load approach
        for i, stg in enumerate(stages):
            assigned_node = self.choose_node_for_stage(stg, task)
            stg.assigned_node = assigned_node
            assigned_node.assigned_stages.append(stg.stage_id)
            task.add_stage(stg)
            self.stage_map[stg.stage_id] = stg

        # Debug print
        print(f"[SchedulerNoMerge] Task '{task.task_id}' final stages:")
        for s in stages:
            print("   ", s)

    def choose_node_for_stage(self, stage:Stage, task:Task) -> Node:
        """
        Pick the node with the least projected load.
        """
        best_node = None
        best_load = math.inf
        best_exec_time = 0.0
        for nd in self.nodes:
            stage_time = self.sum_exec_times(stage, task, nd)
            projected_load = nd.current_load + stage_time / self.observation_window
            if projected_load < best_load:
                best_load = projected_load
                best_node = nd
                best_exec_time = stage_time
        # Update the node's load
        best_node.current_load += best_exec_time / self.observation_window
        return best_node

    def sum_exec_times(self, stage:Stage, task:Task, node:Node) -> float:
        """
        Sum up times from the Profiler for each layer on the given node.
        """
        df = self.profiler.get_profile_db()
        total_time = 0.0
        for lyr in stage.layers:
            layer_class = lyr.__class__.__name__
            query = (
                (df['Model'] == task.model_name) &
                (df['Layer'] == layer_class) &
                (df['Compute'] == node.node_id)  # Match node ID
            )
            row = df.loc[query, 'Total Execution Time (us)']
            if not row.empty:
                val_us = row.values[0]
                total_time += (val_us / 1e6)
            else:
                total_time += self.observation_window  # Fallback
        return total_time

    def stage_completed(self, stage_id:str):
        with self.lock:
            self.completed_stages.add(stage_id)
            task_id = self.get_task_id_from_stage(stage_id)
            if not task_id:
                return
            task = self.tasks.get(task_id)
            if not task:
                return
            stg = task.stages.get(stage_id)
            if not stg:
                return
            # Dispatch dependents if all dependencies are met
            for dep_stg_id in stg.dependents:
                dep_stg = task.stages.get(dep_stg_id)
                if dep_stg:
                    if all(d in self.completed_stages for d in dep_stg.dependencies):
                        self.dispatch_stage(dep_stg)

    def get_task_id_from_stage(self, stage_id: str) -> Optional[str]:
        """
        Extracts the task_id from a stage_id.
        """
        parts = stage_id.split("-stage-", maxsplit=1)
        if len(parts) < 2:
            return None
        return parts[0]

    def execute_task(self, task:Task):
        # Dispatch all stages with no dependencies
        for st_id, stg in task.stages.items():
            if not stg.dependencies:
                self.dispatch_stage(stg)

    def dispatch_stage(self, stage:Stage):
        # Set input data
        if stage.dependencies:
            # assume single dependency for linear chain
            dep_id = stage.dependencies[0]
            dep_stage = stage.task.stages.get(dep_id)
            stage.input_data = dep_stage.output_data
        else:
            stage.input_data = stage.task.input_data

        print(f"[SchedulerNoMerge] Dispatching Stage '{stage.stage_id}' to Node '{stage.assigned_node.node_id}'")
        result_q = stage.assigned_node.assign_task(stage.run_stage)

        # Optionally, wait for the stage to complete and handle errors
        try:
            result = result_q.get(timeout=5)  # Wait up to 5 seconds
            if isinstance(result, Exception):
                print(f"[SchedulerNoMerge] Stage '{stage.stage_id}' encountered an error: {result}")
        except queue.Empty:
            print(f"[SchedulerNoMerge] Stage '{stage.stage_id}' timed out.")

    def shutdown(self):
        print("[SchedulerNoMerge] Shutting down all Nodes...")
        for nd in self.nodes:
            nd.stop()
        print("[SchedulerNoMerge] All Nodes stopped.")


# ---------------------------------------------------------------------
# 3) Define a simple CNN for demonstration
# ---------------------------------------------------------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(8, 16, 3, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1   = nn.Linear(16*8*8, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x


# ---------------------------------------------------------------------
# 4) Demonstration of the test run
# ---------------------------------------------------------------------
if __name__ == "__main__":
    import threading  # Ensure threading is imported
    import pandas as pd  # Ensure pandas is available

    # 1) Create two Nodes for CPU
    node_cpu0 = Node(node_id="CPU-0")
    node_cpu1 = Node(node_id="CPU-1")
    nodes = [node_cpu0, node_cpu1]

    # 2) Create a Profiler placeholder
    profiler = Profiler(mode='init')

    # 3) Create a SchedulerNoMerge
    scheduler = SchedulerNoMerge(nodes=nodes, profiler=profiler, observation_window=100.0)

    # 4) Build a simple CNN
    model = SimpleCNN()
    input_tensor = torch.randn(1, 3, 8, 8)  # example input

    # 5) Create a single Task
    task = Task(
        task_id="task1",
        model=model,
        input_data=input_tensor,
        model_name="SimpleCNN",
        scheduler=scheduler
    )

    # 6) Put it in a Taskset, schedule and execute
    tset = Taskset([task], scheduler=scheduler)
    tset.schedule_all_tasks()
    tset.execute_all()

    # 7) Wait a bit for all stages to complete (since we used threads)
    time.sleep(5)  # Increased to 5 seconds to ensure all stages complete

    # 8) Compare pipeline output with naive
    with torch.no_grad():
        naive_out = model(input_tensor)
    pipeline_out = task.output_data

    print("\n=== Comparison ===")
    print("Naive Output:", naive_out)
    print("Pipeline Output:", pipeline_out)

    if pipeline_out is None:
        print("[Test] Pipeline produced no output!")
    else:
        # Check closeness
        if torch.allclose(naive_out, pipeline_out, atol=1e-5):
            print("[Test] SUCCESS: Outputs match (within tolerance).")
        else:
            print("[Test] WARNING: Outputs differ!\n"
                  f"Naive: {naive_out}\nPipe:  {pipeline_out}")

    # 9) Print final stage info
    print("\n=== Stages Info ===")
    for sid, stg in task.stages.items():
        print(stg)

    # 10) Cleanup
    scheduler.shutdown()


[SchedulerNoMerge] Task 'task1' final stages:
    Stage(stage_id=task1-stage-0_conv1, #layers=1, deps=[], node=CPU-0)
    Stage(stage_id=task1-stage-1_relu1, #layers=1, deps=['task1-stage-0_conv1'], node=CPU-1)
    Stage(stage_id=task1-stage-2_conv2, #layers=1, deps=['task1-stage-1_relu1'], node=CPU-0)
    Stage(stage_id=task1-stage-3_relu2, #layers=1, deps=['task1-stage-2_conv2'], node=CPU-1)
    Stage(stage_id=task1-stage-4_fc1, #layers=1, deps=['task1-stage-3_relu2'], node=CPU-0)
[SchedulerNoMerge] Dispatching Stage 'task1-stage-0_conv1' to Node 'CPU-0'
[Stage] task1-stage-0_conv1 executed on CPU-0. Output shape: torch.Size([1, 8, 8, 8])
[SchedulerNoMerge] Dispatching Stage 'task1-stage-1_relu1' to Node 'CPU-1'
[Stage] task1-stage-1_relu1 executed on CPU-1. Output shape: torch.Size([1, 8, 8, 8])
[SchedulerNoMerge] Stage 'task1-stage-0_conv1' timed out.
[SchedulerNoMerge] Stage 'task1-stage-1_relu1' timed out.
[SchedulerNoMerge] Dispatching Stage 'task1-stage-2_conv2' to Node 'CPU-0'