# PyDart Library – Post Round-Robin Implementation

## Overview

This iteration marks the first implementation of a complete PyDart library following the Round-Robin scheduling scheme. In this version, **Utilization** was introduced as a key metric, particularly for **HPC and batch-based workloads**.

## Main Contributions

1. **Implementing Utilization as a Metric**  
   - Defined utilization as:  
     \[
     \text{Utilization} = \frac{\text{Busy Time}}{\text{Observation Window}}
     \]
   - The observation window includes:  
     - All forward pass execution timings.  
     - Associated data processing.  
     - A slack percentage for scheduling flexibility.  

2. **Targeting HPC and Batch Workloads**  
   - Designed to assess performance in **high-performance computing (HPC) environments**.  
   - Evaluated the impact on batch-processing scenarios.  

3. **Performance Evaluation and Key Insights**  
   - Observed **little to no improvement** in performance with this metric.  
   - Although ineffective in batch processing, the approach **may hold value in real-time applications**.  

## Next Steps

- Further refine utilization-based scheduling to assess its impact on **dynamic and real-time inference workloads**.  
- Explore **adaptive scheduling mechanisms** to improve efficiency beyond static allocation.  

---

**Note**: This iteration represents a key step toward refining execution strategies. While initial results were inconclusive, this metric might be more beneficial in real-time applications.


In [None]:
# -*- coding: utf-8 -*-
"""Simplified_Taskset_Evaluation.ipynb

This script implements a simplified Taskset approach, utilizing manual stage assignments.
"""

import networkx as nx
import os
import torch
import threading
import queue
import time
import re
import pandas as pd
import torch.nn as nn
import torch.profiler
from torch.utils.data import DataLoader, TensorDataset
from typing import Callable, Any, List, Dict, Optional
import copy
import math
import matplotlib.pyplot as plt


In [None]:
# utils.py

import torch
from torch.fx import Node as FxNode  # Alias torch.fx.Node to FxNode
from typing import Any, Dict, List


def resolve_arg(arg: Any, node_outputs: Dict[str, torch.Tensor]) -> Any:
    """
    Recursively replaces FxNode references in args and kwargs with their actual outputs.

    Args:
        arg (Any): The argument to resolve.
        node_outputs (Dict[str, torch.Tensor]): A dictionary mapping node names to their output tensors.

    Returns:
        Any: The resolved argument with FxNode references replaced by tensors.
    """
    if isinstance(arg, FxNode):
        return node_outputs[arg.name]
    elif isinstance(arg, (list, tuple)):
        return type(arg)(resolve_arg(a, node_outputs) for a in arg)
    elif isinstance(arg, dict):
        return {k: resolve_arg(v, node_outputs) for k, v in arg.items()}
    else:
        return arg


def group_topological_order(topological_order: List[str], group_size: int = 2) -> Dict[str, List[str]]:
    """
    Groups the topological order list into fixed-size groups.
    Each group is named as 'stage-1', 'stage-2', etc.

    Args:
        topological_order (List[str]): List of operation names in topological order.
        group_size (int): Number of operations per group.

    Returns:
        Dict[str, List[str]]: Mapping from stage names to lists of operation names.
    """
    groups = {}
    num_groups = (len(topological_order) + group_size - 1) // group_size  # Ceiling division

    for i in range(num_groups):
        start_idx = i * group_size
        end_idx = start_idx + group_size
        group_nodes = topological_order[start_idx:end_idx]
        stage_name = f"stage-{i+1}"
        groups[stage_name] = group_nodes

    return groups

In [None]:
import os
import torch
import threading
import queue
from typing import Callable, List

class Node:
    """
    Represents a computational resource: either a CPU-only node (1 CPU core)
    or a GPU+CPU pair. Each Node has:
      - node_id (e.g., 'CPU-0', 'GPU-0-CPU-1')
      - A worker thread + queue to run tasks
    """

    def __init__(self, node_id: str, cpus=None, gpu=None):
        self._node_id = node_id
        self._cpus = tuple(cpus or [])
        self._gpu = gpu

        self._original_affinity = os.sched_getaffinity(0)
        self._task_queue = queue.Queue()
        self._stop_signal = False

        self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
        self._worker_thread.start()

        self.current_load = 0.0  # Initialize current load
        self.assigned_stages = []  # List to track assigned stages

    @property
    def node_id(self):
        return self._node_id

    @property
    def cpus(self):
        return self._cpus

    @property
    def gpu(self):
        return self._gpu

    def assign_task(self, func: Callable, *args, **kwargs) -> queue.Queue:
        """
        Enqueue a function to this node. Returns a queue from which
        the caller can retrieve the result (blocking).
        """
        result_queue = queue.Queue(maxsize=1)
        self._task_queue.put((func, args, kwargs, result_queue))
        return result_queue

    def stop(self):
        """
        Signal the node to stop after processing queued tasks.
        """
        self._stop_signal = True
        self._task_queue.put(None)
        self._worker_thread.join()

    def _worker_loop(self):
        while not self._stop_signal:
            item = self._task_queue.get()
            if item is None:
                break
            func, args, kwargs, result_queue = item
            try:
                self._set_context()
                result = func(*args, **kwargs)
            except Exception as e:
                result = e
            finally:
                self._reset_context()

            result_queue.put(result)

    def _set_context(self):
        if self._cpus:
            os.sched_setaffinity(0, self._cpus)
        if self._gpu is not None and torch.cuda.is_available():
            torch.cuda.set_device(self._gpu)

    def _reset_context(self):
        os.sched_setaffinity(0, self._original_affinity)
        # Optionally reset GPU device if needed

    @staticmethod
    def discover_nodes(disjoint: bool = True) -> List['Node']:
        """
        Create a list of Node objects.

        If disjoint=False (default):
          - For each CPU core, create a CPU-only node.
          - For each GPU device and each CPU core, create a GPU+CPU node.

        If disjoint=True:
          - Create one Node per CPU core (CPU-only) in a temporary list.
          - For each GPU device, pop exactly one CPU core from the CPU list
            and create a GPU+CPU node with that single core.
          - If there are more GPUs than CPU cores, leftover GPUs get a Node with no CPU.
        """
        nodes = []
        num_cpus = os.cpu_count() or 1
        ngpus = torch.cuda.device_count()

        if not disjoint:
            # --- Original Behavior (overlapping) ---
            # CPU-only nodes
            for core_id in range(num_cpus):
                node = Node(node_id=f"CPU-{core_id}", cpus=[core_id])
                nodes.append(node)

            # GPU+CPU nodes (one for each GPU and CPU pair)
            for g in range(ngpus):
                for core_id in range(num_cpus):
                    node = Node(node_id=f"GPU-{g}-CPU-{core_id}", cpus=[core_id], gpu=g)
                    nodes.append(node)

            print("[discover_nodes] Generated ALL possible nodes (CPU-only + GPU+CPU).")

        else:
            # --- Disjoint Behavior (no CPU overlap) ---
            # 1) Build a CPU-only node for every CPU core
            cpu_nodes = []
            for core_id in range(num_cpus):
                node = Node(node_id=f"CPU-{core_id}", cpus=[core_id], gpu=None)
                cpu_nodes.append(node)

            # 2) For each GPU, pop exactly one CPU node (if available)
            #    to pair with that GPU. Otherwise, create a GPU node with no CPU.
            gpu_nodes = []
            for g in range(ngpus):
                if cpu_nodes:
                    # Pop a CPU node from the list
                    removed_node = cpu_nodes.pop()
                    claim_core = removed_node.cpus[0]  # The one CPU from that node

                    # Create a GPU+CPU node
                    node = Node(node_id=f"GPU-{g}-CPU-{claim_core}",
                                cpus=[claim_core],
                                gpu=g)
                    gpu_nodes.append(node)
                    print(f"[discover_nodes] Created GPU node '{node.node_id}' claiming CPU core {claim_core}.")
                else:
                    # If no CPU cores left, create a GPU node with no CPU
                    node = Node(node_id=f"GPU-{g}", cpus=[], gpu=g)
                    gpu_nodes.append(node)
                    print(f"[discover_nodes] Created GPU node '{node.node_id}' with no CPU assigned.")

            # Combine the leftover CPU-only nodes + GPU nodes
            nodes = cpu_nodes + gpu_nodes

            print("[discover_nodes] Generated DISJOINT nodes: leftover CPU-only plus GPU+CPU with unique cores.")

        return nodes

    def __repr__(self):
        return f"Node({self._node_id}, cpus={self._cpus}, gpu={self._gpu})"


In [None]:
# import os
# import re
# import torch
# from torch import nn
# from torch.profiler import profile, ProfilerActivity, record_function
# import torch.fx
# import pandas as pd
# from typing import Any, Dict, Callable, Optional
# import torchvision.models as models

# class Profiler:
#     """
#     Profiler class that uses torch.fx to trace and instrument a PyTorch model,
#     wrapping each operation with torch.profiler.record_function to collect detailed profiling data.
#     """
#     def __init__(self, mode: str, profile_db_path='profiling_results.csv', log_dir='logs'):
#         """
#         Initializes the Profiler.

#         Args:
#             mode (str): 'init' or 'runtime'.
#             profile_db_path (str): Path to the ProfileDB CSV file.
#             log_dir (str): Directory to store logs.
#         """
#         assert mode in ['init', 'runtime'], "Profiler mode must be either 'init' or 'runtime'."
#         self.mode = mode
#         self.profile_db_path = profile_db_path
#         self.log_dir = log_dir
#         os.makedirs(self.log_dir, exist_ok=True)

#         # Define columns for ProfileDB
#         self.columns = [
#             'Task_ID', 'Model', 'Layer', 'Compute',
#             'Self CPU (us)', 'CPU Total (us)', 'CUDA Total (us)',
#             'Self CPU Mem (bytes)', 'Self CUDA Mem (bytes)',
#             'Total Execution Time (us)', 'Total Memory Used (bytes)'
#         ]

#         # Initialize or load ProfileDB
#         if os.path.exists(self.profile_db_path):
#             self.profile_db = pd.read_csv(self.profile_db_path)
#         else:
#             self.profile_db = pd.DataFrame(columns=self.columns)

#         self.runtime_csv = os.path.join(self.log_dir, 'runtime_results.csv')
#         if not os.path.exists(self.runtime_csv):
#             rt_cols = ['Task_ID', 'Model', 'Layer', 'Compute', 'Execution Time (us)']
#             pd.DataFrame(columns=rt_cols).to_csv(self.runtime_csv, index=False)

#         # We'll store an 'observation_window' if needed
#         self.observation_window = 0.0

#     def _trace_and_instrument_model(self, model: nn.Module) -> torch.fx.GraphModule:
#         """
#         Traces the model using torch.fx, instruments every node by wrapping its operation
#         with torch.profiler.record_function, and returns the modified GraphModule.

#         Args:
#             model (nn.Module): The PyTorch model to profile.

#         Returns:
#             torch.fx.GraphModule: The instrumented GraphModule.
#         """
#         # Trace the model
#         tracer = torch.fx.Tracer()
#         graph = tracer.trace(model)
#         graph_module = torch.fx.GraphModule(model, graph)
#         graph = graph_module.graph

#         # Define a unique attribute name prefix to avoid conflicts
#         profiler_attr_prefix = "_profiler_wrapped_"

#         # Iterate through the nodes and wrap functions
#         for node in list(graph.nodes):
#             if node.op == 'call_function':
#                 func = node.target
#                 func_name = func.__name__
#                 # Define a unique name for the wrapped function
#                 wrapped_func_name = f"{profiler_attr_prefix}{func_name}_{id(func)}"

#                 # Define the wrapped function with record_function
#                 def make_wrapped_func(original_func, name):
#                     def wrapped(*args, **kwargs):
#                         with record_function(name):
#                             return original_func(*args, **kwargs)
#                     return wrapped

#                 wrapped_func = make_wrapped_func(func, func_name)

#                 # Assign the wrapped function as an attribute to the model
#                 setattr(model, wrapped_func_name, wrapped_func)

#                 # Replace the node's target with the wrapped function
#                 node.target = getattr(model, wrapped_func_name)

#             elif node.op == 'call_module':
#                 submodule = dict(model.named_modules())[node.target]
#                 func = submodule.forward
#                 func_name = f"{node.target}.forward"

#                 # Define a unique name for the wrapped function
#                 wrapped_func_name = f"{profiler_attr_prefix}{node.target}_forward_{id(func)}"

#                 # Define the wrapped function with record_function
#                 def make_wrapped_forward(original_forward, name):
#                     def wrapped_forward(*args, **kwargs):
#                         with record_function(name):
#                             return original_forward(*args, **kwargs)
#                     return wrapped_forward

#                 wrapped_forward = make_wrapped_forward(func, func_name)

#                 # Assign the wrapped forward as an attribute to the submodule
#                 setattr(submodule, wrapped_func_name, wrapped_forward)

#                 # Replace the submodule's forward with the wrapped version
#                 submodule.forward = getattr(submodule, wrapped_func_name)

#             elif node.op == 'call_method':
#                 method_name = node.target
#                 obj = node.args[0]
#                 # Retrieve the method from the object
#                 original_method = getattr(obj, method_name)
#                 func_name = f"{obj.name}.{method_name}" if hasattr(obj, 'name') else method_name

#                 # Define a unique name for the wrapped method
#                 wrapped_method_name = f"{profiler_attr_prefix}{method_name}_{id(original_method)}"

#                 # Define the wrapped method with record_function
#                 def make_wrapped_method(original_method, name):
#                     def wrapped_method(*args, **kwargs):
#                         with record_function(name):
#                             return original_method(*args, **kwargs)
#                     return wrapped_method

#                 wrapped_method = make_wrapped_method(original_method, func_name)

#                 # Assign the wrapped method back to the object
#                 setattr(obj, wrapped_method_name, wrapped_method)

#                 # Replace the method call with the wrapped method
#                 # Since it's a method, ensure the GraphModule accesses the wrapped method
#                 # In torch.fx, method calls are kept as 'call_method', so no need to change node.op
#                 # Just ensure that the method is wrapped in the object

#             elif node.op == 'placeholder':
#                 # For placeholders, optionally wrap with a pass-through function
#                 # Currently, no profiling needed for input placeholders
#                 pass

#         # Recompile the GraphModule after instrumentation
#         graph_module.recompile()

#         return graph_module

#     def profile_model(self, model: nn.Module, input_data: Any, node_id: str, task_id: str, warmup_iters=3, profile_iters=5):
#         """
#         Profiles the given model on the specified node.

#         Args:
#             model (nn.Module): The PyTorch model to profile.
#             input_data (Any): The input data for the model.
#             node_id (str): Identifier for the compute node (e.g., 'CPU-0', 'GPU-0').
#             task_id (str): Identifier for the profiling task.
#             warmup_iters (int): Number of warmup iterations.
#             profile_iters (int): Number of profiling iterations.
#         """
#         # Trace and instrument the model
#         instrumented_model = self._trace_and_instrument_model(model)

#         # Move to device
#         device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#         instrumented_model.to(device)
#         instrumented_model.eval()

#         # Warmup runs
#         with torch.no_grad():
#             for _ in range(warmup_iters):
#                 _ = instrumented_model(input_data.to(device))

#         print("Starting profiling...")
#         # Start profiler
#         with profile(
#             activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
#             schedule=torch.profiler.schedule(wait=1, warmup=1, active=profile_iters),
#             on_trace_ready=lambda prof: self._trace_handler(prof, task_id, model.__class__.__name__, node_id),
#             record_shapes=True,
#             profile_memory=True,
#             with_stack=True
#         ) as prof:
#             for _ in range(profile_iters):
#                 with torch.no_grad():
#                     _ = instrumented_model(input_data.to(device))
#                 prof.step()

#         # Save profiling data
#         self.profile_db.to_csv(self.profile_db_path, index=False)
#         print(f"Profiling complete. Data saved to {self.profile_db_path}")

#     def _trace_handler(self, prof, task_id: str, model_name: str, node_id: str):
#         """
#         Handles the trace data and processes events.

#         Args:
#             prof: The profiler instance.
#             task_id (str): Identifier for the profiling task.
#             model_name (str): Name of the model.
#             node_id (str): Compute node identifier.
#         """
#         self._process_profiler_data(prof, task_id, model_name, node_id)

#     def _process_profiler_data(self, profiler, task_id: str, model_name: str, node_id: str):
#         """
#         Processes raw profiler data and aggregates it per node, including a forward_pass record.

#         Args:
#             profiler: The profiler instance.
#             task_id (str): Identifier for the profiling task.
#             model_name (str): Name of the model.
#             node_id (str): Compute node identifier.
#         """
#         aggregated = {}
#         forward_pass = {
#             'Task_ID': task_id,
#             'Model': model_name,
#             'Layer': 'forward_pass',
#             'Compute': node_id,
#             'Self CPU (us)': 0.0,
#             'CPU Total (us)': 0.0,
#             'CUDA Total (us)': 0.0,
#             'Self CPU Mem (bytes)': 0,
#             'Self CUDA Mem (bytes)': 0,
#             'Total Execution Time (us)': 0.0,
#             'Total Memory Used (bytes)': 0
#         }

#         events = profiler.key_averages()

#         for evt in events:
#             layer_name = evt.key
#             if layer_name.startswith("aten::"):
#                 continue  # Skip aten operations

#             # Aggregate forward_pass
#             forward_pass['Self CPU (us)'] += evt.self_cpu_time_total
#             forward_pass['CPU Total (us)'] += evt.cpu_time_total
#             forward_pass['CUDA Total (us)'] += evt.cuda_time_total if hasattr(evt, 'cuda_time_total') else 0.0
#             forward_pass['Self CPU Mem (bytes)'] += evt.self_cpu_memory_usage if hasattr(evt, 'self_cpu_memory_usage') else 0
#             forward_pass['Self CUDA Mem (bytes)'] += evt.self_cuda_memory_usage if hasattr(evt, 'self_cuda_memory_usage') else 0
#             forward_pass['Total Execution Time (us)'] += (evt.cpu_time_total + evt.cuda_time_total)  if hasattr(evt, 'cuda_time_total') else evt.cpu_time_total
#             forward_pass['Total Memory Used (bytes)'] += (evt.self_cpu_memory_usage + evt.self_cuda_memory_usage) if hasattr(evt, 'self_cpu_memory_usage') and hasattr(evt, 'self_cuda_memory_usage') else 0

#             # Aggregate per node
#             if layer_name not in aggregated:
#                 aggregated[layer_name] = {
#                     'Task_ID': task_id,
#                     'Model': model_name,
#                     'Layer': layer_name,
#                     'Compute': node_id,
#                     'Self CPU (us)': 0.0,
#                     'CPU Total (us)': 0.0,
#                     'CUDA Total (us)': 0.0,
#                     'Self CPU Mem (bytes)': 0,
#                     'Self CUDA Mem (bytes)': 0,
#                     'Total Execution Time (us)': 0.0,
#                     'Total Memory Used (bytes)': 0
#                 }

#             aggregated[layer_name]['Self CPU (us)'] += evt.self_cpu_time_total
#             aggregated[layer_name]['CPU Total (us)'] += evt.cpu_time_total
#             aggregated[layer_name]['CUDA Total (us)'] += evt.cuda_time_total  if hasattr(evt, 'cuda_time_total') else 0.0
#             aggregated[layer_name]['Self CPU Mem (bytes)'] += evt.self_cpu_memory_usage if hasattr(evt, 'self_cpu_memory_usage') else 0
#             aggregated[layer_name]['Self CUDA Mem (bytes)'] += evt.self_cuda_memory_usage if hasattr(evt, 'self_cuda_memory_usage') else 0
#             aggregated[layer_name]['Total Execution Time (us)'] += (evt.cpu_time_total + evt.cuda_time_total)  if hasattr(evt, 'cuda_time_total') else evt.cpu_time_total
#             aggregated[layer_name]['Total Memory Used (bytes)'] += (evt.self_cpu_memory_usage + evt.self_cuda_memory_usage) if hasattr(evt, 'self_cpu_memory_usage') and hasattr(evt, 'self_cuda_memory_usage') else 0

#         # Insert forward_pass
#         self.profile_db = self._upsert(self.profile_db, forward_pass)

#         # Insert per-node data
#         for layer_name, data in aggregated.items():
#             self.profile_db = self._upsert(self.profile_db, data)

#         # Save to CSV
#         self.profile_db.to_csv(self.profile_db_path, index=False)

#     def _upsert(self, df: pd.DataFrame, row: Dict[str, Any]) -> pd.DataFrame:
#         """
#         Inserts or updates a row in the DataFrame based on Task_ID, Model, Layer, and Compute.
#         Only updates if 'Total Execution Time (us)' is greater than existing.

#         Args:
#             df (pd.DataFrame): The ProfileDB DataFrame.
#             row (Dict[str, Any]): The row data to upsert.

#         Returns:
#             pd.DataFrame: The updated DataFrame.
#         """
#         mask = (
#             (df['Task_ID'] == row['Task_ID']) &
#             (df['Model'] == row['Model']) &
#             (df['Layer'] == row['Layer']) &
#             (df['Compute'] == row['Compute'])
#         )
#         if mask.any():
#             existing_time = df.loc[mask, 'Total Execution Time (us)'].max()
#             if row['Total Execution Time (us)'] > existing_time:
#                 for key in self.columns:
#                     df.loc[mask, key] = row[key]
#         else:
#             new_row = pd.DataFrame([row])
#             df = pd.concat([df, new_row], ignore_index=True)
#         return df

#     def get_profile_db(self) -> pd.DataFrame:
#         """
#         Returns the ProfileDB DataFrame.

#         Returns:
#             pd.DataFrame: The ProfileDB DataFrame.
#         """
#         return self.profile_db

#     def print_profile_db(self):
#         """
#         Prints the ProfileDB DataFrame.
#         """
#         if self.profile_db.empty:
#             print("ProfileDB is empty.")
#         else:
#             print("ProfileDB:")
#             print(self.profile_db.to_string(index=False))

In [None]:
import os
import io
import copy
import torch
import torch.fx
import pandas as pd
from torch import nn
from torch.profiler import profile, ProfilerActivity, record_function
from typing import Any, Dict, Optional

class Profiler:
    """
    Profiler class that uses torch.fx to trace and instrument a PyTorch model,
    wrapping each operation with torch.profiler.record_function to collect detailed profiling data.

    Now includes a "safe clone" inside profile_model:
      - Tries copy.deepcopy first
      - Falls back to torch.save / torch.load if deepcopy fails
    """

    def __init__(self, mode: str, profile_db_path='profiling_results.csv', log_dir='logs'):
        """
        Initializes the Profiler.

        Args:
            mode (str): 'init' or 'runtime'.
            profile_db_path (str): Path to the ProfileDB CSV file.
            log_dir (str): Directory to store logs.
        """
        assert mode in ['init', 'runtime'], "Profiler mode must be either 'init' or 'runtime'."
        self.mode = mode
        self.profile_db_path = profile_db_path
        self.log_dir = log_dir
        os.makedirs(self.log_dir, exist_ok=True)

        # Define columns for ProfileDB
        self.columns = [
            'Task_ID', 'Model', 'Layer', 'Compute',
            'Self CPU (us)', 'CPU Total (us)', 'CUDA Total (us)',
            'Self CPU Mem (bytes)', 'Self CUDA Mem (bytes)',
            'Total Execution Time (us)', 'Total Memory Used (bytes)'
        ]

        # Initialize or load ProfileDB
        if os.path.exists(self.profile_db_path):
            self.profile_db = pd.read_csv(self.profile_db_path)
        else:
            self.profile_db = pd.DataFrame(columns=self.columns)

        # (Optional) Example location for runtime CSV logs
        self.runtime_csv = os.path.join(self.log_dir, 'runtime_results.csv')
        if not os.path.exists(self.runtime_csv):
            rt_cols = ['Task_ID', 'Model', 'Layer', 'Compute', 'Execution Time (us)']
            pd.DataFrame(columns=rt_cols).to_csv(self.runtime_csv, index=False)

        self.observation_window = 0.0

    def profile_model(self,
                      model: nn.Module,
                      input_data: Any,
                      node_id: str,
                      task_id: str,
                      warmup_iters=3,
                      profile_iters=5):
        """
        Profiles the given model on the specified node. Automatically:
          1) Clones the model (to avoid double-instrumentation).
          2) Traces and instruments the clone with torch.fx and record_function.
          3) Collects and saves profiling data.

        Args:
            model (nn.Module): The PyTorch model to profile.
            input_data (Any): The input data for the model.
            node_id (str): Identifier for the compute node (e.g., 'CPU-0', 'GPU-0').
            task_id (str): Identifier for the profiling task.
            warmup_iters (int): Number of warmup iterations.
            profile_iters (int): Number of profiling iterations.
        """
        # 1) Create a fresh, uninstrumented model copy
        model_copy = self._clone_model_safely(model)

        # 2) Trace & instrument the fresh copy
        instrumented_model = self._trace_and_instrument_model(model_copy)

        # 3) Move to device (CPU or CUDA)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        instrumented_model.to(device)
        instrumented_model.eval()

        # 4) Warmup runs
        with torch.no_grad():
            for _ in range(warmup_iters):
                _ = instrumented_model(input_data.to(device))

        print("Starting profiling...")
        # 5) Actual profiling with the fresh, instrumented model
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=profile_iters),
            on_trace_ready=lambda prof: self._trace_handler(prof, task_id, model_copy.__class__.__name__, node_id),
            record_shapes=True,
            profile_memory=True,
            with_stack=True
        ) as prof:
            for _ in range(profile_iters):
                with torch.no_grad():
                    _ = instrumented_model(input_data.to(device))
                prof.step()

        # 6) Save profiling data
        self.profile_db.to_csv(self.profile_db_path, index=False)
        print(f"Profiling complete. Data saved to {self.profile_db_path}")

    def _clone_model_safely(self, model: nn.Module) -> nn.Module:
        """
        Tries to clone the given model via copy.deepcopy.
        If that fails for any reason, falls back to torch.save/torch.load.

        Returns:
            A fresh, uninstrumented model instance.
        """
        try:
            return copy.deepcopy(model)
        except Exception as e:
            print(f"[Profiler] deepcopy failed with error: {e}")
            print("[Profiler] Falling back to torch.save / torch.load approach.")
            buffer = io.BytesIO()
            torch.save(model, buffer)
            buffer.seek(0)
            return torch.load(buffer)

    def _trace_and_instrument_model(self, model: nn.Module) -> torch.fx.GraphModule:
        """
        Traces the model using torch.fx, instruments every node by wrapping its operation
        with torch.profiler.record_function, and returns the modified GraphModule.
        """
        # 1) FX tracing
        tracer = torch.fx.Tracer()
        graph = tracer.trace(model)
        graph_module = torch.fx.GraphModule(model, graph)
        graph = graph_module.graph

        profiler_attr_prefix = "_profiler_wrapped_"

        # 2) Iterate over nodes and wrap calls
        for node in list(graph.nodes):
            node_name = node.name

            if node.op == 'call_function':
                func = node.target  # e.g., torch.add
                wrapped_func_name = f"{profiler_attr_prefix}{node_name}_{id(func)}"

                def make_wrapped_func(original_func, profile_name):
                    def wrapped(*args, **kwargs):
                        with record_function(profile_name):
                            return original_func(*args, **kwargs)
                    return wrapped

                wrapped_func = make_wrapped_func(func, node_name)
                setattr(model, wrapped_func_name, wrapped_func)
                node.target = getattr(model, wrapped_func_name)

            elif node.op == 'call_module':
                submodule = dict(model.named_modules())[node.target]
                func = submodule.forward
                wrapped_func_name = f"{profiler_attr_prefix}{node_name}_{id(func)}"

                def make_wrapped_forward(original_forward, profile_name):
                    def wrapped_forward(*args, **kwargs):
                        with record_function(profile_name):
                            return original_forward(*args, **kwargs)
                    return wrapped_forward

                wrapped_forward = make_wrapped_forward(func, node_name)
                setattr(submodule, wrapped_func_name, wrapped_forward)
                submodule.forward = getattr(submodule, wrapped_func_name)

            elif node.op == 'call_method':
                method_name = node.target
                obj = node.args[0]
                original_method = getattr(obj, method_name, None)
                if original_method is None:
                    continue

                wrapped_method_name = f"{profiler_attr_prefix}{node_name}_{id(original_method)}"

                def make_wrapped_method(orig_meth, profile_name):
                    def wrapped_method(*args, **kwargs):
                        with record_function(profile_name):
                            return orig_meth(*args, **kwargs)
                    return wrapped_method

                wrapped_method = make_wrapped_method(original_method, node_name)
                setattr(obj, wrapped_method_name, wrapped_method)
                # In many cases, no further replacement is needed for call_method,
                # but it depends on how your FX graph references the method.

            elif node.op == 'placeholder':
                # Usually no instrumentation needed for placeholders
                pass

        # 3) Recompile the GraphModule after instrumentation
        graph_module.recompile()
        return graph_module

    def _trace_handler(self, prof, task_id: str, model_name: str, node_id: str):
        """
        Handles the trace data once the profiling schedule triggers it.
        """
        self._process_profiler_data(prof, task_id, model_name, node_id)

    def _process_profiler_data(self, profiler, task_id: str, model_name: str, node_id: str):
        """
        Processes raw profiler data and aggregates it into self.profile_db.
        Also adds a forward_pass record combining all ops.

        Args:
            profiler: The profiler instance.
            task_id (str): Identifier for the profiling task.
            model_name (str): Model name (class).
            node_id (str): Compute node identifier (e.g., 'CPU-0').
        """
        aggregated = {}
        forward_pass = {
            'Task_ID': task_id,
            'Model': model_name,
            'Layer': 'forward_pass',
            'Compute': node_id,
            'Self CPU (us)': 0.0,
            'CPU Total (us)': 0.0,
            'CUDA Total (us)': 0.0,
            'Self CPU Mem (bytes)': 0,
            'Self CUDA Mem (bytes)': 0,
            'Total Execution Time (us)': 0.0,
            'Total Memory Used (bytes)': 0
        }

        events = profiler.key_averages()

        for evt in events:
            layer_name = evt.key

            # Skip low-level PyTorch ops if desired
            if layer_name.startswith("aten::"):
                continue

            # Accumulate a "forward_pass" total
            forward_pass['Self CPU (us)'] += evt.self_cpu_time_total
            forward_pass['CPU Total (us)'] += evt.cpu_time_total
            forward_pass['CUDA Total (us)'] += getattr(evt, 'cuda_time_total', 0.0)
            forward_pass['Self CPU Mem (bytes)'] += getattr(evt, 'self_cpu_memory_usage', 0)
            forward_pass['Self CUDA Mem (bytes)'] += getattr(evt, 'self_cuda_memory_usage', 0)
            forward_pass['Total Execution Time (us)'] += (
                evt.cpu_time_total + getattr(evt, 'cuda_time_total', 0.0)
            )
            forward_pass['Total Memory Used (bytes)'] += (
                getattr(evt, 'self_cpu_memory_usage', 0) +
                getattr(evt, 'self_cuda_memory_usage', 0)
            )

            if layer_name not in aggregated:
                aggregated[layer_name] = {
                    'Task_ID': task_id,
                    'Model': model_name,
                    'Layer': layer_name,
                    'Compute': node_id,
                    'Self CPU (us)': 0.0,
                    'CPU Total (us)': 0.0,
                    'CUDA Total (us)': 0.0,
                    'Self CPU Mem (bytes)': 0,
                    'Self CUDA Mem (bytes)': 0,
                    'Total Execution Time (us)': 0.0,
                    'Total Memory Used (bytes)': 0
                }

            # Per-layer accumulation
            aggregated[layer_name]['Self CPU (us)'] += evt.self_cpu_time_total
            aggregated[layer_name]['CPU Total (us)'] += evt.cpu_time_total
            aggregated[layer_name]['CUDA Total (us)'] += getattr(evt, 'cuda_time_total', 0.0)
            aggregated[layer_name]['Self CPU Mem (bytes)'] += getattr(evt, 'self_cpu_memory_usage', 0)
            aggregated[layer_name]['Self CUDA Mem (bytes)'] += getattr(evt, 'self_cuda_memory_usage', 0)
            aggregated[layer_name]['Total Execution Time (us)'] += (
                evt.cpu_time_total + getattr(evt, 'cuda_time_total', 0.0)
            )
            aggregated[layer_name]['Total Memory Used (bytes)'] += (
                getattr(evt, 'self_cpu_memory_usage', 0) +
                getattr(evt, 'self_cuda_memory_usage', 0)
            )

        # Insert forward_pass row
        self.profile_db = self._upsert(self.profile_db, forward_pass)

        # Insert each layer's data
        for layer_name, data in aggregated.items():
            self.profile_db = self._upsert(self.profile_db, data)

        # Save to CSV
        self.profile_db.to_csv(self.profile_db_path, index=False)

    def _upsert(self, df: pd.DataFrame, row: Dict[str, Any]) -> pd.DataFrame:
        """
        Inserts or updates a row in the DataFrame based on Task_ID, Model, Layer, and Compute.
        Only updates if 'Total Execution Time (us)' is greater than existing.
        """
        mask = (
            (df['Task_ID'] == row['Task_ID']) &
            (df['Model'] == row['Model']) &
            (df['Layer'] == row['Layer']) &
            (df['Compute'] == row['Compute'])
        )
        if mask.any():
            existing_time = df.loc[mask, 'Total Execution Time (us)'].max()
            if row['Total Execution Time (us)'] > existing_time:
                for key in self.columns:
                    df.loc[mask, key] = row[key]
        else:
            new_row = pd.DataFrame([row])
            df = pd.concat([df, new_row], ignore_index=True)
        return df

    def get_profile_db(self) -> pd.DataFrame:
        """
        Returns the ProfileDB DataFrame.
        """
        return self.profile_db

    def print_profile_db(self):
        """
        Prints the ProfileDB DataFrame.
        """
        if self.profile_db.empty:
            print("ProfileDB is empty.")
        else:
            print("ProfileDB:")
            print(self.profile_db.to_string(index=False))


In [None]:
# stage.py

import torch
import time
from typing import Dict, Optional, List
from torch.fx import Node as FxNode

# from utils import resolve_arg  # Ensure this import path is correct
# from node import Node            # Ensure this import path is correct

def move_tensor_to_device(obj, device):
    """
    Recursively move all Tensors in `obj` to `device`.
    If `obj` is just a Tensor, call .to(device) on it.
    If `obj` is a (list, tuple, dict), recurse.
    Otherwise, return `obj` as-is.
    """
    if isinstance(obj, torch.Tensor):
        return obj.to(device)
    elif isinstance(obj, (list, tuple)):
        return type(obj)(move_tensor_to_device(elem, device) for elem in obj)
    elif isinstance(obj, dict):
        return {k: move_tensor_to_device(v, device) for k, v in obj.items()}
    else:
        return obj

class Stage:
    """
    Represents a group of operations (FxNodes) assigned to a particular Node.
    """

    def __init__(self, stage_id: str, nodes: List[FxNode], assigned_node: Node, task: 'Task'):
        self.stage_id = stage_id
        self.nodes = nodes
        self.assigned_node = assigned_node

        self.dependencies: List[str] = []
        self.dependents: List[str] = []

        self.execution_time: Optional[float] = None
        self.transfer_time: float = 0.0

        self.output_data: Optional[torch.Tensor] = None
        self.input_data: Optional[Dict[str, torch.Tensor]] = None

        self.task = task

        # We'll store the device as a string if needed:
        self.stage_device: str = "cpu"

    def add_dependency(self, stage_id: str):
        self.dependencies.append(stage_id)

    def add_dependent(self, stage_id: str):
        self.dependents.append(stage_id)

    def run_stage(self, node_outputs: Dict[str, torch.Tensor]):
        """
        Executes the stage's operations, updates execution metrics, and handles output data.
        This method is intended to be enqueued to the Node's worker thread.
        """
        start_time = time.time()
        transfer_time = 0.0

        # Decide device from assigned node
        if (self.assigned_node is not None) and (self.assigned_node.gpu is not None) and torch.cuda.is_available():
            device = torch.device(f"cuda:{self.assigned_node.gpu}")
            self.stage_device = str(device)
        else:
            device = torch.device("cpu")
            self.stage_device = "cpu"

        # Synchronize before starting (for accurate timing)
        # if device.type == 'cuda':
        #     torch.cuda.synchronize(device.index)

        try:
            with torch.no_grad():
                for fx_node in self.nodes:
                    # Resolve all fx_node inputs from node_outputs
                    resolved_args = resolve_arg(fx_node.args, node_outputs)
                    resolved_kwargs = resolve_arg(fx_node.kwargs, node_outputs)

                    # Transfer input Tensors to 'device'
                    transfer_start = time.time()
                    resolved_args = move_tensor_to_device(resolved_args, device)
                    resolved_kwargs = move_tensor_to_device(resolved_kwargs, device)
                    transfer_end = time.time()
                    transfer_time += (transfer_end - transfer_start)

                    # Actually run the operation
                    if fx_node.op == 'placeholder':
                        # Typically means the main input to the entire model
                        out = self.task.input_data.to(device)
                        node_outputs[fx_node.name] = out

                    elif fx_node.op == 'call_module':
                        submodule = self.task.model.get_submodule(fx_node.target)
                        submodule.to(device)  # ensure submodule is on the same device
                        out = submodule(*resolved_args, **resolved_kwargs)

                    elif fx_node.op == 'call_function':
                        func = fx_node.target
                        out = func(*resolved_args, **resolved_kwargs)

                    elif fx_node.op == 'call_method':
                        method = getattr(resolved_args[0], fx_node.target)
                        out = method(*resolved_args[1:], **resolved_kwargs)

                    elif fx_node.op == 'output':
                        # The final output node for this stage
                        out = resolved_args[0]

                    else:
                        raise NotImplementedError(f"Operation '{fx_node.op}' not supported in run_stage().")

                    # Save the result into node_outputs
                    node_outputs[fx_node.name] = out

        except AttributeError as e:
            print(f"[Stage] {self.stage_id}: AttributeError during execution: {e}")
            self.execution_time = float('inf')
            self.transfer_time = float('inf')
            node_outputs[self.stage_id] = None
            return  # Early exit on failure

        except Exception as e:
            print(f"[Stage] {self.stage_id}: Error during execution: {e}")
            self.execution_time = float('inf')
            self.transfer_time = float('inf')
            node_outputs[self.stage_id] = None
            return  # Early exit on failure

        finally:
            # After execution, synchronize again if on GPU
            # if device.type == 'cuda':
            #     torch.cuda.synchronize(device.index)

            end_time = time.time()
            self.execution_time = end_time - start_time
            self.transfer_time = transfer_time

        # Update the Task's busy_time
        self.task.update_busy_time(self.execution_time, self.transfer_time)

        # If no dependents, set Task's output_data
        if not self.dependents:
            final_output_node = next((n for n in self.nodes if n.op == 'output'), None)
            if final_output_node:
                # Assuming the first argument is the tensor
                arg = final_output_node.args[0]
                if isinstance(arg, torch.Tensor):
                    final_res = arg.cpu()
                elif isinstance(arg, FxNode):
                    final_res = node_outputs.get(arg.name, None)
                else:
                    final_res = None
                self.task.set_output_data(final_res)
            else:
                self.task.set_output_data(None)

        # Print Stage execution info
        print(f"[Stage] {self.stage_id}: Executed on {self.assigned_node.node_id if self.assigned_node else 'None'} "
              f"in {self.execution_time:.6f} seconds. Transfer Time: {self.transfer_time:.6f} seconds.")

    def __repr__(self):
        return (f"Stage(stage_id={self.stage_id}, device={self.stage_device}, "
                f"node={self.assigned_node.node_id if self.assigned_node else 'None'}, "
                f"deps={self.dependencies}, exec_time={self.execution_time}, "
                f"transfer_time={self.transfer_time})")


In [None]:
import networkx as nx
import torch
import torch.nn as nn
import torch.fx as fx
import time
from typing import List, Dict, Optional, Any

# from stage import Stage
# from utils import group_topological_order


import pandas as pd
from typing import List, Dict, Any, Optional, Set

def retrieve_layer_profile_records(
    profile_db: pd.DataFrame,
    task_id: str,
    model_name: str,
    compute: Optional[str],  # e.g., 'CPU-0'
    layer_names: List[str],
    placeholder_names: Optional[Set[str]] = None
) -> Dict[str, Optional[Dict[str, Any]]]:
    """
    Retrieves profiler records for each layer in layer_names, filtered by:
      - Task_ID == task_id
      - Model == model_name
      - (optional) Compute == compute
      - Layer == each layer_name in layer_names.

    If a layer is in 'placeholder_names' (or is named 'x'), we override with a zero-cost record
    that mirrors the usual keys but has all 0 for times and memory.

    Returns:
        Dict[str, Optional[Dict[str, Any]]]:
            layer_name -> row dict or None if not found
    """
    # If no placeholder set is given, default to empty
    placeholder_names = placeholder_names or set()

    layer_records = {}

    # Helper function to build a zero-cost record
    def build_zero_record(layer_name: str, compute_str: str) -> Dict[str, Any]:
        return {
            "Task_ID": task_id,
            "Model": model_name,
            "Layer": layer_name,
            "Compute": compute_str if compute_str else "N/A",
            "Self CPU (us)": 0.0,
            "CPU Total (us)": 0.0,
            "CUDA Total (us)": 0.0,
            "Self CPU Mem (bytes)": 0,
            "Self CUDA Mem (bytes)": 0,
            "Total Execution Time (us)": 0.0,
            "Total Memory Used (bytes)": 0
        }

    for layer in layer_names:
        # 1) If it's in placeholder_names or specifically called "x", produce a zero-cost record
        if layer in placeholder_names or layer == "x":
            layer_records[layer] = build_zero_record(layer, compute or "")
            continue

        # 2) Otherwise, do normal DataFrame filter
        mask = (
            (profile_db['Task_ID'] == task_id)
            & (profile_db['Model'] == model_name)
            & (profile_db['Layer'] == layer)
        )
        if compute is not None:
            mask = mask & (profile_db['Compute'] == compute)

        matched = profile_db.loc[mask]

        if not matched.empty:
            row_dict = matched.iloc[0].to_dict()  # First match
            layer_records[layer] = row_dict
        else:
            # If no match found, store None
            layer_records[layer] = None

    return layer_records

class Task:
    """
    Represents a single DNN inference task with DAG-based stage allocation.
    """

    def __init__(self, task_id: str, model: nn.Module, input_data: torch.Tensor, model_name: str, profiler: 'Profiler'):
        self.task_id = task_id
        self.model = model
        self.input_data = input_data
        self.model_name = model_name

        self.stages: Dict[str, 'Stage'] = {}
        self.graph = nx.DiGraph()

        self.start_time: Optional[float] = None
        self.finish_time: Optional[float] = None

        self.output_data: Optional[torch.Tensor] = None
        self.busy_time: float = 0.0
        self.computation_time: float = 0.0
        self.transfer_time: float = 0.0

        self.profiler = profiler
        self.model.eval()
        self.init_traced_graph = None
        self.placeholder_names = set()

        self.prof_records: Dict[tuple, Optional[Dict[str, Any]]] = {}

        # Build the DAG of stages
        self._initialize_dag(group_size=4)

    def _initialize_dag(self, group_size: int = 2):
        # 1) Symbolically trace
        tracer = fx.symbolic_trace(self.model)
        traced_graph = tracer.graph

        topological_order = [node for node in traced_graph.nodes]
        # print(topological_order)
        self.init_traced_graph = [n.name for n in topological_order]
        # print(self.init_traced_graph)

        self.placeholder_names = set(node.name for node in topological_order if node.op == "placeholder")

        # from utils import group_topological_order  # or your local import
        grouped_stages = group_topological_order([n.name for n in topological_order],
                                                 group_size=group_size)

        # 2) Create Stage objects
        # from stage import Stage  # or local
        for stage_name, node_names in grouped_stages.items():
            stage_id = f"{self.task_id}-{stage_name}"
            nodes = [n for n in traced_graph.nodes if n.name in node_names]


            stage = Stage(
                stage_id=stage_id,
                nodes=nodes,
                assigned_node=None,
                task=self
            )
            self.add_stage(stage)

        # 3) Add dependencies
        for stage_name, node_names in grouped_stages.items():
            stage_id = f"{self.task_id}-{stage_name}"
            if stage_name != "stage-1":
                prev_idx = int(stage_name.split('-')[1]) - 1
                prev_stage_id = f"{self.task_id}-stage-{prev_idx}"
                self.add_dependency(prev_stage_id, stage_id)

        # 4) Retrieve & store profiler records if profiler is available
        if self.profiler is not None:
            # Get the entire profiler DataFrame
            db = self.profiler.get_profile_db()

            # Filter to relevant rows for this Task + Model
            mask = (
                (db['Task_ID'] == self.task_id) &
                (db['Model'] == self.model_name)
            )
            relevant_df = db.loc[mask]

            # Find all Compute strings present in the DB for this task & model
            all_computes = relevant_df['Compute'].unique().tolist()

            # For each compute, retrieve layer records and store them in self.prof_records
            for compute_str in all_computes:
                records_for_compute = retrieve_layer_profile_records(
                    profile_db=db,
                    task_id=self.task_id,
                    model_name=self.model_name,
                    compute=compute_str,
                    layer_names=self.init_traced_graph,
                    placeholder_names=self.placeholder_names

                )
                for layer_name, row_dict in records_for_compute.items():
                    # Key is (compute, layer_name)
                    self.prof_records[(compute_str, layer_name)] = row_dict


    def get_forward_pass_time(self, sum_across_compute: bool = False) -> float:
        """
        Retrieves the 'forward_pass' time by querying the Profiler's DataFrame directly,
        filtering on this Task's ID + Model + 'forward_pass'.

        Args:
            sum_across_compute (bool):
                If True, sums the forward pass times across all devices (CPU-0, CPU-1, etc.).
                If False, returns the max across devices.

        Returns:
            float: Total forward-pass time in microseconds (or 0.0 if none found).
        """
        if not self.profiler:
            # If for some reason there's no Profiler, return 0.
            return 0.0

        # 1) Get the global profile DataFrame
        profile_df = self.profiler.get_profile_db()

        # 2) Filter rows for (Task_ID == self.task_id) & (Model == self.model_name) & (Layer == 'forward_pass')
        mask = (
            (profile_df['Task_ID'] == self.task_id) &
            (profile_df['Model'] == self.model_name) &
            (profile_df['Layer'] == 'forward_pass')
        )
        matched = profile_df.loc[mask]

        if matched.empty:
            return 0.0

        # 3) Extract the 'Total Execution Time (us)' for each device
        times = matched['Total Execution Time (us)']

        # 4) Depending on your strategy, sum across all devices or take max
        if sum_across_compute:
            return times.sum()
        else:
            return times.max()


    def add_stage(self, stage: 'Stage'):
        if stage.stage_id in self.stages:
            raise ValueError(f"Stage ID {stage.stage_id} already exists in Task {self.task_id}.")
        self.stages[stage.stage_id] = stage
        self.graph.add_node(stage.stage_id, stage=stage)

    def add_dependency(self, from_stage_id: str, to_stage_id: str):
        if from_stage_id not in self.stages or to_stage_id not in self.stages:
            raise ValueError("Stages must exist before adding a dependency.")
        self.graph.add_edge(from_stage_id, to_stage_id)
        self.stages[to_stage_id].add_dependency(from_stage_id)
        self.stages[from_stage_id].add_dependent(to_stage_id)

    def get_execution_order(self) -> List[str]:
        try:
            return list(nx.topological_sort(self.graph))
        except nx.NetworkXUnfeasible:
            raise ValueError("Stage dependencies contain a cycle.")

    def assign_nodes_to_stages(self, available_nodes: List['Node']):
        """
        Round-robin node assignment. Also moves submodules to the correct device
        so that the submodule weights will be on CPU or GPU prior to execution.
        """
        node_count = len(available_nodes)
        sorted_stage_ids = sorted(self.stages.keys())  # ensure stable ordering

        for idx, stage_id in enumerate(sorted_stage_ids):
            stage = self.stages[stage_id]
            # node = available_nodes[idx % node_count]
            node = available_nodes[idx % node_count]
            stage.assigned_node = node

    def get_stage(self, stage_id: str) -> Optional['Stage']:
        return self.stages.get(stage_id, None)

    def get_total_execution_time(self) -> float:
        if self.start_time and self.finish_time:
            return self.finish_time - self.start_time
        return 0.0

    def update_busy_time(self, stage_execution_time: float, stage_transfer_time: float = 0.0):
        self.busy_time += stage_execution_time
        self.transfer_time += stage_transfer_time
        self.computation_time += (stage_execution_time - stage_transfer_time)

    def set_output_data(self, output: torch.Tensor):
        """
        Typically we want final output on CPU for correctness checks.
        """
        if output is not None:
            self.output_data = output.cpu()
        else:
            self.output_data = None
        self.finish_time = time.time()

    def print_stage_allocations(self):
        print(f"=== Stage Allocations for Task '{self.task_id}' ===")
        for stage_id, stage in self.stages.items():
            layer_names = []
            for fx_node in stage.nodes:
                # print(fx_node.name,"****")
                if isinstance(fx_node.target, str):
                    layer_names.append(fx_node.target)
                else:
                    layer_names.append(str(fx_node.target))
            node_id = stage.assigned_node.node_id if stage.assigned_node else "Unassigned"
            print(f"Stage ID: {stage_id}")
            print(f"  Assigned Node: {node_id}")
            print(f"  Layers: {', '.join(layer_names)}" if layer_names else "  Layers: None")
            print(f"  Dependencies: {stage.dependencies}")
            print(f"  Dependents: {stage.dependents}\n")


In [None]:
# import networkx as nx
# import torch
# import torch.nn as nn
# import torch.fx as fx
# import time
# from typing import List, Dict, Optional, Any

# # from stage import Stage
# # from utils import group_topological_order

# class Task:
#     """
#     Represents a single DNN inference task with DAG-based stage allocation,
#     enhanced to store/organize its own profiling data in a model-agnostic way.
#     """

#     def __init__(self,
#                  task_id: str,
#                  model: nn.Module,
#                  input_data: torch.Tensor,
#                  model_name: str,
#                  profiler: 'Profiler'):
#         self.task_id = task_id
#         self.model = model
#         self.input_data = input_data
#         self.model_name = model_name

#         # Stage-based DAG
#         self.stages: Dict[str, 'Stage'] = {}
#         self.graph = nx.DiGraph()

#         # Execution timers
#         self.start_time: Optional[float] = None
#         self.finish_time: Optional[float] = None
#         self.output_data: Optional[torch.Tensor] = None

#         # Busy time metrics
#         self.busy_time: float = 0.0
#         self.computation_time: float = 0.0
#         self.transfer_time: float = 0.0

#         self.profiler = profiler
#         self.model.eval()

#         # NEW: Profiling data fields
#         self.raw_profiling_data: List[Dict[str, Any]] = []          # All raw rows for this task
#         self.parsed_layer_data: Dict[str, Dict[str, Any]] = {}      # Mapped layer/operator -> {device: metrics}
#         self.traced_graph_order: List[str] = []                     # Node names in topological order
#         self.forward_pass_summary: Dict[str, Any] = {}              # Optional aggregated metrics for the entire pass

#         # Build the DAG of stages
#         self._initialize_dag(group_size=4)

#     def _initialize_dag(self, group_size: int = 2):
#         """
#         Symbolically trace the model, generate a topological order of nodes,
#         group them into stages, and build a DAG of those stages.
#         """
#         # 1) Symbolic trace
#         traced_module = fx.symbolic_trace(self.model)
#         traced_graph = traced_module.graph

#         # 2) Grab the node list in topological order
#         topological_nodes = [node for node in traced_graph.nodes]
#         self.traced_graph_order = [node.name for node in topological_nodes]
#         print(topological_nodes)  # Debug info, can remove if not needed

#         # 3) Group the node names
#         # from utils import group_topological_order
#         grouped_stages = group_topological_order(self.traced_graph_order, group_size=group_size)

#         # 4) Create Stage objects and add them to self.stages + self.graph
#         # from stage import Stage
#         for stage_name, node_names in grouped_stages.items():
#             stage_id = f"{self.task_id}-{stage_name}"
#             # Filter the actual FxNode objects by name
#             nodes_in_stage = [n for n in traced_graph.nodes if n.name in node_names]

#             stage = Stage(
#                 stage_id=stage_id,
#                 nodes=nodes_in_stage,
#                 assigned_node=None,
#                 task=self
#             )
#             self.add_stage(stage)

#         # 5) Add dependencies for a linear pipeline
#         for stage_name, node_names in grouped_stages.items():
#             stage_id = f"{self.task_id}-{stage_name}"
#             if stage_name != "stage-1":
#                 prev_idx = int(stage_name.split('-')[1]) - 1
#                 prev_stage_id = f"{self.task_id}-stage-{prev_idx}"
#                 self.add_dependency(prev_stage_id, stage_id)

#     def add_stage(self, stage: 'Stage'):
#         if stage.stage_id in self.stages:
#             raise ValueError(f"Stage ID {stage.stage_id} already exists in Task {self.task_id}.")
#         self.stages[stage.stage_id] = stage
#         self.graph.add_node(stage.stage_id, stage=stage)

#     def add_dependency(self, from_stage_id: str, to_stage_id: str):
#         if from_stage_id not in self.stages or to_stage_id not in self.stages:
#             raise ValueError("Stages must exist before adding a dependency.")
#         self.graph.add_edge(from_stage_id, to_stage_id)
#         self.stages[to_stage_id].add_dependency(from_stage_id)
#         self.stages[from_stage_id].add_dependent(to_stage_id)

#     def get_execution_order(self) -> List[str]:
#         """
#         Return the stage IDs in a topological sequence.
#         """
#         try:
#             return list(nx.topological_sort(self.graph))
#         except nx.NetworkXUnfeasible:
#             raise ValueError("Stage dependencies contain a cycle.")

#     def assign_nodes_to_stages(self, available_nodes: List['Node']):
#         """
#         Round-robin node assignment; also moves submodules to the device
#         so submodule weights are on CPU/GPU prior to execution.
#         """
#         node_count = len(available_nodes)
#         sorted_stage_ids = sorted(self.stages.keys())

#         for idx, stage_id in enumerate(sorted_stage_ids):
#             stage = self.stages[stage_id]
#             node = available_nodes[idx % node_count]
#             stage.assigned_node = node

#             # Decide device string
#             if node.gpu is not None and torch.cuda.is_available():
#                 device_str = f"cuda:{node.gpu}"
#             else:
#                 device_str = "cpu"

#             # Move submodules used by this stage to device_str
#             for fx_node in stage.nodes:
#                 if fx_node.op == 'call_module':
#                     submodule = self.model.get_submodule(fx_node.target)
#                     submodule.to(device_str)

#     def get_stage(self, stage_id: str) -> Optional['Stage']:
#         return self.stages.get(stage_id, None)

#     def get_total_execution_time(self) -> float:
#         if self.start_time and self.finish_time:
#             return self.finish_time - self.start_time
#         return 0.0

#     def update_busy_time(self, stage_execution_time: float, stage_transfer_time: float = 0.0):
#         self.busy_time += stage_execution_time
#         self.transfer_time += stage_transfer_time
#         self.computation_time += (stage_execution_time - stage_transfer_time)

#     def set_output_data(self, output: torch.Tensor):
#         """
#         Store final output on CPU for correctness checks, mark finish time.
#         """
#         if output is not None:
#             self.output_data = output.cpu()
#         else:
#             self.output_data = None
#         self.finish_time = time.time()

#     def print_stage_allocations(self):
#         print(f"=== Stage Allocations for Task '{self.task_id}' ===")
#         for stage_id, stage in self.stages.items():
#             layer_names = []
#             for fx_node in stage.nodes:
#                 if isinstance(fx_node.target, str):
#                     layer_names.append(fx_node.target)
#                 else:
#                     layer_names.append(str(fx_node.target))
#             node_id = stage.assigned_node.node_id if stage.assigned_node else "Unassigned"
#             print(f"Stage ID: {stage_id}")
#             print(f"  Assigned Node: {node_id}")
#             print(f"  Layers: {', '.join(layer_names)}" if layer_names else "  Layers: None")
#             print(f"  Dependencies: {stage.dependencies}")
#             print(f"  Dependents: {stage.dependents}\n")

#     # ----------------------------------------------------------------------
#     # NEW: Store and parse profiling data for this Task
#     # ----------------------------------------------------------------------
#     def init_profiling_data(self,
#                             profiling_records: List[Dict[str, Any]],
#                             forward_pass_agg: Optional[Dict[str, Any]] = None) -> None:
#         """
#         Called after profiling is done for this Task. This stores the raw rows,
#         aggregates them into self.parsed_layer_data, and optionally stores
#         forward-pass summary info.

#         Args:
#             profiling_records: A list of dictionaries, each with columns from the profiler
#                               (Layer, Compute, CPU Total (us), etc.).
#             forward_pass_agg: Aggregated metrics for the entire forward pass (optional).
#         """
#         self.raw_profiling_data = profiling_records
#         self.forward_pass_summary = forward_pass_agg or {}

#         temp_data: Dict[str, Dict[str, Any]] = {}

#         for row in profiling_records:
#             raw_layer_name = row.get("Layer", "")
#             compute_dev = row.get("Compute", "Unknown")

#             mapped_name, skip_flag = self._map_profiler_layer_to_fx_name(raw_layer_name)
#             if skip_flag:
#                 continue  # e.g. skip [memory], ProfilerStep*, etc.

#             if mapped_name not in temp_data:
#                 temp_data[mapped_name] = {}

#             cpu_total_us = row.get("CPU Total (us)", 0.0)
#             self_cpu_us = row.get("Self CPU (us)", 0.0)
#             cuda_total_us = row.get("CUDA Total (us)", 0.0)
#             mem_cpu_bytes = row.get("Self CPU Mem (bytes)", 0)
#             mem_cuda_bytes = row.get("Self CUDA Mem (bytes)", 0)
#             total_exec_us = row.get("Total Execution Time (us)", 0.0)

#             temp_data[mapped_name][compute_dev] = {
#                 "cpu_total_us": cpu_total_us,
#                 "self_cpu_us": self_cpu_us,
#                 "cuda_total_us": cuda_total_us,
#                 "mem_cpu_bytes": mem_cpu_bytes,
#                 "mem_cuda_bytes": mem_cuda_bytes,
#                 "total_exec_us": total_exec_us,
#             }

#         self.parsed_layer_data = temp_data

#     def _map_profiler_layer_to_fx_name(self, raw_layer_name: str) -> (str, bool):
#         """
#         Maps a profiler 'Layer' string (like 'resnet18.layer1.0.conv1.forward', 'cat',
#         '[memory]', 'ProfilerStep*', etc.) to a node name in traced_graph_order.
#         Returns (mapped_name, skip_flag).
#           - mapped_name: The final name to store in parsed_layer_data
#           - skip_flag: whether we skip storing this line entirely

#         Adjust logic to match your naming patterns. This is an example approach.
#         """
#         # 1) Skip known lines
#         if "[memory]" in raw_layer_name or "ProfilerStep*" in raw_layer_name:
#             return ("[skip]", True)
#         if raw_layer_name in ["forward_pass"]:
#             # keep if you want to store it, or skip. We'll keep:
#             return ("forward_pass", False)
#         if raw_layer_name.startswith("placeholder"):
#             return ("placeholder", False)
#         if raw_layer_name.startswith("output"):
#             return ("output", False)

#         # 2) Remove .forward if present
#         name = raw_layer_name
#         if name.endswith(".forward"):
#             name = name[:-8]  # remove ".forward"

#         # 3) If there's a known prefix, remove it. Adjust as needed.
#         # You might have a single known prefix. Here's an example:
#         known_prefixes = ["resnet18.", "SimpleCNN.", "simple_cnn_task", "resnet_task"]
#         for prefix in known_prefixes:
#             if name.startswith(prefix):
#                 name = name[len(prefix):]

#         # 4) Replace '.' with '_' if it looks like a submodule path
#         #    but skip if it's an operator like "cat", "add", etc. that has no dot.
#         # We'll do a naive approach: only replace if there's a dot.
#         if "." in name and not (name in ["cat", "add", "flatten", "relu", "conv2d"]):
#             name = name.replace(".", "_")

#         # cleanup double underscores
#         while "__" in name:
#             name = name.replace("__", "_")

#         # 5) Attempt partial matching in traced_graph_order
#         possible_matches = []
#         for fx_node in self.traced_graph_order:
#             if name in fx_node or fx_node in name:
#                 possible_matches.append(fx_node)

#         if len(possible_matches) == 1:
#             return (possible_matches[0], False)
#         elif len(possible_matches) > 1:
#             # pick the first or do a more advanced approach
#             return (possible_matches[0], False)

#         # If no matches, fallback to the transformed name
#         return (name, False)

#     def __repr__(self) -> str:
#         """
#         Returns a string representation for debugging, showing Task ID
#         and a summary of profiling data if available.
#         """
#         lines = [f"Task(task_id={self.task_id}, model={self.model_name})"]
#         if self.parsed_layer_data:
#             lines.append(f"  #parsed_layers={len(self.parsed_layer_data)}")
#         if self.forward_pass_summary:
#             lines.append(f"  forward_pass_summary={self.forward_pass_summary}")
#         if self.stages:
#             lines.append(f"  #stages={len(self.stages)}")
#         return "\n".join(lines)


In [None]:
# load_metric.py
from typing import Protocol, runtime_checkable
from abc import abstractmethod

# We'll assume you have references to Task and Taskset so you can import them or forward-declare:
# from task import Task
# from taskset import Taskset

@runtime_checkable
class LoadMetric(Protocol):
    """
    A protocol (interface) for computing a 'load' or 'utilization' metric
    for a Task in the context of a Taskset.
    """
    @abstractmethod
    def compute(self, task: "Task", taskset: "Taskset") -> float:
        """
        Computes a load metric for `task`, possibly referencing data
        in `taskset`. Returns a float representing the load or utilization.
        """
        ...

class HPCUtilizationMetric:
    """
    Default HPC-style utilization metric:
      utilization = task.busy_time / observation_window

    where observation_window is computed from the Taskset
    (e.g., sum of forward pass times + a slack factor).
    """
    def __init__(self, slack_fraction: float = 0.2):
        self.slack_fraction = slack_fraction

    def compute(self, task: "Task", taskset: "Taskset") -> float:
        # We ask Taskset for an observation window
        obs_window = taskset.compute_observation_window(self.slack_fraction)
        if obs_window <= 0:
            return 0.0

        # HPC approach: utilization = busy_time / observation_window
        return task.busy_time / obs_window

In [None]:
# taskset.py

import threading
import time
from typing import List, Dict
# from task import Task
# from node import Node
# from utils import resolve_arg

class Taskset:
    """
    Manages a collection of Tasks and orchestrates their execution using DAG-based stage allocation.
    """

    def __init__(self, tasks: List[Task], available_nodes: List['Node'] ,metric: LoadMetric = None):
        """
        Initializes the Taskset.

        Args:
            tasks (List[Task]): The list of tasks to manage.
            available_nodes (List[Node]): The list of available compute nodes.
        """
        self.tasks = tasks
        self.available_nodes = available_nodes

        # Performance Metrics
        self.total_utilization: float = 0.0
        self.average_turnaround_time: float = 0.0
        self.throughput: float = 0.0
        self.makespan: float = 0.0
        self.task_completion_rate: float = 0.0

        # from load_metric import HPCUtilizationMetric  # or define up top
        self.metric = metric if metric is not None else HPCUtilizationMetric()

        # Initialize a lock for thread-safe updates if needed
        self.lock = threading.Lock()

        self.loads = {}

    def compute_observation_window(self, slack_fraction: float = 0.2) -> float:
        """
        Example: sum up the forward_pass times from each Task's profiler data
        and multiply by (1 + slack_fraction).
        Return the total in microseconds or seconds (up to you).
        """
        total_forward_time = 0.0
        for task in self.tasks:
            # We can define a method on Task for "get_forward_pass_time()", or
            # do it inline by scanning task.prof_records.
            total_forward_time += task.get_forward_pass_time()  # see below

        obs_window = total_forward_time * (1.0 + slack_fraction)
        # print(obs_window)
        return obs_window

    def reset_loads(self):
        """
        Clears out the load dictionary so it can be recalculated
        or updated without stale values.
        """
        self.loads = {}


    def compute_loads(self, metric: "LoadMetric" = None) -> Dict[str, float]:
        """
        Computes load values for each task using either:
          - The user-provided 'metric' argument (if not None),
          - Or else the default 'self.metric' stored in Taskset (if metric is None).

        Returns:
            A dictionary mapping task_id -> load value.
        """
        # If user didn't pass a metric, fallback to the default (self.metric)
        if metric is None:
            # If self.metric doesn't exist or is also None, you might
            # want to create a default HPCUtilizationMetric here:
            if not hasattr(self, "metric") or self.metric is None:
                # from load_metric import HPCUtilizationMetric
                self.metric = HPCUtilizationMetric()  # or with a chosen slack
            metric = self.metric

        self.loads = {}
        for t in self.tasks:
            load_val = metric.compute(t, self)
            print(load_val)
            self.loads[t.task_id] = load_val

        return self.loads


    def execute_all(self):
        """
        Executes all tasks in parallel, managing stage allocations and dependencies.
        Each task runs in its own thread, and stages are executed via Node worker threads.
        """
        # Assign nodes to stages for each task using a load-balancing strategy
        for task in self.tasks:
            task.assign_nodes_to_stages(self.available_nodes)
            task.print_stage_allocations()

        threads = []
        for task in self.tasks:
            t = threading.Thread(target=self.execute_task, args=(task,))
            t.start()
            threads.append(t)

        for t in threads:
            t.join()

        self.calculate_metrics()

    def execute_task(self, task: Task):
        """
        Executes a single task by running its stages in topological order.

        Args:
            task (Task): The task to execute.
        """
        print(f"[Taskset] Starting execution of Task '{task.task_id}'.")
        task.start_time = time.time()

        try:
            execution_order = task.get_execution_order()
        except ValueError as e:
            print(f"[Taskset] Task '{task.task_id}' execution failed: {e}")
            return

        # Dictionary to keep track of node outputs
        node_outputs = {}

        # Iterate through stages in topological order
        for stage_id in execution_order:
            stage = task.get_stage(stage_id)
            if stage is None:
                print(f"[Taskset] Task '{task.task_id}': Stage '{stage_id}' not found.")
                continue

            # Define the function to execute the stage with proper binding
            def execute(current_stage=stage):
                current_stage.run_stage(node_outputs)

            # Enqueue the stage_execution function to the Node's task queue
            try:
                result_queue = stage.assigned_node.assign_task(execute)
            except AttributeError as e:
                print(f"[Taskset] Task '{task.task_id}': Stage '{stage_id}' failed to assign to Node: {e}")
                # Assign infinite execution and transfer time to indicate failure
                with self.lock:
                    stage.execution_time = float('inf')
                    stage.transfer_time = float('inf')
                continue

            # Wait for the stage to complete
            try:
                # We don't expect any return value from run_stage, so we just wait for completion
                result = result_queue.get()
                # Optionally, you can check if result is None or some status flag
            except Exception as e:
                print(f"[Taskset] Task '{task.task_id}': Stage '{stage_id}' encountered an error during execution: {e}")
                with self.lock:
                    stage.execution_time = float('inf')
                    stage.transfer_time = float('inf')
                continue

            # Handle potential errors by checking execution_time
            if stage.execution_time == float('inf'):
                print(f"[Taskset] Task '{task.task_id}': Stage '{stage_id}' failed during execution.")

        task.finish_time = time.time()
        print(f"[Taskset] Completed execution of Task '{task.task_id}'.")

    def calculate_metrics(self):
        """
        Calculates and updates performance metrics for the taskset.
        """
        # 1) total_busy_time = sum of (execution_time) across all tasks
        total_busy_time = sum(stage.execution_time for task in self.tasks for stage in task.stages.values())

        # 2) total_available_time = observation_window * #nodes
        if not any(task.finish_time for task in self.tasks):
            self.total_utilization = 0.0
            self.average_turnaround_time = 0.0
            self.throughput = 0.0
            self.makespan = 0.0
            self.task_completion_rate = 0.0
            return

        earliest_start = min(task.start_time for task in self.tasks if task.start_time)
        latest_finish = max(task.finish_time for task in self.tasks if task.finish_time)
        observation_window = latest_finish - earliest_start
        total_available_time = observation_window * len(set(
            stage.assigned_node.node_id for task in self.tasks for stage in task.stages.values()
        ))
        self.total_utilization = (total_busy_time / total_available_time) if total_available_time > 0 else 0.0

        # 3) average turnaround
        turnaround_times = [task.get_total_execution_time() for task in self.tasks]
        if turnaround_times:
            self.average_turnaround_time = sum(turnaround_times) / len(turnaround_times)
        else:
            self.average_turnaround_time = 0.0

        # 4) makespan = difference between earliest start and latest finish
        if earliest_start and latest_finish:
            self.makespan = latest_finish - earliest_start
        else:
            self.makespan = 0.0

        # 5) throughput = number_of_tasks / makespan
        if self.makespan > 0:
            self.throughput = len(self.tasks) / self.makespan
        else:
            self.throughput = 0.0

        # 6) task completion rate
        completed_tasks = [t for t in self.tasks if t.output_data is not None]
        self.task_completion_rate = len(completed_tasks) / len(self.tasks) if self.tasks else 0.0

    def __repr__(self):
        return (
            f"Taskset(total_tasks={len(self.tasks)}, "
            f"total_utilization={self.total_utilization:.2%}, "
            f"average_turnaround_time={self.average_turnaround_time:.6f} sec, "
            f"throughput={self.throughput:.2f} tasks/sec, "
            f"makespan={self.makespan:.6f} sec, "
            f"task_completion_rate={self.task_completion_rate:.2%})"
        )


In [None]:
# evaluator.py

import torch
from typing import Dict
import time

class Evaluator:
    """
    Evaluator runs tasks in naive mode vs. parallel mode and compares outputs.
    """

    def __init__(self, taskset: 'Taskset', profiler: 'Profiler'):
        self.taskset = taskset
        self.profiler = profiler

        # These track each task's final output (naive vs. parallel).
        self.naive_outputs: Dict[str, torch.Tensor] = {}
        self.parallel_outputs: Dict[str, torch.Tensor] = {}

        # These track each task's individual total time
        # (sum-of-times approach from your original code).
        self.naive_execution_times: Dict[str, float] = {}
        self.parallel_execution_times: Dict[str, float] = {}

        # We also track each task’s “completion time” (relative to the start
        # of naive or parallel execution). That lets us see
        # when each task actually finished.
        self.naive_completion_times: Dict[str, float] = {}
        self.parallel_completion_times: Dict[str, float] = {}

        # Finally, we store the overall naive and parallel “makespan”.
        self.naive_makespan: float = 0.0
        self.parallel_makespan: float = 0.0

    def run_naive_execution(self):
        """
        Runs each task **sequentially** and measures:
          - Per-task execution time (the original approach),
          - The time each task finishes (relative to the naive start),
          - The overall naive makespan.
        """
        print("[Evaluator] Starting Naive Execution.")

        # Overall start time for the naive run:
        naive_start = time.time()

        for task in self.taskset.tasks:
            model = task.model
            input_tensor = task.input_data
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

            print(f"[Evaluator] Task '{task.task_id}': Running on {device}")
            model.to(device)
            input_tensor = input_tensor.to(device)

            t0 = time.time()
            with torch.no_grad():
                output = model(input_tensor)
            t1 = time.time()

            exec_time = t1 - t0
            print(f"[Evaluator] Task '{task.task_id}' (Naive) executed in {exec_time:.6f} seconds.")

            # Store the sum-of-times approach for later reference
            self.naive_execution_times[task.task_id] = exec_time
            # The “finish time” for this task is how long since naive_start
            self.naive_completion_times[task.task_id] = time.time() - naive_start

            # Record the final output on CPU
            self.naive_outputs[task.task_id] = output.cpu()

        # After all tasks are done (sequentially), measure the total naive makespan
        naive_end = time.time()
        self.naive_makespan = naive_end - naive_start
        print(f"[Evaluator] Naive Execution Completed. Naive makespan = {self.naive_makespan:.6f} s\n")

    def run_parallel_execution(self):
        """
        Executes the entire taskset in parallel and measures:
          - The overall parallel makespan,
          - Each task’s sum-of-times approach,
          - Each task’s finish time (relative to the parallel start).
        """
        print("[Evaluator] Starting Parallel Execution.")

        # Clear old parallel data
        self.parallel_outputs.clear()
        self.parallel_execution_times.clear()
        self.parallel_completion_times.clear()

        # Start measuring overall parallel makespan
        parallel_start = time.time()

        self.taskset.execute_all()

        parallel_end = time.time()
        self.parallel_makespan = parallel_end - parallel_start

        # For each task, gather final outputs + per-task times
        for task in self.taskset.tasks:
            if task.output_data is not None:
                self.parallel_outputs[task.task_id] = task.output_data.cpu()
            else:
                self.parallel_outputs[task.task_id] = None

            # The sum-of-times approach is the existing code's “get_total_execution_time”
            exec_time = task.get_total_execution_time()
            self.parallel_execution_times[task.task_id] = exec_time

            # The finish_time is absolute. We want relative to parallel_start:
            # (task.finish_time is set by the time the Task completed)
            if task.finish_time is not None:
                self.parallel_completion_times[task.task_id] = task.finish_time - parallel_start
            else:
                self.parallel_completion_times[task.task_id] = float('nan')

        print(f"[Evaluator] Parallel Execution Completed. Parallel makespan = {self.parallel_makespan:.6f} s\n")

    def compare_outputs(self):
        """
        Compares naive vs. parallel outputs for correctness.
        """
        print("[Evaluator] Comparing Outputs.")
        all_match = True

        for task_id in self.naive_outputs:
            naive_out = self.naive_outputs[task_id]
            parallel_out = self.parallel_outputs.get(task_id, None)

            if naive_out is None or parallel_out is None:
                print(f"[Evaluator] Task '{task_id}' missing output in one execution.")
                all_match = False
                continue

            # Ensure both are on CPU for a fair comparison
            naive_out = naive_out.cpu()
            parallel_out = parallel_out.cpu()

            if torch.equal(naive_out, parallel_out):
                print(f"[Evaluator] Task '{task_id}' outputs match exactly.")
            elif torch.allclose(naive_out, parallel_out, atol=1e-5):
                print(f"[Evaluator] Task '{task_id}' outputs are close within tolerance.")
            else:
                print(f"[Evaluator] Task '{task_id}' outputs do NOT match.")
                all_match = False

        if all_match:
            print("[Evaluator] All outputs match.\n")
        else:
            print("[Evaluator] Some outputs differ.\n")

    def analyze_speedup_throughput(self):
        """
        Prints out:
          1) The sum-of-times approach (the old logic),
          2) The new makespan-based approach,
          3) Each task’s naive vs. parallel completion time,
          4) Speedup & throughput based on both approaches.
        """
        print("[Evaluator] Analyzing Speedup and Throughput.\n")

        #
        # 1) Print sum-of-times approach (the original code).
        #
        total_naive_time = sum(self.naive_execution_times.values())
        total_parallel_time = sum(self.parallel_execution_times.values())

        print(f"--- Sum-of-times approach ---")
        print(f"[Evaluator] Naive total time (sum of per-task):    {total_naive_time:.6f} s")
        print(f"[Evaluator] Parallel total time (sum of per-task): {total_parallel_time:.6f} s")

        sum_speedup = (
            total_naive_time / total_parallel_time
            if total_parallel_time > 0
            else float('inf')
        )
        num_tasks = len(self.taskset.tasks)

        # sum-of-times throughput
        naive_thr_sum = num_tasks / total_naive_time if total_naive_time > 0 else 0
        parallel_thr_sum = num_tasks / total_parallel_time if total_parallel_time > 0 else 0

        print(f"[Evaluator] Speedup (sum-of-times) = {sum_speedup:.2f}x")
        print(f"[Evaluator] Naive Throughput (sum-of-times)   = {naive_thr_sum:.2f} tasks/s")
        print(f"[Evaluator] Parallel Throughput (sum-of-times) = {parallel_thr_sum:.2f} tasks/s\n")

        #
        # 2) Print makespan-based approach.
        #
        print(f"--- Makespan-based approach ---")
        print(f"[Evaluator] Naive makespan:   {self.naive_makespan:.6f} s")
        print(f"[Evaluator] Parallel makespan: {self.parallel_makespan:.6f} s")

        makespan_speedup = (
            self.naive_makespan / self.parallel_makespan
            if self.parallel_makespan > 0
            else float('inf')
        )

        # makespan-based throughput
        naive_thr_makespan = num_tasks / self.naive_makespan if self.naive_makespan > 0 else 0
        parallel_thr_makespan = num_tasks / self.parallel_makespan if self.parallel_makespan > 0 else 0

        print(f"[Evaluator] Speedup (makespan-based) = {makespan_speedup:.2f}x")
        print(f"[Evaluator] Naive Throughput (makespan)   = {naive_thr_makespan:.2f} tasks/s")
        print(f"[Evaluator] Parallel Throughput (makespan) = {parallel_thr_makespan:.2f} tasks/s\n")

        #
        # 3) Print each task’s naive vs. parallel completion time
        #
        print(f"--- Task Completion Times (relative to start) ---")
        for task_id in self.naive_completion_times:
            naive_finish = self.naive_completion_times[task_id]
            parallel_finish = self.parallel_completion_times.get(task_id, float('nan'))
            print(f" Task {task_id}: naive finish={naive_finish:.6f}s, parallel finish={parallel_finish:.6f}s")

        print()  # extra newline at end


In [None]:
# test_script.py

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models
import copy
import time
# from node import Node
# from task import Task
# from taskset import Taskset
# from utils import group_topological_order, resolve_arg
# from profiler import Profiler  # Assuming profiler.py contains the Profiler class
# from evaluator import Evaluator  # Assuming evaluator.py contains the Evaluator class


# Define a SimpleCNN with torch.cat in its forward pass
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, padding=2)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * 28 * 28, 10)  # Assuming input images are 28x28

    def forward(self, x):
        # First path
        x1 = self.conv1(x)
        x1 = self.relu(x1)

        # Second path
        x2 = self.conv2(x)
        x2 = self.relu(x2)

        # Concatenate along the channel dimension
        x = torch.cat((x1, x2), dim=1)

        x = self.flatten(x)
        x = self.fc(x)
        return x


# Define a PretrainedResNet18 with modified final layer
class PretrainedResNet18(nn.Module):
    def __init__(self, num_classes: int = 10):
        """
        Initializes the ResNet18 model with a modified final layer.

        Args:
            num_classes (int): Number of output classes.
        """
        super(PretrainedResNet18, self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)
        num_ftrs = self.resnet18.fc.in_features
        self.resnet18.fc = nn.Linear(num_ftrs, num_classes)  # Modify for desired number of classes

    def forward(self, x):
        return self.resnet18(x)


def create_synthetic_dataloader(batch_size: int = 1, num_samples: int = 1, input_size=(3, 28, 28)):
    """
    Creates a synthetic DataLoader for SimpleCNN with input size (3, 28, 28).

    Args:
        batch_size (int): Number of samples per batch.
        num_samples (int): Total number of samples.

    Returns:
        DataLoader: PyTorch DataLoader with synthetic data.
    """
    inputs = torch.randn(num_samples, *input_size)
    targets = torch.randint(0, 10, (num_samples,))
    dataset = TensorDataset(inputs, targets)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    return dataloader


def create_resnet_dataloader(batch_size: int = 1, num_samples: int = 1, input_size=(3, 224, 224)):
    """
    Creates a synthetic DataLoader for ResNet18 with input size (3, 224, 224).

    Args:
        batch_size (int): Number of samples per batch.
        num_samples (int): Total number of samples.

    Returns:
        DataLoader: PyTorch DataLoader with synthetic data.
    """
    inputs = torch.randn(num_samples, *input_size)
    targets = torch.randint(0, 10, (num_samples,))  # ResNet18 typically has 1000 classes; adjust as needed
    dataset = TensorDataset(inputs, targets)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    return dataloader


def initialize_components():
    """
    Initializes the compute nodes, profiler, tasks for SimpleCNN and ResNet18,
    and assigns stages to nodes.

    Returns:
        Tuple[Taskset, Profiler, List[Node]]: Initialized Taskset, Profiler, and list of Nodes.
    """
    nodes = Node.discover_nodes()
    print(f"[Main] Discovered Nodes: {nodes}\n")

    profiler = Profiler(mode="init")  # Initialize Profiler
    print("[Main] Initialized Profiler.\n")



    # Create SimpleCNN Tasks
    num_simple_cnn_tasks = 1  # Adjust the number as needed
    simple_cnn_tasks = []
    for i in range(num_simple_cnn_tasks):
        model = SimpleCNN()
        dl = create_synthetic_dataloader(batch_size=10, num_samples=100, input_size=(3, 28, 28))
        single_input_cnn, _ = next(iter(dl))
        task = Task(
            task_id=f"simple_cnn_task{i+1}",
            model=model,
            input_data=single_input_cnn,
            model_name=model.__class__.__name__,
            profiler=profiler
        )
        simple_cnn_tasks.append(task)
    print(f"[Main] Created {num_simple_cnn_tasks} SimpleCNN Tasks.\n")

    # Create ResNet18 Tasks
    num_resnet_tasks = 1  # Adjust the number as needed
    resnet_tasks = []
    for i in range(num_resnet_tasks):
        model = PretrainedResNet18(num_classes=10)
        dl = create_resnet_dataloader(batch_size=10, num_samples=100, input_size=(3, 224, 224))
        single_input_resnet, _ = next(iter(dl))
        task = Task(
            task_id=f"resnet_task{i+1}",
            model=model,
            input_data=single_input_resnet,
            model_name=model.__class__.__name__,
            profiler=profiler
        )
        resnet_tasks.append(task)
    print(f"[Main] Created {num_resnet_tasks} ResNet18 Tasks.\n")

    # Combine all tasks into a single list
    all_tasks = simple_cnn_tasks + resnet_tasks
    print("[Main] Initialized Taskset with SimpleCNN and ResNet18 tasks.\n")

    # Initialize profiler by profiling each stage automatically
    # Note: Uncomment and modify as needed based on your Profiler implementation
    # 1) Profile each task on each node
    for task in all_tasks:
        for node_curr in nodes:
            input_copy = copy.deepcopy(task.input_data)
            profiler.profile_model(
                model=task.model,
                input_data=input_copy,
                node_id=node_curr.node_id,
                task_id=task.task_id,
            )

    # 2) Now that profiling is done, retrieve the final DataFrame
    profile_db = profiler.get_profile_db()

    # for task.


    # 3) For each task, filter out its profiler rows, build forward_pass aggregates if desired,
    #    then call init_profiling_data
    # for task in all_tasks:
    #     # Gather all rows in 'profile_db' that match this Task's ID
    #     task_rows = profile_db[profile_db['Task_ID'] == task.task_id]

    #     if task_rows.empty:
    #         # No rows found for this task (unusual), skip or log warning
    #         print(f"No profiler rows found for Task '{task.task_id}'")
    #         continue

    #     # Convert these rows into a list-of-dicts for easy consumption
    #     records_for_task = task_rows.to_dict(orient='records')

    #     # (Optional) Build a forward_pass dictionary if you want aggregated metrics:
    #     # e.g. get row(s) where Layer == 'forward_pass'
    #     forward_pass_data = task_rows[task_rows['Layer'] == 'forward_pass']
    #     if not forward_pass_data.empty:
    #         # Example: take max of 'CPU Total (us)' for forward_pass across all devices
    #         total_cpu_time = forward_pass_data['CPU Total (us)'].max()
    #         forward_pass_agg = {
    #             "max_forward_cpu_time_us": total_cpu_time,
    #             # Add other aggregates as you see fit
    #         }
    #     else:
    #         forward_pass_agg = {}

    #     # 4) Call init_profiling_data on the Task
    #     task.init_profiling_data(
    #         profiling_records=records_for_task,
    #         forward_pass_agg=forward_pass_agg
    #     )

    # # 5) Finally, you can print the profile database if you wish
    profiler.print_profile_db()

    for task in all_tasks:
        # task.print_stage_allocations()
        # print("Is the model in training mode?", task.model.training)
        # print(task.parsed_layer_data)
        print("\n")
        print(task.prof_records)
        # print()

    # taskset = Taskset(tasks=all_tasks, available_nodes=nodes)
    # load_metric = HPCUtilizationMetric(slack_fraction=0.2)
    taskset = Taskset(tasks=all_tasks, available_nodes=nodes)
    print("[Main] Initialized Taskset.\n")

    return taskset, profiler, nodes


def set_seed(seed: int = 42):
    """
    Sets the seed for Python, NumPy, and PyTorch to ensure reproducibility.

    Args:
        seed (int): The seed value to use. Default is 42.
    """
    import random
    import numpy as np
    import torch

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # For CUDA
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # If using multi-GPU.

    # Ensure deterministic behavior in CuDNN
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    print(f"[Utils] Seed set to {seed} for Python, NumPy, and PyTorch.")


def run_evaluation():
    taskset, profiler, nodes = initialize_components()

    # compute_loads
    taskset.compute_loads()

    set_seed(42)
    print("[Main] Seed set to 42 for reproducibility.\n")

    # Execute tasks in parallel
    # print("[Main] Starting Parallel Execution of Taskset.\n")
    # taskset.execute_all()

    # for task in taskset.tasks:
    #     task.print_stage_allocations()
    #     # print(task.prof_records.keys())
    #     print("Is the model in training mode?", task.model.training)
    # # Initialize evaluator
    # evaluator = Evaluator(taskset=taskset, profiler=profiler)

    # # Run Naive Execution
    # evaluator.run_naive_execution()

    # print("=== Performance Metrics ===")
    # print(taskset)
    # print("===========================")

    # # Run Parallel Execution
    # evaluator.run_parallel_execution()

    # # Compare Outputs
    # evaluator.compare_outputs()

    # # Analyze Speedup and Throughput
    # evaluator.analyze_speedup_throughput()

    # # Print final metrics
    # print("=== Final Performance Metrics ===")
    # print(taskset)
    # print("==================================")

    # # Shutdown nodes
    # print("[Main] Shutting down all Nodes.")
    # for node in nodes:
    #     node.stop()
    # print("[Main] All Nodes have been shut down.")

    # for task in taskset.tasks:
    #   # task.print_stage_allocations()
    #   print(task.task_id)
    #   print()
    #   print(task.prof_records)
    #   # print(task.loads)
    #   print()
    #   print("Is the model in training mode?", task.model.training)

    # print(taskset.loads)
    # print(taskset.tasks[1].loads


if __name__ == "__main__":
    run_evaluation()


[discover_nodes] Generated DISJOINT nodes: leftover CPU-only plus GPU+CPU with unique cores.
[Main] Discovered Nodes: [Node(CPU-0, cpus=(0,), gpu=None), Node(CPU-1, cpus=(1,), gpu=None)]

[Main] Initialized Profiler.

[Main] Created 1 SimpleCNN Tasks.



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 131MB/s]


[Main] Created 1 ResNet18 Tasks.

[Main] Initialized Taskset with SimpleCNN and ResNet18 tasks.

Starting profiling...


  warn("CUDA is not available, disabling CUDA profiling")
  df = pd.concat([df, new_row], ignore_index=True)
  warn("CUDA is not available, disabling CUDA profiling")


Profiling complete. Data saved to profiling_results.csv
Starting profiling...
Profiling complete. Data saved to profiling_results.csv
Starting profiling...


  warn("CUDA is not available, disabling CUDA profiling")


Profiling complete. Data saved to profiling_results.csv
Starting profiling...


  warn("CUDA is not available, disabling CUDA profiling")


Profiling complete. Data saved to profiling_results.csv
ProfileDB:
         Task_ID              Model                          Layer Compute  Self CPU (us)  CPU Total (us)  CUDA Total (us) Self CPU Mem (bytes) Self CUDA Mem (bytes)  Total Execution Time (us) Total Memory Used (bytes)
simple_cnn_task1          SimpleCNN                   forward_pass   CPU-0       5187.730       23445.961              0.0             -9032480                     0                  23445.961                  -9032480
simple_cnn_task1          SimpleCNN                          conv1   CPU-0       3079.728        8682.799              0.0                    0                     0                   8682.799                         0
simple_cnn_task1          SimpleCNN                         relu_1   CPU-0        375.149        3411.435              0.0                    0                     0                   3411.435                         0
simple_cnn_task1          SimpleCNN                      