# PyDart Library – First Iteration

## Overview

This notebook represents the initial iteration of the complete PyDart Library, developed following the completion of the Profiler. In this version, the library includes an explicit `Scheduler` class inspired by the DART research paper.

## Main Contributions

1. **Forward Pass Decomposition**  
   - The decomposition was performed on the forward pass of the deep neural network rather than on its DAG (Directed Acyclic Graph).

2. **Event Aggregation**  
   - The decomposition process aggregated individual events captured by the Profiler.

3. **Dynamic Programming (DP) Algorithm**  
   - A simplified version of the DP algorithm, as described in the research paper, was implemented.  
   - Note: This implementation was not fully complete.

## Error Analysis and Iterative Improvements

During profiling, several issues were identified:

- **Profiling Errors:**  
  - Most errors were related to improper or missing data retrieval.

- **Decomposition Errors:**  
  - There were issues with the incorrect decomposition of the forward pass.  
  - These errors have been addressed in subsequent iterations (detailed in the next notebook).

Since these errors were not directly related to the DP algorithm, the focus was placed on correcting the earlier stages of the process before refining the algorithm further.

## Development and Testing Phases

Throughout the development process, multiple Jupyter notebooks were used for testing and developing individual classes. The workflow is divided into three main phases:

1. **Initialization Phase**  
   - Configuration and setup of initial components
     - Profiler Class - Profile and store the results into a persistent DB.
     - Stage Class - The main building block or the fundamental unit that actually runs on the node.
     - Task Class - The actual DNN inference task for the model and encapusulates the Stages and the dependencies/dependent stages in it
     gives allocation on the Nodes.
     - Scheduler Class - Allocates the stages offiilne using the DP based approach and then dispatches during the runtime stage.
      - Taskset Class - Comprises of all the Tasks and the Scheduler objects in it and is used as the focal point for running the library from the user prespective.

2. **Runtime Phase**  
   - The phase intended for end-user interaction with the library.
   - This runs by running the taskset excute

3. **Evaluation Phase**  
   - The phase where the library and framework are evaluated for performance and correctness.
   - This is the phase , where we evaluate the library on the particular system and task configurations , to test its speedup and throughput.

---

**Note** - There were multiple iterations of this , while developing the said classes.I have included , what I felt like were the main checkpoints.
Even the following iterations after this follow almost a similar approach.


In [None]:
import os
import torch
import threading
import queue
from typing import Callable, Any, List

class Node:
    """
    Represents a computational resource: either a CPU-only node (1 CPU core)
    or a GPU+CPU pair. Each Node has:
      - node_id (e.g., 'CPU-0', 'GPU-0-CPU-1')
      - A worker thread + queue to run tasks
    """

    def __init__(self, node_id: str, cpus=None, gpu=None):
        self._node_id = node_id
        self._cpus = tuple(cpus or [])
        self._gpu = gpu

        self._original_affinity = os.sched_getaffinity(0)
        self._task_queue = queue.Queue()
        self._stop_signal = False

        self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
        self._worker_thread.start()

    @property
    def node_id(self):
        return self._node_id

    @property
    def cpus(self):
        return self._cpus

    @property
    def gpu(self):
        return self._gpu

    def assign_task(self, func: Callable, *args, **kwargs) -> queue.Queue:
        """
        Enqueue a function to this node. Returns a queue from which
        the caller can retrieve the result (blocking).
        """
        result_queue = queue.Queue(maxsize=1)
        self._task_queue.put((func, args, kwargs, result_queue))
        return result_queue

    def stop(self):
        """
        Signal the node to stop after processing queued tasks.
        """
        self._stop_signal = True
        self._task_queue.put(None)
        self._worker_thread.join()

    def _worker_loop(self):
        while not self._stop_signal:
            item = self._task_queue.get()
            if item is None:
                break
            func, args, kwargs, result_queue = item
            try:
                self._set_context()
                result = func(*args, **kwargs)
            except Exception as e:
                result = e
            finally:
                self._reset_context()

            result_queue.put(result)

    def _set_context(self):
        if self._cpus:
            os.sched_setaffinity(0, self._cpus)
        if self._gpu is not None and torch.cuda.is_available():
            torch.cuda.set_device(self._gpu)

    def _reset_context(self):
        os.sched_setaffinity(0, self._original_affinity)
        # optionally reset GPU device if needed

    @staticmethod
    def discover_nodes() -> List['Node']:
        """
        Create a Node for each CPU core, and for each GPU+CPU pair.
        """
        nodes = []
        num_cpus = os.cpu_count() or 1
        ngpus = torch.cuda.device_count()

        # CPU-only nodes
        for core_id in range(num_cpus):
            node = Node(node_id=f"CPU-{core_id}", cpus=[core_id])
            nodes.append(node)

        # GPU+CPU nodes
        for g in range(ngpus):
            for core_id in range(num_cpus):
                node = Node(node_id=f"GPU-{g}-CPU-{core_id}", cpus=[core_id], gpu=g)
                nodes.append(node)

        return nodes

    def __repr__(self):
        return f"Node({self._node_id}, cpus={self._cpus}, gpu={self._gpu})"

In [None]:
import time
import os
import re
import pandas as pd
import torch
import torch.nn as nn
import torch.profiler

class Profiler:
    """
    In 'init' mode: Gather detailed profiling info for each leaf layer on each Node,
    storing results in a CSV-based ProfileDB.
    In 'runtime' mode: Potentially gather minimal logs (optional).
    """

    def __init__(self, mode: str, profile_db_path='profiling_results.csv', log_dir='logs'):
        assert mode in ['init', 'runtime']
        self.mode = mode
        self.profile_db_path = profile_db_path
        self.log_dir = log_dir
        os.makedirs(self.log_dir, exist_ok=True)

        columns = [
            'Model', 'Layer', 'Compute',
            'Self CPU (us)', 'CPU Total (us)', 'CUDA Total (us)',
            'Self CPU Mem (bytes)', 'Self CUDA Mem (bytes)',
            'Total Execution Time (us)', 'Total Memory Used (bytes)'
        ]
        if os.path.exists(self.profile_db_path):
            self.profile_db = pd.read_csv(self.profile_db_path)
        else:
            self.profile_db = pd.DataFrame(columns=columns)

        self.runtime_csv = os.path.join(self.log_dir, 'runtime_results.csv')
        if not os.path.exists(self.runtime_csv):
            rt_cols = ['Model', 'Layer', 'Compute', 'Execution Time (us)']
            pd.DataFrame(columns=rt_cols).to_csv(self.runtime_csv, index=False)

    def _register_hooks(self, model: nn.Module):
        def hook_wrapper(layer_name):
            def hook(mod, inp, out):
                with torch.profiler.record_function(layer_name):
                    pass
            return hook

        for idx, (name, layer) in enumerate(model.named_modules()):
            if not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.ModuleList) and layer != model:
                layer.register_forward_hook(hook_wrapper(f"{name}_{idx}"))

    def profile_model(self, model: nn.Module, dataloader, node, model_name: str, warmup_iters=3):
        """
        Schedule a profiling task on 'node'. In 'init' mode, we gather
        full per-layer times.
        """

        def profiling_task():
            device = torch.device(f"cuda:{node.gpu}" if node.gpu is not None and torch.cuda.is_available() else "cpu")
            model.to(device)

            if self.mode == 'init':
                # warmup
                with torch.no_grad():
                    cnt = 0
                    for inputs, targets in dataloader:
                        inputs = inputs.to(device)
                        targets = targets.to(device)
                        model(inputs)
                        cnt += 1
                        if cnt >= warmup_iters:
                            break
                self._profile_init(model, dataloader, node, model_name, device)
            else:
                self._profile_runtime(model, dataloader, node, model_name, device)

        rq = node.assign_task(profiling_task)
        rq.get()  # block

    def _profile_init(self, model, dataloader, node, model_name, device):
        self._register_hooks(model)
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
            profile_memory=True
        ) as prof:
            with torch.no_grad():
                for step, (inp, tgt) in enumerate(dataloader):
                    if step >= 1:
                        break
                    inp = inp.to(device)
                    tgt = tgt.to(device)
                    model(inp)
                    prof.step()

        stats = self._process_events(prof, model, node, runtime=False)
        self._update_profile_db(stats, model_name, node, runtime=False)

    def _profile_runtime(self, model, dataloader, node, model_name, device):
        self._register_hooks(model)
        with torch.no_grad():
            for step, (inp, tgt) in enumerate(dataloader):
                if step >= 1:
                    break
                inp = inp.to(device)
                tgt = tgt.to(device)
                with torch.profiler.profile(
                    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
                ) as prof:
                    model(inp)
                    prof.step()
                stats = self._process_events(prof, model, node, runtime=True)
                self._append_runtime_csv(stats, model_name, node)

    def _process_events(self, profiler, model, node, runtime=False):
        recognized = set()
        for n, m in model.named_modules():
            if n:
                recognized.add(n)

        aggregated = {
            'forward_pass': dict(self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                                 self_cpu_memory_usage=0, self_cuda_memory_usage=0, compute=node.node_id),
            'misc': dict(self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                         self_cpu_memory_usage=0, self_cuda_memory_usage=0, compute=node.node_id)
        }

        events = list(profiler.events())
        found_root = False

        def strip_suffix(s):
            return re.sub(r'(\.|_)\d+$', '', s)

        for e in events:
            if e.name == "":
                found_root = True
                aggregated['forward_pass']['self_cpu_time_total'] += e.self_cpu_time_total
                aggregated['forward_pass']['cpu_time_total'] += e.cpu_time_total
                aggregated['forward_pass']['cuda_time_total'] += e.device_time_total
                if not runtime:
                    aggregated['forward_pass']['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                    aggregated['forward_pass']['self_cuda_memory_usage'] += e.self_device_memory_usage
            else:
                base = strip_suffix(e.name)
                if base in recognized:
                    if base not in aggregated:
                        aggregated[base] = dict(
                            self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                            self_cpu_memory_usage=0, self_cuda_memory_usage=0,
                            compute=node.node_id
                        )
                    aggregated[base]['self_cpu_time_total'] += e.self_cpu_time_total
                    aggregated[base]['cpu_time_total'] += e.cpu_time_total
                    aggregated[base]['cuda_time_total'] += e.device_time_total
                    if not runtime:
                        aggregated[base]['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                        aggregated[base]['self_cuda_memory_usage'] += e.self_device_memory_usage
                else:
                    aggregated['misc']['self_cpu_time_total'] += e.self_cpu_time_total
                    aggregated['misc']['cpu_time_total'] += e.cpu_time_total
                    aggregated['misc']['cuda_time_total'] += e.device_time_total
                    if not runtime:
                        aggregated['misc']['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                        aggregated['misc']['self_cuda_memory_usage'] += e.self_device_memory_usage

        # If no root event found, sum all into forward_pass
        if not found_root:
            for k in list(aggregated.keys()):
                if k not in ('forward_pass', 'misc'):
                    aggregated['forward_pass']['self_cpu_time_total'] += aggregated[k]['self_cpu_time_total']
                    aggregated['forward_pass']['cpu_time_total'] += aggregated[k]['cpu_time_total']
                    aggregated['forward_pass']['cuda_time_total'] += aggregated[k]['cuda_time_total']
                    if not runtime:
                        aggregated['forward_pass']['self_cpu_memory_usage'] += aggregated[k]['self_cpu_memory_usage']
                        aggregated['forward_pass']['self_cuda_memory_usage'] += aggregated[k]['self_cuda_memory_usage']

            aggregated['forward_pass']['self_cpu_time_total'] += aggregated['misc']['self_cpu_time_total']
            aggregated['forward_pass']['cpu_time_total'] += aggregated['misc']['cpu_time_total']
            aggregated['forward_pass']['cuda_time_total'] += aggregated['misc']['cuda_time_total']
            if not runtime:
                aggregated['forward_pass']['self_cpu_memory_usage'] += aggregated['misc']['self_cpu_memory_usage']
                aggregated['forward_pass']['self_cuda_memory_usage'] += aggregated['misc']['self_cuda_memory_usage']

        return aggregated

    def _update_profile_db(self, stats, model_name, node, runtime=False):
        if runtime:
            return
        for layer_name, data in stats.items():
            total_t = data['cpu_time_total'] + data['cuda_time_total']
            total_m = data['self_cpu_memory_usage'] + data['self_cuda_memory_usage']
            row = {
                'Model': model_name,
                'Layer': layer_name,
                'Compute': data['compute'],
                'Self CPU (us)': data['self_cpu_time_total'],
                'CPU Total (us)': data['cpu_time_total'],
                'CUDA Total (us)': data['cuda_time_total'],
                'Self CPU Mem (bytes)': data['self_cpu_memory_usage'],
                'Self CUDA Mem (bytes)': data['self_cuda_memory_usage'],
                'Total Execution Time (us)': total_t,
                'Total Memory Used (bytes)': total_m
            }
            self.profile_db = self._upsert(self.profile_db, row)
        self.profile_db.to_csv(self.profile_db_path, index=False)

    def _upsert(self, df, row):
        mask = (
            (df['Model'] == row['Model']) &
            (df['Layer'] == row['Layer']) &
            (df['Compute'] == row['Compute'])
        )
        if not df[mask].empty:
            existing_time = df.loc[mask, 'Total Execution Time (us)'].max()
            if row['Total Execution Time (us)'] > existing_time:
                for k, v in row.items():
                    df.loc[mask, k] = v
        else:
            df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
        return df

    def _append_runtime_csv(self, stats, model_name, node):
        rows = []
        for layer_name, data in stats.items():
            exec_time = data['cpu_time_total'] + data['cuda_time_total']
            rows.append({
                'Model': model_name,
                'Layer': layer_name,
                'Compute': data['compute'],
                'Execution Time (us)': exec_time
            })
        if rows:
            rdf = pd.read_csv(self.runtime_csv)
            rdf = pd.concat([rdf, pd.DataFrame(rows)], ignore_index=True)
            rdf.to_csv(self.runtime_csv, index=False)

    def get_profile_db(self):
        return self.profile_db

    def print_profile_db(self):
        if self.profile_db.empty:
            print("ProfileDB is empty.")
        else:
            print("ProfileDB:\n", self.profile_db.to_string(index=False))

In [None]:
import torch

class Stage:
    """
    A contiguous block of layers assigned to a single Node.
    """

    def __init__(self, stage_id: str, layers: torch.nn.ModuleList, assigned_node: 'Node'):
        self.stage_id = stage_id
        self.layers = layers
        self.assigned_node = assigned_node

        self.input_data = None
        self.output_data = None
        self.execution_time = 0.0

        self.dependencies = set()
        self.dependents = set()

    def run_stage(self):
        import time
        start = time.time()
        device = torch.device(
            f"cuda:{self.assigned_node.gpu}" if self.assigned_node.gpu is not None and torch.cuda.is_available()
            else "cpu"
        )
        inp = self.input_data.to(device)

        with torch.no_grad():
            out = inp
            for layer in self.layers:
                out = layer(out)

        if device.type == 'cuda':
            out = out.cpu()

        self.output_data = out
        end = time.time()
        self.execution_time = end - start
        return self.output_data

In [None]:
import threading
import numpy as np
from typing import List

class Scheduler:
    """
    A DP-based scheduler that partitions the entire set of layers into
    contiguous stages, assigns each stage to one node, and then
    can run them in a pipeline-parallel manner.
    """

    def __init__(self, profiler: 'Profiler', nodes: List['Node'], observation_window: float):
        self.profiler = profiler
        self.nodes = nodes
        self.observation_window = observation_window

        # We store a global stage graph: stage_id -> {node_idx, layer_names, deps, dependents}
        self.global_stage_graph = {}

    def decompose_and_allocate(self, model: torch.nn.Module, model_name: str, k_stages: int = None):
        """
        1) Identify leaf layers in a sequential order.
        2) Use DP to form at most k_stages contiguous blocks, each assigned to a single node.
        3) Minimizes the maximum load across all nodes.
        4) Build the global_stage_graph with adjacency.
        """
        # gather leaf layers in order
        leaf_layers = []
        for name, module in model.named_modules():
            if len(list(module.children())) == 0 and name:
                leaf_layers.append(name)

        n_layers = len(leaf_layers)
        if n_layers == 0:
            print("No leaf layers found in model!")
            return

        if k_stages is None:
            k_stages = min(len(self.nodes), n_layers)

        profile_db = self.profiler.get_profile_db()
        model_df = profile_db[profile_db['Model'] == model_name]

        # times[i][j] = time (seconds) for layer i on node j
        times = [[0.0]*len(self.nodes) for _ in range(n_layers)]
        for i, lname in enumerate(leaf_layers):
            for j, node in enumerate(self.nodes):
                row = model_df[(model_df['Layer'] == lname) & (model_df['Compute'] == node.node_id)]
                if not row.empty:
                    t_us = row['Total Execution Time (us)'].values[0]
                    times[i][j] = t_us / 1e6  # convert microseconds to sec
                else:
                    times[i][j] = 1.0  # fallback

        # build prefix sums
        prefix_times = [[0.0]*(n_layers+1) for _ in range(len(self.nodes))]
        for j in range(len(self.nodes)):
            for i in range(1, n_layers+1):
                prefix_times[j][i] = prefix_times[j][i-1] + times[i-1][j]

        def block_cost(x, i, node_j):
            return prefix_times[node_j][i] - prefix_times[node_j][x]

        INF = float('inf')
        DP = [[INF]*(k_stages+1) for _ in range(n_layers+1)]
        back_node = [[-1]*(k_stages+1) for _ in range(n_layers+1)]
        back_split = [[-1]*(k_stages+1) for _ in range(n_layers+1)]

        DP[0][0] = 0.0

        for i in range(1, n_layers+1):
            for s in range(1, k_stages+1):
                for j in range(len(self.nodes)):
                    for x in range(i):
                        c = block_cost(x, i, j)
                        cand = max(DP[x][s-1], c)
                        if cand < DP[i][s]:
                            DP[i][s] = cand
                            back_node[i][s] = j
                            back_split[i][s] = x

        min_max_load = DP[n_layers][k_stages]
        print(f"[Scheduler] DP => min possible max load = {min_max_load:.4f} sec")

        # reconstruct
        i = n_layers
        s = k_stages
        partitions = []
        while i > 0 and s > 0:
            nj = back_node[i][s]
            xx = back_split[i][s]
            partitions.append((xx, i, nj))
            i = xx
            s -= 1

        partitions.reverse()

        # build global_stage_graph
        self.global_stage_graph.clear()
        idx = 0
        for (start_i, end_i, node_j) in partitions:
            stage_id = f"stage-{idx}"
            assigned_node = self.nodes[node_j]
            layer_subset = leaf_layers[start_i:end_i]

            self.global_stage_graph[stage_id] = {
                'node_idx': node_j,
                'layer_names': layer_subset,
                'dependencies': [],
                'dependents': []
            }
            idx += 1

        # link them linearly (stage i depends on stage i-1)
        keys = list(self.global_stage_graph.keys())
        for k in range(1, len(keys)):
            prev_stg = keys[k-1]
            curr_stg = keys[k]
            self.global_stage_graph[curr_stg]['dependencies'].append(prev_stg)
            self.global_stage_graph[prev_stg]['dependents'].append(curr_stg)

    def create_stages_for_task(self, task: 'Task'):
        """
        Convert global_stage_graph into actual Stage objects for the given task.
        """
        import torch.nn as nn
        for stage_id, info in self.global_stage_graph.items():
            node_idx = info['node_idx']
            node = self.nodes[node_idx]

            layer_list = []
            for lname in info['layer_names']:
                layer_list.append(task.model.get_submodule(lname))

            stg = Stage(stage_id, nn.ModuleList(layer_list), node)
            stg.dependencies = set(info['dependencies'])
            stg.dependents = set(info['dependents'])
            task.stages[stage_id] = stg

    def run_pipeline_parallel(self, task: 'Task'):
        """
        Pipeline parallel approach:
         - track how many dependencies remain
         - each stage that has 0 dependencies is enqueued
         - on completion, reduce the dependent's dep count
        """
        lock = threading.Lock()
        dep_count = {}
        for sid, stg in task.stages.items():
            dep_count[sid] = len(stg.dependencies)

        total_stages = len(task.stages)
        completed = 0
        done_event = threading.Event()

        def stage_done_callback(sid: str):
            nonlocal completed
            with lock:
                for dep_id in task.stages[sid].dependents:
                    dep_count[dep_id] -= 1
                    if dep_count[dep_id] == 0:
                        enqueue_stage(dep_id)
                completed += 1
                if completed >= total_stages:
                    done_event.set()

        def enqueue_stage(sid: str):
            stg = task.stages[sid]

            def run_fn():
                return stg.run_stage()

            def worker():
                rq = stg.assigned_node.assign_task(run_fn)
                _ = rq.get()
                stage_done_callback(sid)

            threading.Thread(target=worker, daemon=True).start()

        # initial
        for sid, c in dep_count.items():
            if c == 0:
                enqueue_stage(sid)

        done_event.wait()

In [None]:
from typing import Dict, Optional

class Task:
    """
    Represents a single inference pass for a model + input.
    Stages are built by the Scheduler's create_stages_for_task call.
    """

    def __init__(self, task_id: str, model: torch.nn.Module, input_data: torch.Tensor, model_name: str):
        self.task_id = task_id
        self.model = model
        self.input_data = input_data
        self.model_name = model_name

        self.stages: Dict[str, Stage] = {}

        self.start_time: Optional[float] = None
        self.finish_time: Optional[float] = None

    def get_total_execution_time(self):
        if self.start_time and self.finish_time:
            return self.finish_time - self.start_time
        return 0.0


############################################################
# 6) Taskset Class
############################################################

import time

class Taskset:
    """
    A collection of tasks to run with the DP-based pipeline approach.
    """

    def __init__(self, tasks: List[Task], scheduler: Scheduler):
        self.tasks = tasks
        self.scheduler = scheduler

    def schedule_all_tasks(self):
        # Build stage objects for each Task from the global stage graph
        for t in self.tasks:
            self.scheduler.create_stages_for_task(t)

    def execute_all(self):
        # We run each task in parallel
        import threading
        start_global = time.time()
        threads = []

        def run_task(task: Task):
            task.start_time = time.time()
            self.scheduler.run_pipeline_parallel(task)
            task.finish_time = time.time()

        for t in self.tasks:
            th = threading.Thread(target=run_task, args=(t,), daemon=True)
            th.start()
            threads.append(th)

        for th in threads:
            th.join()

        end_global = time.time()
        print(f"[Taskset] All tasks completed in {end_global - start_global:.2f}s")

    def collect_metrics(self):
        # For demonstration, we do not track node usage in detail here.
        return {}

In [None]:
class Evaluator:
    """
    Compares naive single-device run vs. pipeline approach.
    """

    def __init__(self, device='cpu'):
        self.device = device
        self.naive_run_time = None
        self.pipeline_run_time = None

    def run_naive_pytorch(self, tasks: List[Task]):
        """
        Runs each Task sequentially on self.device (layer-by-layer in order).
        Measures total time across all tasks.
        """
        import time
        dev = torch.device(self.device)
        start = time.time()

        with torch.no_grad():
            for task in tasks:
                task.model.to(dev)
                out = task.model(task.input_data.to(dev))
                _ = out.cpu() if dev.type == 'cuda' else out

        end = time.time()
        self.naive_run_time = end - start
        print(f"[Evaluator] Naive run on device={self.device} for {len(tasks)} tasks => {self.naive_run_time:.2f}s")

    def run_pipeline(self, taskset: Taskset):
        import time
        start = time.time()
        taskset.schedule_all_tasks()
        taskset.execute_all()
        end = time.time()

        self.pipeline_run_time = end - start
        print(f"[Evaluator] Pipeline approach => {self.pipeline_run_time:.2f}s")

    def compare_results(self, node_usage_info):
        print("\n=== Comparison ===")
        if self.naive_run_time is None or self.pipeline_run_time is None:
            print("Cannot compare, missing times.")
            return
        print(f"Naive total : {self.naive_run_time:.2f}s")
        print(f"Pipeline total : {self.pipeline_run_time:.2f}s")
        if self.pipeline_run_time > 0:
            speedup = self.naive_run_time / self.pipeline_run_time
        else:
            speedup = 1.0
        print(f"Speedup = {speedup:.2f}x")

        if node_usage_info:
            print("Node Usage Info:")
            for k, v in node_usage_info.items():
                print(f"  {k}: {v}")


In [None]:
# def main_demo():
#     import torch
#     import torch.nn as nn
#     import torch.utils.data as data
#     import gc  # for Python garbage collection

#     # 1) Build a sample model
#     class BiggerFFN(nn.Module):
#         def __init__(self):
#             super().__init__()
#             self.net = nn.Sequential(
#                 nn.Linear(256, 512),
#                 nn.ReLU(),
#                 nn.Linear(512, 256),
#                 nn.ReLU(),
#                 nn.Linear(256, 128),
#                 nn.ReLU(),
#                 nn.Linear(128, 10)
#             )
#         def forward(self, x):
#             return self.net(x)

#     model = BiggerFFN()
#     input_tensor = torch.randn(1, 256)

#     # 2) A dummy dataset for profiling
#     class DummyDataset(data.Dataset):
#         def __init__(self, size=5):
#             self.size = size
#         def __len__(self):
#             return self.size
#         def __getitem__(self, idx):
#             return torch.randn(1, 256), torch.tensor(0)

#     dloader = data.DataLoader(DummyDataset(), batch_size=1)

#     # 3) Profiler in 'init' mode
#     profiler_init = Profiler(mode='init', profile_db_path='dp_partition_profiling.csv')

#     # 4) Discover Nodes
#     nodes = Node.discover_nodes()
#     if not nodes:
#         print("No nodes discovered!")
#         return

#     print("Discovered Nodes (up to first 4):")
#     for nd in nodes[:4]:
#         print(" ", nd)
#     if len(nodes) > 4:
#         print("  ...")

#     # 5) Profile the model on each node
#     model_name = "Partitioned_BiggerFFN"
#     for node in nodes:
#         print(f"\nProfiling model on {node.node_id} ...")
#         profiler_init.profile_model(model, dloader, node, model_name)

#     # 6) Build DP-based Scheduler
#     df = profiler_init.get_profile_db()
#     total_us = df[df['Model'] == model_name]['Total Execution Time (us)'].sum()
#     total_s = total_us / 1e6
#     observation_window = total_s*1.1  # arbitrary slack

#     scheduler = Scheduler(profiler_init, nodes, observation_window)
#     scheduler.decompose_and_allocate(model, model_name, k_stages=None)  # let it choose up to len(nodes) or n_layers

#     # 7) Create multiple tasks
#     task1 = Task("Task1", model, input_tensor.clone(), model_name)
#     task2 = Task("Task2", model, input_tensor.clone(), model_name)
#     task3 = Task("Task3", model, input_tensor.clone(), model_name)
#     tasks = [task1, task2, task3]

#     # 8) Create a Taskset
#     taskset = Taskset(tasks, scheduler)

#     # 9) Evaluator
#     evaluator = Evaluator(device='gpu')

#     # 10) Naive run
#     evaluator.run_naive_pytorch(tasks)

#     # **CLEAR CACHES** after naive run, before pipeline run
#     # (If running on GPU, torch.cuda.empty_cache() helps free memory
#     #  also we can do Python-level gc to ensure minimal leftover references)
#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()
#     gc.collect()

#     # 11) Pipeline run
#     print("\n=== Running pipeline approach (DP-based grouping) ===")
#     evaluator.run_pipeline(taskset)

#     # 12) Compare results
#     usage_info = taskset.collect_metrics()
#     evaluator.compare_results(usage_info)

#     # 13) Stop all nodes
#     for nd in nodes:
#         nd.stop()

#     # Print final ProfileDB
#     print("\nFinal ProfileDB:")
#     profiler_init.print_profile_db()


In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import gc  # for garbage collection if needed
import torchvision.models as models  # for ResNet

def main_demo():

    #####################################################
    # 1) Example Models Definition
    #####################################################

    # a) Our original BiggerFFN model
    class BiggerFFN(nn.Module):
        def __init__(self):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(256, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 128),
                nn.ReLU(),
                nn.Linear(128, 10)
            )
        def forward(self, x):
            return self.net(x)

    # b) A ResNet18 (untrained) from torchvision
    resnet_model = models.resnet18(pretrained=False)
    # By default, ResNet18 expects input shape [N, 3, H, W], e.g. [1, 3, 224, 224]

    # We'll use these model names to store separate profiles
    ffn_model_name = "Partitioned_BiggerFFN"
    resnet_model_name = "Partitioned_ResNet18"

    #####################################################
    # 2) Dummy Datasets + Dataloaders for Profiling
    #####################################################

    # a) For BiggerFFN, we want an input of shape (1,256)
    class FFNDataset(data.Dataset):
        def __init__(self, size=5):
            self.size = size
        def __len__(self):
            return self.size
        def __getitem__(self, idx):
            # Return (input_tensor, target), matching shape [1,256]
            return torch.randn(1, 256), torch.tensor(0)

    # b) For ResNet18, we want an input of shape [1,3,224,224]
    class ResNetDataset(data.Dataset):
        def __init__(self, size=5):
            self.size = size
        def __len__(self):
            return self.size
        def __getitem__(self, idx):
            # Typically ResNet18 expects 3-channel images
            return torch.randn(3, 224, 224), torch.tensor(0)

    # Instantiate data loaders for each model
    ffn_loader = data.DataLoader(FFNDataset(), batch_size=1)
    resnet_loader = data.DataLoader(ResNetDataset(), batch_size=1)

    #####################################################
    #    INIT PHASE
    #####################################################
    def init_phase():
        print("=== INIT PHASE ===\n")

        # 1) Create a Profiler in 'init' mode
        profiler_init = Profiler(mode='init', profile_db_path='dp_partition_profiling.csv')

        # 2) Discover nodes
        nodes = Node.discover_nodes()
        if not nodes:
            print("No nodes discovered! Exiting init phase.")
            return None, None, None

        print("Discovered Nodes (up to first 4):")
        for nd in nodes[:4]:
            print(" ", nd)
        if len(nodes) > 4:
            print("  ...")

        # 3) Profile BOTH models on each node

        # a) Profile the BiggerFFN
        print("\n[INIT] Profiling BiggerFFN on each node ...")
        for node in nodes:
            profiler_init.profile_model(BiggerFFN(), ffn_loader, node, ffn_model_name)

        # b) Profile the ResNet18
        print("\n[INIT] Profiling ResNet18 on each node ...")
        for node in nodes:
            profiler_init.profile_model(resnet_model, resnet_loader, node, resnet_model_name)

        # 4) Build TWO DP-based schedulers (one for each model),
        #    because each model has different layer sets.

        # ---- Scheduler for BiggerFFN ----
        df = profiler_init.get_profile_db()
        total_us_ffn = df[df['Model'] == ffn_model_name]['Total Execution Time (us)'].sum()
        total_s_ffn = total_us_ffn / 1e6
        obs_window_ffn = total_s_ffn * 1.1  # 10% slack

        scheduler_ffn = Scheduler(profiler_init, nodes, obs_window_ffn)
        scheduler_ffn.decompose_and_allocate(BiggerFFN(), ffn_model_name, k_stages=None)

        # ---- Scheduler for ResNet18 ----
        total_us_res = df[df['Model'] == resnet_model_name]['Total Execution Time (us)'].sum()
        total_s_res = total_us_res / 1e6
        obs_window_res = total_s_res * 1.1

        scheduler_res = Scheduler(profiler_init, nodes, obs_window_res)
        scheduler_res.decompose_and_allocate(resnet_model, resnet_model_name, k_stages=None)

        return profiler_init, nodes, (scheduler_ffn, scheduler_res)

    #####################################################
    #    RUNTIME PHASE
    #####################################################
    def runtime_phase(profiler_init, nodes, schedulers):
        print("\n=== RUNTIME PHASE ===\n")

        scheduler_ffn, scheduler_res = schedulers

        # a) Create tasks for each model
        # - For the FFN, shape [1,256]
        input_ffn = torch.randn(1, 256)
        task_ffn1 = Task("TaskFFN1", BiggerFFN(), input_ffn.clone(), ffn_model_name)
        task_ffn2 = Task("TaskFFN2", BiggerFFN(), input_ffn.clone(), ffn_model_name)

        # - For the ResNet, shape [1,3,224,224]
        input_res = torch.randn(1, 3, 224, 224)
        task_res1 = Task("TaskRes1", resnet_model, input_res.clone(), resnet_model_name)
        task_res2 = Task("TaskRes2", resnet_model, input_res.clone(), resnet_model_name)

        # We'll do a separate pipeline run for FFN tasks vs. ResNet tasks
        # because each set uses a different scheduler/stage decomposition.

        ############################################
        # b) Evaluate FFN tasks
        ############################################
        tasks_ffn = [task_ffn1, task_ffn2]
        taskset_ffn = Taskset(tasks_ffn, scheduler_ffn)
        evaluator_ffn = Evaluator(device='cuda')

        print("\n--- Evaluating BiggerFFN Tasks ---")

        # (1) Naive
        evaluator_ffn.run_naive_pytorch(tasks_ffn)

        # Clear caches if desired
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

        # (2) Pipeline
        print("\n=== Running pipeline approach for BiggerFFN ===")
        evaluator_ffn.run_pipeline(taskset_ffn)

        usage_info_ffn = taskset_ffn.collect_metrics()
        evaluator_ffn.compare_results(usage_info_ffn)

        # c) Evaluate ResNet tasks
        tasks_res = [task_res1, task_res2]
        taskset_res = Taskset(tasks_res, scheduler_res)
        evaluator_res = Evaluator(device='cuda')

        print("\n--- Evaluating ResNet18 Tasks ---")

        # (1) Naive
        evaluator_res.run_naive_pytorch(tasks_res)

        # Clear caches if desired
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

        # (2) Pipeline
        print("\n=== Running pipeline approach for ResNet18 ===")
        evaluator_res.run_pipeline(taskset_res)

        usage_info_res = taskset_res.collect_metrics()
        evaluator_res.compare_results(usage_info_res)

        tasks_ffn_resnet = [task_ffn1, task_ffn2,task_res1,task_res2]
        taskset_ffn = Taskset(tasks_ffn_resnet, scheduler_ffn)
        evaluator_ffn = Evaluator(device='cuda')
        # d) Stop nodes
        for nd in nodes:
            nd.stop()

        # e) Print final ProfileDB
        print("\nFinal ProfileDB:")
        profiler_init.print_profile_db()

    #####################################################
    # main_demo flow
    #####################################################

    profiler_init, nodes, schedulers = init_phase()
    if profiler_init is not None and nodes is not None and schedulers is not None:
        runtime_phase(profiler_init, nodes, schedulers)
    else:
        print("Init phase failed. No runtime execution.")


In [None]:
# # Suppose all your classes are in a module named pipeline_lib
# # For example:
# # from pipeline_lib import Node, Profiler, Stage, Scheduler, Task, Taskset, Evaluator

# import torch
# import torch.nn as nn
# import torch.utils.data as data
# import gc
# import torchvision.models as models

# def demo_single_scheduler():
#     """
#     Demonstrates a SINGLE unified scheduler that can handle multiple models
#     by storing multiple stage graphs internally. The user only sees a single 'Scheduler'.
#     """

#     print("\n=== Demo: Single Unified Scheduler ===")

#     # 1) Define two models
#     class BiggerFFN(nn.Module):
#         def __init__(self):
#             super().__init__()
#             self.net = nn.Sequential(
#                 nn.Linear(256, 512),
#                 nn.ReLU(),
#                 nn.Linear(512, 256),
#                 nn.ReLU(),
#                 nn.Linear(256, 128),
#                 nn.ReLU(),
#                 nn.Linear(128, 10)
#             )
#         def forward(self, x):
#             return self.net(x)

#     resnet_model = models.resnet18(pretrained=False)
#     ffn_model_name = "Partitioned_BiggerFFN"
#     resnet_model_name = "Partitioned_ResNet18"

#     # 2) Profile both models
#     profiler_init = Profiler(mode='init', profile_db_path='profiling_results_single_sched.csv')
#     nodes = Node.discover_nodes()
#     if not nodes:
#         print("No nodes discovered, exiting.")
#         return

#     # trivial data
#     class FFNDataset(data.Dataset):
#         def __len__(self):
#             return 5
#         def __getitem__(self, idx):
#             return torch.randn(1,256), torch.tensor(0)

#     class ResNetDataset(data.Dataset):
#         def __len__(self):
#             return 5
#         def __getitem__(self, idx):
#             return torch.randn(3,224,224), torch.tensor(0)

#     ffn_loader = data.DataLoader(FFNDataset(), batch_size=1)
#     resnet_loader = data.DataLoader(ResNetDataset(), batch_size=1)

#     # Profile each model on each node
#     for node in nodes:
#         profiler_init.profile_model(BiggerFFN(), ffn_loader, node, ffn_model_name)
#         profiler_init.profile_model(resnet_model, resnet_loader, node, resnet_model_name)

#     df = profiler_init.get_profile_db()

#     # 3) Single "MultiModelScheduler"
#     class MultiModelScheduler(Scheduler):
#         def __init__(self, profiler, nodes):
#             # We won't call parent init with observation_window because each model might differ
#             # We'll store a dict of model_name -> (observation_window, stage_graph, etc.)
#             self.profiler = profiler
#             self.nodes = nodes
#             self.model_plans = {}  # e.g. { model_name: { 'obs_window': X, 'stage_graph': {...} } }

#         def decompose_and_allocate_multi(self, model_obj, model_name):
#             """
#             Similar to decompose_and_allocate, but we store the result in self.model_plans[model_name].
#             """
#             # Build times
#             df_local = self.profiler.get_profile_db()
#             sub = df_local[df_local['Model'] == model_name]
#             total_us = sub['Total Execution Time (us)'].sum()
#             total_s = total_us / 1e6
#             obs_window = total_s * 1.2  # some slack

#             # We'll do a partial approach: call parent's logic but store in a dictionary
#             # For that, we might define parent's DP method as a function we can call
#             # or replicate the code. For brevity, let's replicate a shorter version:

#             # Build list of leaf layers, compute times, do DP, store stage graph
#             # We'll store stage_graph in self.model_plans[model_name]
#             # ... (omitting the full DP code for brevity)...

#             super().__init__(self.profiler, self.nodes, obs_window)  # we hacky call the parent init

#             # Actually call parent's "decompose_and_allocate"
#             super().decompose_and_allocate(model_obj, model_name, k_stages=None)

#             # Now store the parent's result
#             plan = self.global_stage_graph
#             self.model_plans[model_name] = {
#                 'obs_window': obs_window,
#                 'stage_graph': plan
#             }
#             # Clear parent's global_stage_graph so we can do the next model
#             self.global_stage_graph = {}

#         def create_stages_for_task(self, task):
#             """
#             For a task referencing 'task.model_name', we retrieve the stored stage_graph.
#             Then we build Stage objects accordingly.
#             """
#             if task.model_name not in self.model_plans:
#                 raise ValueError(f"No stage graph found for model {task.model_name}")
#             stage_graph = self.model_plans[task.model_name]['stage_graph']

#             import torch.nn as nn
#             for stage_id, info in stage_graph.items():
#                 node_idx = info['node_idx']
#                 node = self.nodes[node_idx]

#                 layer_list = []
#                 for lname in info['layer_names']:
#                     layer_list.append(task.model.get_submodule(lname))

#                 stg = Stage(stage_id, nn.ModuleList(layer_list), node)
#                 stg.dependencies = set(info['dependencies'])
#                 stg.dependents = set(info['dependents'])
#                 task.stages[stage_id] = stg

#         def run_pipeline_parallel(self, task):
#             """
#             Same pipeline approach, but we do it for the task's stage objects.
#             """
#             # We can reuse the parent's method
#             super().run_pipeline_parallel(task)

#     # Instantiate the multi-model scheduler
#     multi_sched = MultiModelScheduler(profiler_init, nodes)

#     # Decompose both models
#     multi_sched.decompose_and_allocate_multi(BiggerFFN(), ffn_model_name)
#     multi_sched.decompose_and_allocate_multi(resnet_model, resnet_model_name)

#     # 4) Create tasks
#     input_ffn = torch.randn(1, 256)
#     task_ffn1 = Task("TaskFFN1", BiggerFFN(), input_ffn.clone(), ffn_model_name)
#     task_ffn2 = Task("TaskFFN2", BiggerFFN(), input_ffn.clone(), ffn_model_name)

#     input_res = torch.randn(1, 3, 224, 224)
#     task_res1 = Task("TaskRes1", resnet_model, input_res.clone(), resnet_model_name)
#     task_res2 = Task("TaskRes2", resnet_model, input_res.clone(), resnet_model_name)

#     tasks_all = [task_ffn1, task_ffn2, task_res1, task_res2]

#     # 5) Single Taskset referencing multi_sched
#     class MultiModelTaskset:
#         def __init__(self, tasks, scheduler):
#             self.tasks = tasks
#             self.scheduler = scheduler
#         def schedule_all_tasks(self):
#             for t in self.tasks:
#                 self.scheduler.create_stages_for_task(t)
#         def execute_all(self):
#             import time, threading
#             start = time.time()
#             threads = []
#             def worker(task):
#                 self.scheduler.run_pipeline_parallel(task)
#             for t in self.tasks:
#                 th = threading.Thread(target=worker, args=(t,), daemon=True)
#                 th.start()
#                 threads.append(th)
#             for th in threads:
#                 th.join()
#             end = time.time()
#             print(f"[SingleScheduler] All tasks done in {end-start:.4f}s")
#         def collect_metrics(self):
#             return {}

#     mm_taskset = MultiModelTaskset(tasks_all, multi_sched)

#     # 6) Evaluate
#     evaluator = Evaluator(device='cuda')
#     evaluator.run_naive_pytorch(tasks_all)

#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()
#     gc.collect()

#     print("\n=== Pipeline for All Tasks (Single Unified Scheduler) ===")
#     evaluator.run_pipeline(mm_taskset)
#     usage_info = mm_taskset.collect_metrics()
#     evaluator.compare_results(usage_info)

#     # Stop nodes
#     for nd in nodes:
#         nd.stop()

#     print("\n=== Done: Single Unified Scheduler Demo ===")
#     profiler_init.print_profile_db()

In [None]:
demo_single_scheduler()

In [None]:
# main_demo()

In [None]:
# # -*- coding: utf-8 -*-
# """Poc_final_dart.ipynb

# Automatically generated by Colab.
# """

# import os
# import torch
# import threading
# import queue
# from typing import Callable, Any, List

# ############################################
# # 1) Node Class (Unchanged)
# ############################################

# class Node:
#     def __init__(self, node_id: str, cpus=None, gpu=None):
#         self._node_id = node_id
#         self._cpus = tuple(cpus or [])
#         self._gpu = gpu

#         self._original_affinity = os.sched_getaffinity(0)
#         self._task_queue = queue.Queue()
#         self._stop_signal = False

#         self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
#         self._worker_thread.start()

#     @property
#     def node_id(self):
#         return self._node_id

#     @property
#     def cpus(self):
#         return self._cpus

#     @property
#     def gpu(self):
#         return self._gpu

#     def assign_task(self, func: Callable, *args, **kwargs) -> queue.Queue:
#         result_queue = queue.Queue(maxsize=1)
#         self._task_queue.put((func, args, kwargs, result_queue))
#         return result_queue

#     def stop(self):
#         self._stop_signal = True
#         self._task_queue.put(None)
#         self._worker_thread.join()

#     def _worker_loop(self):
#         while not self._stop_signal:
#             item = self._task_queue.get()
#             if item is None:
#                 break
#             func, args, kwargs, result_queue = item
#             try:
#                 self._set_context()
#                 result = func(*args, **kwargs)
#             except Exception as e:
#                 result = e
#             finally:
#                 self._reset_context()

#             result_queue.put(result)

#     def _set_context(self):
#         if self._cpus:
#             os.sched_setaffinity(0, self._cpus)
#         if self._gpu is not None and torch.cuda.is_available():
#             torch.cuda.set_device(self._gpu)

#     def _reset_context(self):
#         os.sched_setaffinity(0, self._original_affinity)

#     @staticmethod
#     def discover_nodes() -> List['Node']:
#         nodes = []
#         num_cpus = os.cpu_count() or 1
#         ngpus = torch.cuda.device_count()
#         for core_id in range(num_cpus):
#             node = Node(node_id=f"CPU-{core_id}", cpus=[core_id])
#             nodes.append(node)
#         for g in range(ngpus):
#             for core_id in range(num_cpus):
#                 node = Node(node_id=f"GPU-{g}-CPU-{core_id}", cpus=[core_id], gpu=g)
#                 nodes.append(node)
#         return nodes

#     def __repr__(self):
#         return f"Node({self._node_id}, cpus={self._cpus}, gpu={self._gpu})"


# ############################################
# # 2) Profiler Class (Unchanged)
# ############################################

# import time
# import os
# import re
# import pandas as pd
# import torch
# import torch.nn as nn
# import torch.profiler

# class Profiler:
#     def __init__(self, mode: str, profile_db_path='profiling_results.csv', log_dir='logs'):
#         assert mode in ['init', 'runtime']
#         self.mode = mode
#         self.profile_db_path = profile_db_path
#         self.log_dir = log_dir
#         os.makedirs(self.log_dir, exist_ok=True)

#         columns = [
#             'Model', 'Layer', 'Compute',
#             'Self CPU (us)', 'CPU Total (us)', 'CUDA Total (us)',
#             'Self CPU Mem (bytes)', 'Self CUDA Mem (bytes)',
#             'Total Execution Time (us)', 'Total Memory Used (bytes)'
#         ]
#         if os.path.exists(self.profile_db_path):
#             self.profile_db = pd.read_csv(self.profile_db_path)
#         else:
#             self.profile_db = pd.DataFrame(columns=columns)

#         self.runtime_csv = os.path.join(self.log_dir, 'runtime_results.csv')
#         if not os.path.exists(self.runtime_csv):
#             rt_cols = ['Model', 'Layer', 'Compute', 'Execution Time (us)']
#             pd.DataFrame(columns=rt_cols).to_csv(self.runtime_csv, index=False)

#     def _register_hooks(self, model: nn.Module):
#         def hook_wrapper(layer_name):
#             def hook(mod, inp, out):
#                 with torch.profiler.record_function(layer_name):
#                     pass
#             return hook

#         for idx, (name, layer) in enumerate(model.named_modules()):
#             if not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.ModuleList) and layer != model:
#                 layer.register_forward_hook(hook_wrapper(f"{name}_{idx}"))

#     def profile_model(self, model: nn.Module, dataloader, node, model_name: str, warmup_iters=3):
#         def profiling_task():
#             device = torch.device(f"cuda:{node.gpu}" if node.gpu is not None and torch.cuda.is_available() else "cpu")
#             model.to(device)

#             if self.mode == 'init':
#                 with torch.no_grad():
#                     cnt = 0
#                     for inputs, targets in dataloader:
#                         inputs = inputs.to(device)
#                         targets = targets.to(device)
#                         model(inputs)
#                         cnt += 1
#                         if cnt >= warmup_iters:
#                             break
#                 self._profile_init(model, dataloader, node, model_name, device)
#             else:
#                 self._profile_runtime(model, dataloader, node, model_name, device)

#         rq = node.assign_task(profiling_task)
#         rq.get()

#     def _profile_init(self, model, dataloader, node, model_name, device):
#         self._register_hooks(model)
#         with torch.profiler.profile(
#             activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
#             profile_memory=True
#         ) as prof:
#             with torch.no_grad():
#                 for step, (inp, tgt) in enumerate(dataloader):
#                     if step >= 1:
#                         break
#                     inp = inp.to(device)
#                     tgt = tgt.to(device)
#                     model(inp)
#                     prof.step()

#         stats = self._process_events(prof, model, node, runtime=False)
#         self._update_profile_db(stats, model_name, node, runtime=False)

#     def _profile_runtime(self, model, dataloader, node, model_name, device):
#         self._register_hooks(model)
#         with torch.no_grad():
#             for step, (inp, tgt) in enumerate(dataloader):
#                 if step >= 1:
#                     break
#                 inp = inp.to(device)
#                 tgt = tgt.to(device)
#                 with torch.profiler.profile(
#                     activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
#                 ) as prof:
#                     model(inp)
#                     prof.step()
#                 stats = self._process_events(prof, model, node, runtime=True)
#                 self._append_runtime_csv(stats, model_name, node)

#     def _process_events(self, profiler, model, node, runtime=False):
#         recognized = set()
#         for n, m in model.named_modules():
#             if n:
#                 recognized.add(n)

#         aggregated = {
#             'forward_pass': dict(self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
#                                  self_cpu_memory_usage=0, self_cuda_memory_usage=0, compute=node.node_id),
#             'misc': dict(self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
#                          self_cpu_memory_usage=0, self_cuda_memory_usage=0, compute=node.node_id)
#         }

#         events = list(profiler.events())
#         found_root = False

#         def strip_suffix(s):
#             return re.sub(r'(\.|_)\d+$', '', s)

#         for e in events:
#             if e.name == "":
#                 found_root = True
#                 aggregated['forward_pass']['self_cpu_time_total'] += e.self_cpu_time_total
#                 aggregated['forward_pass']['cpu_time_total'] += e.cpu_time_total
#                 aggregated['forward_pass']['cuda_time_total'] += e.device_time_total
#                 if not runtime:
#                     aggregated['forward_pass']['self_cpu_memory_usage'] += e.self_cpu_memory_usage
#                     aggregated['forward_pass']['self_cuda_memory_usage'] += e.self_device_memory_usage
#             else:
#                 base = strip_suffix(e.name)
#                 if base in recognized:
#                     if base not in aggregated:
#                         aggregated[base] = dict(
#                             self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
#                             self_cpu_memory_usage=0, self_cuda_memory_usage=0,
#                             compute=node.node_id
#                         )
#                     aggregated[base]['self_cpu_time_total'] += e.self_cpu_time_total
#                     aggregated[base]['cpu_time_total'] += e.cpu_time_total
#                     aggregated[base]['cuda_time_total'] += e.device_time_total
#                     if not runtime:
#                         aggregated[base]['self_cpu_memory_usage'] += e.self_cpu_memory_usage
#                         aggregated[base]['self_cuda_memory_usage'] += e.self_device_memory_usage
#                 else:
#                     aggregated['misc']['self_cpu_time_total'] += e.self_cpu_time_total
#                     aggregated['misc']['cpu_time_total'] += e.cpu_time_total
#                     aggregated['misc']['cuda_time_total'] += e.device_time_total
#                     if not runtime:
#                         aggregated['misc']['self_cpu_memory_usage'] += e.self_cpu_memory_usage
#                         aggregated['misc']['self_cuda_memory_usage'] += e.self_device_memory_usage

#         if not found_root:
#             for k in list(aggregated.keys()):
#                 if k not in ('forward_pass', 'misc'):
#                     aggregated['forward_pass']['self_cpu_time_total'] += aggregated[k]['self_cpu_time_total']
#                     aggregated['forward_pass']['cpu_time_total'] += aggregated[k]['cpu_time_total']
#                     aggregated['forward_pass']['cuda_time_total'] += aggregated[k]['cuda_time_total']
#                     if not runtime:
#                         aggregated['forward_pass']['self_cpu_memory_usage'] += aggregated[k]['self_cpu_memory_usage']
#                         aggregated['forward_pass']['self_cuda_memory_usage'] += aggregated[k]['self_cuda_memory_usage']

#             aggregated['forward_pass']['self_cpu_time_total'] += aggregated['misc']['self_cpu_time_total']
#             aggregated['forward_pass']['cpu_time_total'] += aggregated['misc']['cpu_time_total']
#             aggregated['forward_pass']['cuda_time_total'] += aggregated['misc']['cuda_time_total']
#             if not runtime:
#                 aggregated['forward_pass']['self_cpu_memory_usage'] += aggregated['misc']['self_cpu_memory_usage']
#                 aggregated['forward_pass']['self_cuda_memory_usage'] += aggregated['misc']['self_cuda_memory_usage']

#         return aggregated

#     def _update_profile_db(self, stats, model_name, node, runtime=False):
#         if runtime:
#             return
#         for layer_name, data in stats.items():
#             total_t = data['cpu_time_total'] + data['cuda_time_total']
#             total_m = data['self_cpu_memory_usage'] + data['self_cuda_memory_usage']
#             row = {
#                 'Model': model_name,
#                 'Layer': layer_name,
#                 'Compute': data['compute'],
#                 'Self CPU (us)': data['self_cpu_time_total'],
#                 'CPU Total (us)': data['cpu_time_total'],
#                 'CUDA Total (us)': data['cuda_time_total'],
#                 'Self CPU Mem (bytes)': data['self_cpu_memory_usage'],
#                 'Self CUDA Mem (bytes)': data['self_cuda_memory_usage'],
#                 'Total Execution Time (us)': total_t,
#                 'Total Memory Used (bytes)': total_m
#             }
#             self.profile_db = self._upsert(self.profile_db, row)
#         self.profile_db.to_csv(self.profile_db_path, index=False)

#     def _upsert(self, df, row):
#         mask = (
#             (df['Model'] == row['Model']) &
#             (df['Layer'] == row['Layer']) &
#             (df['Compute'] == row['Compute'])
#         )
#         if not df[mask].empty:
#             existing_time = df.loc[mask, 'Total Execution Time (us)'].max()
#             if row['Total Execution Time (us)'] > existing_time:
#                 for k, v in row.items():
#                     df.loc[mask, k] = v
#         else:
#             df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
#         return df

#     def _append_runtime_csv(self, stats, model_name, node):
#         rows = []
#         for layer_name, data in stats.items():
#             exec_time = data['cpu_time_total'] + data['cuda_time_total']
#             rows.append({
#                 'Model': model_name,
#                 'Layer': layer_name,
#                 'Compute': data['compute'],
#                 'Execution Time (us)': exec_time
#             })
#         if rows:
#             rdf = pd.read_csv(self.runtime_csv)
#             rdf = pd.concat([rdf, pd.DataFrame(rows)], ignore_index=True)
#             rdf.to_csv(self.runtime_csv, index=False)

#     def get_profile_db(self):
#         return self.profile_db

#     def print_profile_db(self):
#         if self.profile_db.empty:
#             print("ProfileDB is empty.")
#         else:
#             print("ProfileDB:\n", self.profile_db.to_string(index=False))


# ############################################
# # 3) Stage Class (Unchanged)
# ############################################

# class Stage:
#     def __init__(self, stage_id: str, layers: nn.ModuleList, assigned_node: 'Node'):
#         self.stage_id = stage_id
#         self.layers = layers
#         self.assigned_node = assigned_node

#         self.input_data = None
#         self.output_data = None
#         self.execution_time = 0.0

#         self.dependencies = set()
#         self.dependents = set()

#     def run_stage(self):
#         import time
#         start = time.time()
#         device = torch.device(
#             f"cuda:{self.assigned_node.gpu}" if self.assigned_node.gpu is not None and torch.cuda.is_available()
#             else "cpu"
#         )
#         inp = self.input_data.to(device)

#         with torch.no_grad():
#             out = inp
#             for layer in self.layers:
#                 out = layer(out)

#         if device.type == 'cuda':
#             out = out.cpu()

#         self.output_data = out
#         end = time.time()
#         self.execution_time = end - start
#         return self.output_data


# ############################################
# # 4) Scheduler Class (Updated to Use Observation Window)
# ############################################

# class Scheduler:
#     """
#     A DP-based scheduler that:
#       - Uses an observation_window to interpret times as a fraction => utilization
#       - w[k] is the existing utilization on node k (0 <= w[k] <= 1, if we interpret it strictly).
#       - For each new Task, runs a DP up to k_stages partitions, picks best s,
#         and updates w[k] accordingly.
#     """

#     def __init__(self, profiler: 'Profiler', nodes: List['Node'], observation_window: float, slack_percent: float = 0.1):
#         """
#         :param profiler: A Profiler with layer->node times in a ProfileDB.
#         :param nodes: A list of Node objects
#         :param observation_window: The base time window for converting raw times -> utilization fraction
#         :param slack_percent: an optional multiplier to add slack: e.g. final window = observation_window*(1+slack_percent)
#         """
#         self.profiler = profiler
#         self.nodes = nodes

#         # incorporate slack into the final used window
#         self.observation_window = observation_window * (1.0 + slack_percent)

#         # node_usage w[k] store fraction of the observation window
#         self.node_usage = [0.0] * len(nodes)

#         # store per-task stage graph
#         self.task_stage_graphs: Dict[str, Dict[str, dict]] = {}

#     def _get_block_fraction(self, block_time: float) -> float:
#         """
#         Convert raw block_time in seconds -> fraction of the observation window.
#         """
#         frac = block_time / self.observation_window
#         return frac

#     def _get_block_time(self, prefix_times, x, i, node_j) -> float:
#         return prefix_times[node_j][i] - prefix_times[node_j][x]

#     def decompose_and_allocate_task(self, task: 'Task', k_stages_max: int = None):
#         """
#         DP with w[k] usage and observation_window.
#         We sum up block_time/observation_window to w[k].
#         """
#         model = task.model
#         model_name = task.model_name
#         task_id = task.task_id

#         # Gather leaf layers
#         leaf_layers = []
#         for name, module in model.named_modules():
#             if len(list(module.children())) == 0 and name:
#                 leaf_layers.append(name)
#         n_layers = len(leaf_layers)
#         if n_layers == 0:
#             print(f"[Scheduler] No leaf layers found for model={model_name}")
#             self.task_stage_graphs[task_id] = {}
#             return

#         if k_stages_max is None:
#             k_stages_max = min(len(self.nodes), n_layers)

#         profile_db = self.profiler.get_profile_db()
#         model_df = profile_db[profile_db['Model'] == model_name]

#         times = [[0.0]*len(self.nodes) for _ in range(n_layers)]
#         for i2, lname in enumerate(leaf_layers):
#             for j, node in enumerate(self.nodes):
#                 row = model_df[(model_df['Layer'] == lname) & (model_df['Compute'] == node.node_id)]
#                 if not row.empty:
#                     t_us = row['Total Execution Time (us)'].values[0]
#                     times[i2][j] = t_us / 1e6  # raw seconds
#                 else:
#                     times[i2][j] = 1.0

#         # prefix sums
#         prefix_times = [[0.0]*(n_layers+1) for _ in range(len(self.nodes))]
#         for j in range(len(self.nodes)):
#             for i2 in range(1, n_layers+1):
#                 prefix_times[j][i2] = prefix_times[j][i2-1] + times[i2-1][j]

#         INF = float('inf')
#         DP = [[INF]*(k_stages_max+1) for _ in range(n_layers+1)]
#         back_node = [[-1]*(k_stages_max+1) for _ in range(n_layers+1)]
#         back_split = [[-1]*(k_stages_max+1) for _ in range(n_layers+1)]

#         DP[0][0] = 0.0

#         # fill DP
#         for i2 in range(1, n_layers+1):
#             for s in range(1, k_stages_max+1):
#                 for j in range(len(self.nodes)):
#                     wj = self.node_usage[j]  # existing usage fraction
#                     for x in range(i2):
#                         block_t = self._get_block_time(prefix_times, x, i2, j) # raw time
#                         block_frac = self._get_block_fraction(block_t)         # fraction
#                         cand = max(DP[x][s-1], wj + block_frac)
#                         if cand < DP[i2][s]:
#                             DP[i2][s] = cand
#                             back_node[i2][s] = j
#                             back_split[i2][s] = x

#         best_val = INF
#         best_s = None
#         for s in range(1, k_stages_max+1):
#             if DP[n_layers][s] < best_val:
#                 best_val = DP[n_layers][s]
#                 best_s = s
#         print(f"[Scheduler] Task={task_id}, Model={model_name}, best s in [1..{k_stages_max}] => {best_s}, final utilization={best_val:.4f}")

#         i2 = n_layers
#         s = best_s
#         partitions = []
#         while i2 > 0 and s > 0:
#             j = back_node[i2][s]
#             x = back_split[i2][s]
#             partitions.append((x, i2, j))
#             i2 = x
#             s -= 1
#         partitions.reverse()

#         stage_graph = {}
#         idx = 0
#         for (start_i, end_i, node_j) in partitions:
#             stage_id = f"{task_id}-stage-{idx}"
#             layer_subset = leaf_layers[start_i:end_i]
#             stage_graph[stage_id] = {
#                 'node_idx': node_j,
#                 'layer_names': layer_subset,
#                 'dependencies': [],
#                 'dependents': []
#             }
#             idx += 1

#         keys = list(stage_graph.keys())
#         for k2 in range(1, len(keys)):
#             prev_stg = keys[k2-1]
#             curr_stg = keys[k2]
#             stage_graph[curr_stg]['dependencies'].append(prev_stg)
#             stage_graph[prev_stg]['dependents'].append(curr_stg)

#         self.task_stage_graphs[task_id] = stage_graph

#         # update w[k]
#         for (start_i, end_i, node_j) in partitions:
#             raw_block_time = self._get_block_time(prefix_times, start_i, end_i, node_j)
#             block_frac = self._get_block_fraction(raw_block_time)
#             self.node_usage[node_j] += block_frac

#         self.print_task_stages(task_id)

#     def print_task_stages(self, task_id: str):
#         if task_id not in self.task_stage_graphs:
#             print(f"[Scheduler] No stage graph for task={task_id}")
#             return
#         stg_graph = self.task_stage_graphs[task_id]
#         print(f"[Scheduler] Stage Decomposition for Task={task_id}:")
#         for sid, info in stg_graph.items():
#             node_j = info['node_idx']
#             node_id = self.nodes[node_j].node_id
#             layers = info['layer_names']
#             deps = info['dependencies']
#             print(f"  Stage={sid}, Node={node_id}, Layers={layers}, DependsOn={list(deps)}")

#     def create_stages_for_task(self, task: 'Task'):
#         import torch.nn as nn
#         t_id = task.task_id
#         if t_id not in self.task_stage_graphs:
#             print(f"[Scheduler] No stage graph found for Task={t_id}.")
#             return
#         stg_graph = self.task_stage_graphs[t_id]
#         for stage_id, info in stg_graph.items():
#             node_idx = info['node_idx']
#             node = self.nodes[node_idx]

#             layer_list = []
#             for lname in info['layer_names']:
#                 layer_list.append(task.model.get_submodule(lname))

#             stg = Stage(stage_id, nn.ModuleList(layer_list), node)
#             stg.dependencies = set(info['dependencies'])
#             stg.dependents = set(info['dependents'])
#             task.stages[stage_id] = stg

#     def run_pipeline_parallel(self, task: 'Task'):
#         lock = threading.Lock()
#         dep_count = {}
#         for sid, stg in task.stages.items():
#             dep_count[sid] = len(stg.dependencies)

#         total_stages = len(task.stages)
#         completed = 0
#         done_event = threading.Event()

#         def stage_done_callback(sid: str):
#             nonlocal completed
#             with lock:
#                 for dep_id in task.stages[sid].dependents:
#                     dep_count[dep_id] -= 1
#                     if dep_count[dep_id] == 0:
#                         enqueue_stage(dep_id)
#                 completed += 1
#                 if completed >= total_stages:
#                     done_event.set()

#         def enqueue_stage(sid: str):
#             stg = task.stages[sid]
#             def run_fn():
#                 return stg.run_stage()
#             def worker():
#                 rq = stg.assigned_node.assign_task(run_fn)
#                 _ = rq.get()
#                 stage_done_callback(sid)
#             threading.Thread(target=worker, daemon=True).start()

#         for sid, c in dep_count.items():
#             if c == 0:
#                 enqueue_stage(sid)

#         done_event.wait()


# ############################################
# # 5) Task Class (Unchanged)
# ############################################

# from typing import Dict, Optional

# class Task:
#     def __init__(self, task_id: str, model: nn.Module, input_data: torch.Tensor, model_name: str):
#         self.task_id = task_id
#         self.model = model
#         self.input_data = input_data
#         self.model_name = model_name

#         self.stages: Dict[str, Stage] = {}

#         self.start_time: Optional[float] = None
#         self.finish_time: Optional[float] = None

#     def get_total_execution_time(self):
#         if self.start_time and self.finish_time:
#             return self.finish_time - self.start_time
#         return 0.0


# ############################################
# # 6) Taskset Class (Unchanged)
# ############################################

# import time

# class Taskset:
#     def __init__(self, tasks: List[Task], scheduler: Scheduler):
#         self.tasks = tasks
#         self.scheduler = scheduler

#     def schedule_all_tasks(self):
#         for t in self.tasks:
#             self.scheduler.create_stages_for_task(t)

#     def execute_all(self):
#         import threading
#         start_global = time.time()
#         threads = []
#         def run_task(task: Task):
#             task.start_time = time.time()
#             self.scheduler.run_pipeline_parallel(task)
#             task.finish_time = time.time()

#         for t in self.tasks:
#             th = threading.Thread(target=run_task, args=(t,), daemon=True)
#             th.start()
#             threads.append(th)

#         for th in threads:
#             th.join()

#         end_global = time.time()
#         print(f"[Taskset] All tasks completed in {end_global - start_global:.2f}s")

#     def collect_metrics(self):
#         return {}


# ############################################
# # 7) Evaluator Class (Unchanged)
# ############################################

# class Evaluator:
#     def __init__(self, device='cpu', print_precision=2, rtol=1e-03, atol=1e-05):
#         self.device = device
#         self.print_precision = print_precision
#         self.rtol = rtol
#         self.atol = atol

#         self.naive_run_time = None
#         self.pipeline_run_time = None

#         self.naive_outputs = {}
#         self.pipeline_outputs = {}

#     def run_naive_pytorch(self, tasks: List['Task']):
#         import time
#         dev = torch.device(self.device)
#         start = time.time()

#         with torch.no_grad():
#             for task in tasks:
#                 task.model.to(dev)
#                 out = task.model(task.input_data.to(dev))
#                 out_cpu = out.cpu() if dev.type == 'cuda' else out
#                 self.naive_outputs[task.task_id] = out_cpu

#         end = time.time()
#         self.naive_run_time = end - start
#         print(f"[Evaluator] Naive run on device='{self.device}' for {len(tasks)} tasks => "
#               f"{self.naive_run_time:.{self.print_precision}f}s")

#     def run_pipeline(self, taskset: 'Taskset'):
#         import time
#         start = time.time()
#         taskset.schedule_all_tasks()
#         taskset.execute_all()
#         end = time.time()
#         self.pipeline_run_time = end - start
#         print(f"[Evaluator] Pipeline approach => {self.pipeline_run_time:.{self.print_precision}f}s")

#         for task in taskset.tasks:
#             final_out = self._get_task_final_output(task)
#             self.pipeline_outputs[task.task_id] = final_out

#     def _get_task_final_output(self, task: 'Task'):
#         if not task.stages:
#             return None
#         final_stage = None
#         for sid, stg in task.stages.items():
#             if not stg.dependents:
#                 final_stage = stg
#                 break
#         if final_stage is None:
#             sorted_ids = sorted(task.stages.keys())
#             final_stage = task.stages[sorted_ids[-1]]
#         return final_stage.output_data

#     def compare_results(self, node_usage_info: dict = None):
#         print("\n=== Comparison ===")
#         if self.naive_run_time is None or self.pipeline_run_time is None:
#             print("Cannot compare times; missing run times.")
#         else:
#             print(f"Naive total   : {self.naive_run_time:.{self.print_precision}f}s")
#             print(f"Pipeline total: {self.pipeline_run_time:.{self.print_precision}f}s")
#             if self.pipeline_run_time == 0.0:
#                 speedup = float('inf')
#             else:
#                 speedup = self.naive_run_time / self.pipeline_run_time
#             print(f"Speedup = {speedup:.{self.print_precision}f}x")

#         print("\n--- Output Similarity Checks ---")
#         for task_id, naive_out in self.naive_outputs.items():
#             pipe_out = self.pipeline_outputs.get(task_id, None)
#             if pipe_out is None:
#                 print(f"Task={task_id}: No pipeline output found.")
#                 continue
#             if naive_out.shape != pipe_out.shape:
#                 print(f"Task={task_id}: Output shape mismatch.")
#                 continue
#             close = torch.allclose(naive_out, pipe_out, rtol=self.rtol, atol=self.atol)
#             if close:
#                 print(f"Task={task_id}: Outputs match within tolerance.")
#             else:
#                 print(f"Task={task_id}: Outputs differ beyond tolerance.")

#         if node_usage_info:
#             print("\nNode Usage Info:")
#             for k, v in node_usage_info.items():
#                 print(f"  Node={k}, usage={v}")

#         print("\n=== End Comparison ===")


# ############################################
# # 8) main_demo test script (Updated)
# ############################################

# import torch
# import torch.nn as nn
# import torch.utils.data as data
# import gc
# import torchvision.models as models

# def main_demo():
#     print("=== main_demo with Observation Window Logic ===")

#     # 1) Setup init phase
#     def init_phase():
#         print("=== INIT PHASE ===")
#         profiler = Profiler(mode='init', profile_db_path='dp_partition_profiling.csv')
#         nodes = Node.discover_nodes()
#         if not nodes:
#             print("No nodes discovered. Exiting.")
#             return None, None

#         # For demonstration, we profile 2 models
#         print("\n-- Profiling models --")
#         # a) Dummy dataset for a simple FFN
#         class FFNDataset(data.Dataset):
#             def __init__(self, size=5):
#                 self.size = size
#             def __len__(self):
#                 return self.size
#             def __getitem__(self, idx):
#                 return torch.randn(1,256), torch.tensor(0)

#         # b) Dummy dataset for ResNet
#         class ResNetDataset(data.Dataset):
#             def __init__(self, size=5):
#                 self.size = size
#             def __len__(self):
#                 return self.size
#             def __getitem__(self, idx):
#                 return torch.randn(3,224,224), torch.tensor(0)

#         ffn_loader = data.DataLoader(FFNDataset(), batch_size=1)
#         res_loader = data.DataLoader(ResNetDataset(), batch_size=1)

#         # Profile both
#         ffn_model = nn.Sequential(nn.Linear(256, 512), nn.ReLU(), nn.Linear(512,128))
#         # or the bigger ffn if you prefer

#         resnet_model = models.resnet18(pretrained=False)

#         ffn_model_name = "InitPhaseFFN"
#         res_model_name = "InitPhaseResNet"

#         for node in nodes:
#             profiler.profile_model(ffn_model, ffn_loader, node, ffn_model_name)
#             profiler.profile_model(resnet_model, res_loader, node, res_model_name)

#         return profiler, nodes


#     # 2) Runtime phase
#     def runtime_phase(profiler, nodes):
#         print("\n=== RUNTIME PHASE ===")

#         # Choose an observation window (in seconds). Suppose we pick 2.0 seconds
#         # Then a slack of 20% => final window=2.4
#         # The HPC or real-time system might pick this logic differently
#         observation_window = 2.0
#         slack_percent = 0.2

#         # Create our DP-based Scheduler with the observation window
#         scheduler = Scheduler(profiler, nodes, observation_window=observation_window, slack_percent=slack_percent)

#         # Let's define tasks with actual models
#         # a) Our bigger FFN model
#         class BiggerFFN(nn.Module):
#             def __init__(self):
#                 super().__init__()
#                 self.net = nn.Sequential(
#                     nn.Linear(256, 512),
#                     nn.ReLU(),
#                     nn.Linear(512, 256),
#                     nn.ReLU(),
#                     nn.Linear(256, 128),
#                     nn.ReLU(),
#                     nn.Linear(128, 10)
#                 )
#             def forward(self, x):
#                 return self.net(x)

#         bigger_ffn_model_name = "BiggerFFN_ObsWin"

#         # b) Our ResNet
#         resnet_model_name = "ResNet18_ObsWin"
#         resnet_model = models.resnet18(pretrained=False)

#         # We won't re-profile them now, we assume the ProfileDB has times for these layers or fallback=1.0
#         # Let's create tasks:
#         input_ffn = torch.randn(1, 256)
#         task_ffn1 = Task("TaskFFN1", BiggerFFN(), input_ffn.clone(), bigger_ffn_model_name)
#         task_ffn2 = Task("TaskFFN2", BiggerFFN(), input_ffn.clone(), bigger_ffn_model_name)

#         input_res = torch.randn(1, 3, 224, 224)
#         task_res1 = Task("TaskRes1", resnet_model, input_res.clone(), resnet_model_name)
#         task_res2 = Task("TaskRes2", resnet_model, input_res.clone(), resnet_model_name)
#         taks_res3 = Task("TaskRes3", resnet_model, input_res.clone(), resnet_model_name)

#         all_tasks = [task_ffn1, task_ffn2, task_res1, task_res2,task_res3]

#         # Decompose & allocate each
#         for t in all_tasks:
#             scheduler.decompose_and_allocate_task(t, k_stages_max=None)

#         # Build Taskset
#         taskset = Taskset(all_tasks, scheduler)
#         evaluator = Evaluator(device='cuda', print_precision=3)

#         # 1) naive
#         evaluator.run_naive_pytorch(all_tasks)
#         if torch.cuda.is_available():
#             torch.cuda.empty_cache()
#         gc.collect()

#         # 2) pipeline
#         print("\n-- Pipeline run with observation window approach --")
#         evaluator.run_pipeline(taskset)

#         usage_info = {}
#         # usage_info = {scheduler.nodes[i].node_id: scheduler.node_usage[i] for i in range(len(scheduler.nodes))}

#         evaluator.compare_results(usage_info)

#         # stop nodes
#         for nd in nodes:
#             nd.stop()

#         # print final profile
#         print("\nFinal ProfileDB:")
#         profiler.print_profile_db()


#     # actually run
#     profiler, nodes = init_phase()
#     if profiler is not None and nodes is not None:
#         runtime_phase(profiler, nodes)
#     else:
#         print("Init phase failed; skipping runtime.")


In [None]:
# main_demo()

=== main_demo with Observation Window Logic ===
=== INIT PHASE ===

-- Profiling models --


  df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)



=== RUNTIME PHASE ===


NameError: name 'task_res3' is not defined

In [None]:
# -*- coding: utf-8 -*-
"""Poc_final_dart.ipynb

Automatically generated by Colab.
"""

import os
import torch
import threading
import queue
from typing import Callable, Any, List

############################################
# 1) Node Class (Unchanged)
############################################

class Node:
    def __init__(self, node_id: str, cpus=None, gpu=None):
        self._node_id = node_id
        self._cpus = tuple(cpus or [])
        self._gpu = gpu

        self._original_affinity = os.sched_getaffinity(0)
        self._task_queue = queue.Queue()
        self._stop_signal = False

        self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
        self._worker_thread.start()

    @property
    def node_id(self):
        return self._node_id

    @property
    def cpus(self):
        return self._cpus

    @property
    def gpu(self):
        return self._gpu

    def assign_task(self, func: Callable, *args, **kwargs) -> queue.Queue:
        result_queue = queue.Queue(maxsize=1)
        self._task_queue.put((func, args, kwargs, result_queue))
        return result_queue

    def stop(self):
        self._stop_signal = True
        self._task_queue.put(None)
        self._worker_thread.join()

    def _worker_loop(self):
        while not self._stop_signal:
            item = self._task_queue.get()
            if item is None:
                break
            func, args, kwargs, result_queue = item
            try:
                self._set_context()
                result = func(*args, **kwargs)
            except Exception as e:
                result = e
            finally:
                self._reset_context()

            result_queue.put(result)

    def _set_context(self):
        if self._cpus:
            os.sched_setaffinity(0, self._cpus)
        if self._gpu is not None and torch.cuda.is_available():
            torch.cuda.set_device(self._gpu)

    def _reset_context(self):
        os.sched_setaffinity(0, self._original_affinity)

    @staticmethod
    def discover_nodes() -> List['Node']:
        nodes = []
        num_cpus = os.cpu_count() or 1
        ngpus = torch.cuda.device_count()
        for core_id in range(num_cpus):
            node = Node(node_id=f"CPU-{core_id}", cpus=[core_id])
            nodes.append(node)
        for g in range(ngpus):
            for core_id in range(num_cpus):
                node = Node(node_id=f"GPU-{g}-CPU-{core_id}", cpus=[core_id], gpu=g)
                nodes.append(node)
        return nodes

    def __repr__(self):
        return f"Node({self._node_id}, cpus={self._cpus}, gpu={self._gpu})"


############################################
# 2) Profiler Class (Unchanged)
############################################

import time
import os
import re
import pandas as pd
import torch
import torch.nn as nn
import torch.profiler

class Profiler:
    def __init__(self, mode: str, profile_db_path='profiling_results.csv', log_dir='logs'):
        assert mode in ['init', 'runtime']
        self.mode = mode
        self.profile_db_path = profile_db_path
        self.log_dir = log_dir
        os.makedirs(self.log_dir, exist_ok=True)

        columns = [
            'Model', 'Layer', 'Compute',
            'Self CPU (us)', 'CPU Total (us)', 'CUDA Total (us)',
            'Self CPU Mem (bytes)', 'Self CUDA Mem (bytes)',
            'Total Execution Time (us)', 'Total Memory Used (bytes)'
        ]
        if os.path.exists(self.profile_db_path):
            self.profile_db = pd.read_csv(self.profile_db_path)
        else:
            self.profile_db = pd.DataFrame(columns=columns)

        self.runtime_csv = os.path.join(self.log_dir, 'runtime_results.csv')
        if not os.path.exists(self.runtime_csv):
            rt_cols = ['Model', 'Layer', 'Compute', 'Execution Time (us)']
            pd.DataFrame(columns=rt_cols).to_csv(self.runtime_csv, index=False)

    def _register_hooks(self, model: nn.Module):
        def hook_wrapper(layer_name):
            def hook(mod, inp, out):
                with torch.profiler.record_function(layer_name):
                    pass
            return hook

        for idx, (name, layer) in enumerate(model.named_modules()):
            if not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.ModuleList) and layer != model:
                layer.register_forward_hook(hook_wrapper(f"{name}_{idx}"))

    def profile_model(self, model: nn.Module, dataloader, node, model_name: str, warmup_iters=3):
        def profiling_task():
            device = torch.device(f"cuda:{node.gpu}" if node.gpu is not None and torch.cuda.is_available() else "cpu")
            model.to(device)

            if self.mode == 'init':
                with torch.no_grad():
                    cnt = 0
                    for inputs, targets in dataloader:
                        inputs = inputs.to(device)
                        targets = targets.to(device)
                        model(inputs)
                        cnt += 1
                        if cnt >= warmup_iters:
                            break
                self._profile_init(model, dataloader, node, model_name, device)
            else:
                self._profile_runtime(model, dataloader, node, model_name, device)

        rq = node.assign_task(profiling_task)
        rq.get()

    def _profile_init(self, model, dataloader, node, model_name, device):
        self._register_hooks(model)
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
            profile_memory=True
        ) as prof:
            with torch.no_grad():
                for step, (inp, tgt) in enumerate(dataloader):
                    if step >= 1:
                        break
                    inp = inp.to(device)
                    tgt = tgt.to(device)
                    model(inp)
                    prof.step()

        stats = self._process_events(prof, model, node, runtime=False)
        self._update_profile_db(stats, model_name, node, runtime=False)

    def _profile_runtime(self, model, dataloader, node, model_name, device):
        self._register_hooks(model)
        with torch.no_grad():
            for step, (inp, tgt) in enumerate(dataloader):
                if step >= 1:
                    break
                inp = inp.to(device)
                tgt = tgt.to(device)
                with torch.profiler.profile(
                    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
                ) as prof:
                    model(inp)
                    prof.step()
                stats = self._process_events(prof, model, node, runtime=True)
                self._append_runtime_csv(stats, model_name, node)

    def _process_events(self, profiler, model, node, runtime=False):
        recognized = set()
        for n, m in model.named_modules():
            if n:
                recognized.add(n)

        aggregated = {
            'forward_pass': dict(self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                                 self_cpu_memory_usage=0, self_cuda_memory_usage=0, compute=node.node_id),
            'misc': dict(self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                         self_cpu_memory_usage=0, self_cuda_memory_usage=0, compute=node.node_id)
        }

        events = list(profiler.events())
        found_root = False

        def strip_suffix(s):
            return re.sub(r'(\.|_)\d+$', '', s)

        for e in events:
            if e.name == "":
                found_root = True
                aggregated['forward_pass']['self_cpu_time_total'] += e.self_cpu_time_total
                aggregated['forward_pass']['cpu_time_total'] += e.cpu_time_total
                aggregated['forward_pass']['cuda_time_total'] += e.device_time_total
                if not runtime:
                    aggregated['forward_pass']['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                    aggregated['forward_pass']['self_cuda_memory_usage'] += e.self_device_memory_usage
            else:
                base = strip_suffix(e.name)
                if base in recognized:
                    if base not in aggregated:
                        aggregated[base] = dict(
                            self_cpu_time_total=0, cpu_time_total=0, cuda_time_total=0,
                            self_cpu_memory_usage=0, self_cuda_memory_usage=0,
                            compute=node.node_id
                        )
                    aggregated[base]['self_cpu_time_total'] += e.self_cpu_time_total
                    aggregated[base]['cpu_time_total'] += e.cpu_time_total
                    aggregated[base]['cuda_time_total'] += e.device_time_total
                    if not runtime:
                        aggregated[base]['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                        aggregated[base]['self_cuda_memory_usage'] += e.self_device_memory_usage
                else:
                    aggregated['misc']['self_cpu_time_total'] += e.self_cpu_time_total
                    aggregated['misc']['cpu_time_total'] += e.cpu_time_total
                    aggregated['misc']['cuda_time_total'] += e.device_time_total
                    if not runtime:
                        aggregated['misc']['self_cpu_memory_usage'] += e.self_cpu_memory_usage
                        aggregated['misc']['self_cuda_memory_usage'] += e.self_device_memory_usage

        if not found_root:
            for k in list(aggregated.keys()):
                if k not in ('forward_pass', 'misc'):
                    aggregated['forward_pass']['self_cpu_time_total'] += aggregated[k]['self_cpu_time_total']
                    aggregated['forward_pass']['cpu_time_total'] += aggregated[k]['cpu_time_total']
                    aggregated['forward_pass']['cuda_time_total'] += aggregated[k]['cuda_time_total']
                    if not runtime:
                        aggregated['forward_pass']['self_cpu_memory_usage'] += aggregated[k]['self_cpu_memory_usage']
                        aggregated['forward_pass']['self_cuda_memory_usage'] += aggregated[k]['self_cuda_memory_usage']

            aggregated['forward_pass']['self_cpu_time_total'] += aggregated['misc']['self_cpu_time_total']
            aggregated['forward_pass']['cpu_time_total'] += aggregated['misc']['cpu_time_total']
            aggregated['forward_pass']['cuda_time_total'] += aggregated['misc']['cuda_time_total']
            if not runtime:
                aggregated['forward_pass']['self_cpu_memory_usage'] += aggregated['misc']['self_cpu_memory_usage']
                aggregated['forward_pass']['self_cuda_memory_usage'] += aggregated['misc']['self_cuda_memory_usage']

        return aggregated

    def _update_profile_db(self, stats, model_name, node, runtime=False):
        if runtime:
            return
        for layer_name, data in stats.items():
            total_t = data['cpu_time_total'] + data['cuda_time_total']
            total_m = data['self_cpu_memory_usage'] + data['self_cuda_memory_usage']
            row = {
                'Model': model_name,
                'Layer': layer_name,
                'Compute': data['compute'],
                'Self CPU (us)': data['self_cpu_time_total'],
                'CPU Total (us)': data['cpu_time_total'],
                'CUDA Total (us)': data['cuda_time_total'],
                'Self CPU Mem (bytes)': data['self_cpu_memory_usage'],
                'Self CUDA Mem (bytes)': data['self_cuda_memory_usage'],
                'Total Execution Time (us)': total_t,
                'Total Memory Used (bytes)': total_m
            }
            self.profile_db = self._upsert(self.profile_db, row)
        self.profile_db.to_csv(self.profile_db_path, index=False)

    def _upsert(self, df, row):
        mask = (
            (df['Model'] == row['Model']) &
            (df['Layer'] == row['Layer']) &
            (df['Compute'] == row['Compute'])
        )
        if not df[mask].empty:
            existing_time = df.loc[mask, 'Total Execution Time (us)'].max()
            if row['Total Execution Time (us)'] > existing_time:
                for k, v in row.items():
                    df.loc[mask, k] = v
        else:
            # Ensure row is not empty before concatenation
            new_row = pd.DataFrame([row])
            df = pd.concat([df, new_row], ignore_index=True)
        return df

    def _append_runtime_csv(self, stats, model_name, node):
        rows = []
        for layer_name, data in stats.items():
            exec_time = data['cpu_time_total'] + data['cuda_time_total']
            rows.append({
                'Model': model_name,
                'Layer': layer_name,
                'Compute': data['compute'],
                'Execution Time (us)': exec_time
            })
        if rows:
            # Avoid reading empty runtime_csv by checking if it exists
            if os.path.exists(self.runtime_csv):
                rdf = pd.read_csv(self.runtime_csv)
                rdf = pd.concat([rdf, pd.DataFrame(rows)], ignore_index=True)
            else:
                rdf = pd.DataFrame(rows)
            rdf.to_csv(self.runtime_csv, index=False)

    def get_profile_db(self):
        return self.profile_db

    def print_profile_db(self):
        if self.profile_db.empty:
            print("ProfileDB is empty.")
        else:
            print("ProfileDB:\n", self.profile_db.to_string(index=False))


############################################
# 3) Stage Class (Unchanged)
############################################

class Stage:
    def __init__(self, stage_id: str, layers: nn.ModuleList, assigned_node: 'Node'):
        self.stage_id = stage_id
        self.layers = layers
        self.assigned_node = assigned_node

        self.input_data = None
        self.output_data = None
        self.execution_time = 0.0

        self.dependencies = set()
        self.dependents = set()

    def run_stage(self):
        import time
        start = time.time()
        device = torch.device(
            f"cuda:{self.assigned_node.gpu}" if self.assigned_node.gpu is not None and torch.cuda.is_available()
            else "cpu"
        )
        if self.input_data is None:
            print(f"[Stage] {self.stage_id}: No input data provided.")
            out = torch.tensor([])
        else:
            inp = self.input_data.to(device)

            with torch.no_grad():
                out = inp
                for layer in self.layers:
                    out = layer(out)

            if device.type == 'cuda':
                out = out.cpu()

        self.output_data = out
        end = time.time()
        self.execution_time = end - start
        return self.output_data


############################################
# 4) Scheduler Class (Updated to Use Observation Window and Data Propagation)
############################################

class Scheduler:
    """
    A DP-based scheduler that:
      - Uses an observation_window to interpret times as a fraction => utilization
      - w[k] is the existing utilization on node k (0 <= w[k] <= 1, considering observation window)
      - For each new Task, runs a DP up to k_stages_max partitions, picks best s,
        and updates w[k] accordingly.
    """

    def __init__(self, profiler: 'Profiler', nodes: List['Node'], observation_window: float, slack_percent: float = 0.1):
        """
        :param profiler: A Profiler with layer->node times in a ProfileDB.
        :param nodes: A list of Node objects
        :param observation_window: The base time window for converting raw times -> utilization fraction
        :param slack_percent: an optional multiplier to add slack: e.g. final window = observation_window*(1+slack_percent)
        """
        self.profiler = profiler
        self.nodes = nodes

        # Incorporate slack into the final used window
        self.observation_window = observation_window * (1.0 + slack_percent)

        # node_usage w[k] store fraction of the observation window
        self.node_usage = [0.0] * len(nodes)

        # store per-task stage graph
        self.task_stage_graphs: Dict[str, Dict[str, dict]] = {}

    def _get_block_fraction(self, block_time: float) -> float:
        """
        Convert raw block_time in seconds -> fraction of the observation window.
        """
        frac = block_time / self.observation_window
        return frac

    def _get_block_time(self, prefix_times, x, i, node_j) -> float:
        return prefix_times[node_j][i] - prefix_times[node_j][x]

    def decompose_and_allocate_task(self, task: 'Task', k_stages_max: int = None):
        """
        DP with w[k] usage and observation_window.
        We sum up block_time/observation_window to w[k].
        """
        model = task.model
        model_name = task.model_name
        task_id = task.task_id

        # Gather leaf layers
        leaf_layers = []
        for name, module in model.named_modules():
            if len(list(module.children())) == 0 and name:
                leaf_layers.append(name)
        n_layers = len(leaf_layers)
        if n_layers == 0:
            print(f"[Scheduler] No leaf layers found for model={model_name}")
            self.task_stage_graphs[task_id] = {}
            return

        if k_stages_max is None:
            k_stages_max = min(len(self.nodes), n_layers)

        profile_db = self.profiler.get_profile_db()
        model_df = profile_db[profile_db['Model'] == model_name]

        times = [[0.0]*len(self.nodes) for _ in range(n_layers)]
        for i2, lname in enumerate(leaf_layers):
            for j, node in enumerate(self.nodes):
                row = model_df[(model_df['Layer'] == lname) & (model_df['Compute'] == node.node_id)]
                if not row.empty:
                    t_us = row['Total Execution Time (us)'].values[0]
                    times[i2][j] = t_us / 1e6  # raw seconds
                else:
                    times[i2][j] = 1.0  # fallback

        # prefix sums
        prefix_times = [[0.0]*(n_layers+1) for _ in range(len(self.nodes))]
        for j in range(len(self.nodes)):
            for i2 in range(1, n_layers+1):
                prefix_times[j][i2] = prefix_times[j][i2-1] + times[i2-1][j]

        INF = float('inf')
        DP = [[INF]*(k_stages_max+1) for _ in range(n_layers+1)]
        back_node = [[-1]*(k_stages_max+1) for _ in range(n_layers+1)]
        back_split = [[-1]*(k_stages_max+1) for _ in range(n_layers+1)]

        DP[0][0] = 0.0

        # Fill DP
        for i2 in range(1, n_layers+1):
            for s in range(1, k_stages_max+1):
                for j in range(len(self.nodes)):
                    wj = self.node_usage[j]  # existing usage fraction
                    for x in range(i2):
                        block_t = self._get_block_time(prefix_times, x, i2, j)  # raw time
                        block_frac = self._get_block_fraction(block_t)         # fraction
                        cand = max(DP[x][s-1], wj + block_frac)
                        if cand < DP[i2][s]:
                            DP[i2][s] = cand
                            back_node[i2][s] = j
                            back_split[i2][s] = x

        # Find best s
        best_val = INF
        best_s = None
        for s in range(1, k_stages_max+1):
            if DP[n_layers][s] < best_val:
                best_val = DP[n_layers][s]
                best_s = s
        print(f"[Scheduler] Task={task_id}, Model={model_name}, best s in [1..{k_stages_max}] => {best_s}, final utilization={best_val:.4f}")

        # Reconstruct partitions
        i2 = n_layers
        s = best_s
        partitions = []
        while i2 > 0 and s > 0:
            j = back_node[i2][s]
            x = back_split[i2][s]
            partitions.append((x, i2, j))
            i2 = x
            s -= 1
        partitions.reverse()

        # Build a stage graph for this task
        stage_graph = {}
        idx = 0
        for (start_i, end_i, node_j) in partitions:
            stage_id = f"{task_id}-stage-{idx}"
            layer_subset = leaf_layers[start_i:end_i]
            stage_graph[stage_id] = {
                'node_idx': node_j,
                'layer_names': layer_subset,
                'dependencies': [],
                'dependents': []
            }
            idx += 1

        # Link stages linearly
        keys = list(stage_graph.keys())
        for k2 in range(1, len(keys)):
            prev_stg = keys[k2-1]
            curr_stg = keys[k2]
            stage_graph[curr_stg]['dependencies'].append(prev_stg)
            stage_graph[prev_stg]['dependents'].append(curr_stg)

        # Store the stage graph
        self.task_stage_graphs[task_id] = stage_graph

        # Update w[k] with assigned block fractions
        for (start_i, end_i, node_j) in partitions:
            raw_block_time = self._get_block_time(prefix_times, start_i, end_i, node_j)
            block_frac = self._get_block_fraction(raw_block_time)
            self.node_usage[node_j] += block_frac

        # Print final stage decomposition
        self.print_task_stages(task_id)

    def print_task_stages(self, task_id: str):
        """
        Print: which stage (plus consecutive layers) is on which node, for the given task.
        """
        if task_id not in self.task_stage_graphs:
            print(f"[Scheduler] No stage graph for task={task_id}")
            return

        stg_graph = self.task_stage_graphs[task_id]
        print(f"[Scheduler] Stage Decomposition for Task={task_id}:")
        for sid, info in stg_graph.items():
            node_j = info['node_idx']
            node_id = self.nodes[node_j].node_id
            layers = info['layer_names']
            deps = info['dependencies']
            print(f"  Stage={sid}, Node={node_id}, Layers={layers}, DependsOn={list(deps)}")

    def create_stages_for_task(self, task: 'Task'):
        """
        Convert the stored stage graph for this task into Stage objects for pipeline execution.
        """
        import torch.nn as nn
        t_id = task.task_id
        if t_id not in self.task_stage_graphs:
            print(f"[Scheduler] No stage graph found for Task={t_id}.")
            return

        stg_graph = self.task_stage_graphs[t_id]
        for stage_id, info in stg_graph.items():
            node_idx = info['node_idx']
            node = self.nodes[node_idx]

            layer_list = []
            for lname in info['layer_names']:
                layer = task.model.get_submodule(lname)
                if layer is not None:
                    layer_list.append(layer)
                else:
                    print(f"[Scheduler] Warning: Layer '{lname}' not found in model '{task.model_name}'.")

            stg = Stage(stage_id, nn.ModuleList(layer_list), node)
            stg.dependencies = set(info['dependencies'])
            stg.dependents = set(info['dependents'])
            task.stages[stage_id] = stg

    def run_pipeline_parallel(self, task: 'Task'):
        """
        The pipeline logic for one task:
         - Start all stages with zero deps
         - Each stage on finishing decrements the dependents
         - Launch a worker thread for each stage
         - Propagate output data to dependent stages
        """
        lock = threading.Lock()
        dep_count = {}
        for sid, stg in task.stages.items():
            dep_count[sid] = len(stg.dependencies)

        total_stages = len(task.stages)
        completed = 0
        done_event = threading.Event()

        def stage_done_callback(sid: str):
            nonlocal completed
            with lock:
                completed_stage = task.stages[sid]
                for dep_id in completed_stage.dependents:
                    dependent_stage = task.stages[dep_id]
                    # Propagate output data to dependent stage's input
                    dependent_stage.input_data = completed_stage.output_data
                    dep_count[dep_id] -= 1
                    if dep_count[dep_id] == 0:
                        enqueue_stage(dep_id)
                completed += 1
                if completed >= total_stages:
                    done_event.set()

        def enqueue_stage(sid: str):
            stg = task.stages[sid]
            def run_fn():
                return stg.run_stage()
            def worker():
                rq = stg.assigned_node.assign_task(run_fn)
                result = rq.get()
                if isinstance(result, Exception):
                    print(f"[Pipeline] Stage {sid} encountered an exception: {result}")
                stage_done_callback(sid)
            threading.Thread(target=worker, daemon=True).start()

        # Enqueue initial stages with zero dependencies
        for sid, c in dep_count.items():
            if c == 0:
                enqueue_stage(sid)

        done_event.wait()


############################################
# 5) Task Class (Unchanged)
############################################

from typing import Dict, Optional

class Task:
    """
    Represents a single inference pass for a model + input.
    Stages are built by the Scheduler's create_stages_for_task call.
    """

    def __init__(self, task_id: str, model: nn.Module, input_data: torch.Tensor, model_name: str):
        self.task_id = task_id
        self.model = model
        self.input_data = input_data
        self.model_name = model_name

        self.stages: Dict[str, Stage] = {}

        self.start_time: Optional[float] = None
        self.finish_time: Optional[float] = None

    def get_total_execution_time(self):
        if self.start_time and self.finish_time:
            return self.finish_time - self.start_time
        return 0.0


############################################
# 6) Taskset Class (Unchanged)
############################################

import time

class Taskset:
    """
    A collection of tasks to run with the DP-based pipeline approach.
    """

    def __init__(self, tasks: List[Task], scheduler: Scheduler):
        self.tasks = tasks
        self.scheduler = scheduler

    def schedule_all_tasks(self):
        # Build stage objects for each Task from the global stage graph
        for t in self.tasks:
            self.scheduler.create_stages_for_task(t)

    def execute_all(self):
        # We run each task in parallel
        import threading
        start_global = time.time()
        threads = []

        def run_task(task: Task):
            task.start_time = time.time()
            self.scheduler.run_pipeline_parallel(task)
            task.finish_time = time.time()

        for t in self.tasks:
            th = threading.Thread(target=run_task, args=(t,), daemon=True)
            th.start()
            threads.append(th)

        for th in threads:
            th.join()

        end_global = time.time()
        print(f"[Taskset] All tasks completed in {end_global - start_global:.2f}s")

    def collect_metrics(self):
        # For demonstration, we do not track node usage in detail here.
        return {}


############################################
# 7) Evaluator Class (Unchanged)
############################################

class Evaluator:
    """
    Compares:
      1) Naive single-device run
      2) Pipeline approach (DP-based)
    and checks whether the final outputs match.
    """

    def __init__(self, device='cpu', print_precision=2, rtol=1e-03, atol=1e-05):
        self.device = device
        self.print_precision = print_precision
        self.rtol = rtol
        self.atol = atol

        self.naive_run_time = None
        self.pipeline_run_time = None

        # Store final outputs
        #   self.naive_outputs[task_id] = tensor
        #   self.pipeline_outputs[task_id] = tensor
        self.naive_outputs = {}
        self.pipeline_outputs = {}

    def run_naive_pytorch(self, tasks: List['Task']):
        """
        Runs each Task sequentially on self.device (layer-by-layer).
        Also stores each Task's final output in self.naive_outputs[task.task_id].
        """
        import time
        dev = torch.device(self.device)
        start = time.time()

        with torch.no_grad():
            for task in tasks:
                task.model.to(dev)
                out = task.model(task.input_data.to(dev))
                out_cpu = out.cpu() if dev.type == 'cuda' else out
                self.naive_outputs[task.task_id] = out_cpu  # store the final naive output

        end = time.time()
        self.naive_run_time = end - start
        print(f"[Evaluator] Naive run on device='{self.device}' for {len(tasks)} tasks => "
              f"{self.naive_run_time:.{self.print_precision}f}s")

    def run_pipeline(self, taskset: 'Taskset'):
        """
        1) schedule_all_tasks -> builds Stage objects for each Task
        2) execute_all -> runs pipeline parallel
        3) gather final outputs from each Task
        """
        import time
        start = time.time()

        taskset.schedule_all_tasks()   # build stage objects
        taskset.execute_all()          # pipeline parallel run

        end = time.time()
        self.pipeline_run_time = end - start
        print(f"[Evaluator] Pipeline approach => {self.pipeline_run_time:.{self.print_precision}f}s")

        # Collect final outputs from tasks
        for task in taskset.tasks:
            # we assume the last stage's output_data is the final output
            # let's gather that
            final_out = self._get_task_final_output(task)
            self.pipeline_outputs[task.task_id] = final_out

    def _get_task_final_output(self, task: 'Task'):
        """
        Retrieve the final stage's output_data (the stage that has no dependents).
        Or just take the largest stage index if we name them in ascending order.
        """
        if not task.stages:
            return None
        # find stage with no dependents (or the max stage ID)
        final_stage = None
        for sid, stg in task.stages.items():
            if not stg.dependents:  # no dependents => final stage
                final_stage = stg
                break
        if final_stage is None:
            # fallback: just pick the highest stage ID by sorting
            sorted_ids = sorted(task.stages.keys())
            final_stage = task.stages[sorted_ids[-1]]
        return final_stage.output_data

    def compare_results(self, node_usage_info: dict = None):
        """
        Compares:
          - naive_run_time vs pipeline_run_time
          - naive_outputs vs pipeline_outputs
        """
        print("\n=== Comparison ===")
        if self.naive_run_time is None or self.pipeline_run_time is None:
            print("Cannot compare times; missing run times.")
        else:
            print(f"Naive total   : {self.naive_run_time:.{self.print_precision}f}s")
            print(f"Pipeline total: {self.pipeline_run_time:.{self.print_precision}f}s")
            if self.pipeline_run_time == 0.0:
                speedup = float('inf')
            else:
                speedup = self.naive_run_time / self.pipeline_run_time
            print(f"Speedup = {speedup:.{self.print_precision}f}x")

        # Compare final outputs
        print("\n--- Output Similarity Checks ---")
        for task_id, naive_out in self.naive_outputs.items():
            pipe_out = self.pipeline_outputs.get(task_id, None)
            if pipe_out is None:
                print(f"Task={task_id}: No pipeline output found.")
                continue
            # Check if shape matches
            if naive_out.shape != pipe_out.shape:
                print(f"Task={task_id}: Output shape mismatch.")
                continue
            # Compare with torch.allclose
            close = torch.allclose(naive_out, pipe_out, rtol=self.rtol, atol=self.atol)
            if close:
                print(f"Task={task_id}: Outputs match within tolerance.")
            else:
                print(f"Task={task_id}: Outputs differ beyond tolerance.")

        # Optionally show node usage info
        if node_usage_info:
            print("\nNode Usage Info:")
            for k, v in node_usage_info.items():
                print(f"  Node={k}, usage={v}")

        print("\n=== End Comparison ===")


############################################
# 8) main_demo Test Script (Updated)
############################################

import torch
import torch.nn as nn
import torch.utils.data as data
import gc
import torchvision.models as models

def main_demo():
    print("=== main_demo with Observation Window Logic ===")

    # 1) Setup init phase
    def init_phase():
        print("=== INIT PHASE ===")
        profiler = Profiler(mode='init', profile_db_path='dp_partition_profiling.csv')
        nodes = Node.discover_nodes()
        if not nodes:
            print("No nodes discovered. Exiting.")
            return None, None, None, None

        # For demonstration, we profile 2 models
        print("\n-- Profiling models --")
        # a) Dummy dataset for a simple FFN
        class FFNDataset(data.Dataset):
            def __init__(self, size=5):
                self.size = size
            def __len__(self):
                return self.size
            def __getitem__(self, idx):
                return torch.randn(1,256), torch.tensor(0)

        # b) Dummy dataset for ResNet
        class ResNetDataset(data.Dataset):
            def __init__(self, size=5):
                self.size = size
            def __len__(self):
                return self.size
            def __getitem__(self, idx):
                return torch.randn(3,224,224), torch.tensor(0)

        ffn_loader = data.DataLoader(FFNDataset(), batch_size=1)
        res_loader = data.DataLoader(ResNetDataset(), batch_size=1)

        # Profile both
        ffn_model = nn.Sequential(nn.Linear(256, 512), nn.ReLU(), nn.Linear(512,128))
        # Updated ResNet initialization to use 'weights' instead of 'pretrained'
        resnet_model = models.resnet18(weights=None)

        ffn_model_name = "InitPhaseFFN"
        res_model_name = "InitPhaseResNet"

        for node in nodes:
            profiler.profile_model(ffn_model, ffn_loader, node, ffn_model_name)
            profiler.profile_model(resnet_model, res_loader, node, res_model_name)

        # Define tasks to calculate observation window
        # For simplicity, assume tasks are similar to those in runtime_phase
        class BiggerFFN(nn.Module):
            def __init__(self):
                super().__init__()
                self.net = nn.Sequential(
                    nn.Linear(256, 512),
                    nn.ReLU(),
                    nn.Linear(512, 256),
                    nn.ReLU(),
                    nn.Linear(256, 128),
                    nn.ReLU(),
                    nn.Linear(128, 10)
                )
            def forward(self, x):
                return self.net(x)

        bigger_ffn_model_name = "BiggerFFN_ObsWin"
        resnet_model_name = "ResNet18_ObsWin"
        resnet_model_updated = models.resnet18(weights=None)  # Updated ResNet initialization

        # Create dummy tasks to compute observation window
        input_ffn = torch.randn(1, 256)
        task_ffn1 = Task("TaskFFN1", BiggerFFN(), input_ffn.clone(), bigger_ffn_model_name)
        task_ffn2 = Task("TaskFFN2", BiggerFFN(), input_ffn.clone(), bigger_ffn_model_name)

        input_res = torch.randn(1, 3, 224, 224)
        task_res1 = Task("TaskRes1", resnet_model_updated, input_res.clone(), resnet_model_name)
        task_res2 = Task("TaskRes2", resnet_model_updated, input_res.clone(), resnet_model_name)
        task_res3 = Task("TaskRes3", resnet_model_updated, input_res.clone(), resnet_model_name)

        all_tasks = [task_ffn1, task_ffn2, task_res1, task_res2,task_res3]

        # Compute expected execution time for each task (sum of minimal layer times)
        profile_db = profiler.get_profile_db()

        total_expected_time = 0.0
        for task in all_tasks:
            model_df = profile_db[profile_db['Model'] == task.model_name]
            leaf_layers = [name for name, m in task.model.named_modules() if len(list(m.children())) == 0 and name]
            task_time = 0.0
            for lname in leaf_layers:
                # Find the minimal execution time across all nodes for this layer
                layer_times = model_df[model_df['Layer'] == lname]['Total Execution Time (us)']
                if not layer_times.empty:
                    min_time = layer_times.min() / 1e6  # Convert to seconds
                else:
                    min_time = 1.0  # Fallback
                task_time += min_time
            total_expected_time += task_time

        print(f"\nTotal Expected Execution Time (sum of all tasks' minimal times): {total_expected_time:.4f}s")

        # Apply slack_percent
        slack_percent = 0.2
        observation_window = total_expected_time * (1.0 + slack_percent)
        print(f"Observation Window (with {slack_percent*100}% slack): {observation_window:.4f}s")

        # Create Scheduler with observation window
        scheduler = Scheduler(profiler, nodes, observation_window=observation_window, slack_percent=slack_percent)

        # Decompose & allocate all tasks
        for task in all_tasks:
            scheduler.decompose_and_allocate_task(task, k_stages_max=None)

        return profiler, nodes, all_tasks, scheduler

    # 2) Runtime phase
    def runtime_phase(profiler, nodes, all_tasks, scheduler):
        print("\n=== RUNTIME PHASE ===")

        # Build Taskset
        taskset = Taskset(all_tasks, scheduler)
        evaluator = Evaluator(device='cuda' if torch.cuda.is_available() else 'cpu', print_precision=3)

        # Run evaluator
        # 1) Naive run
        evaluator.run_naive_pytorch(all_tasks)

        # Clear caches if desired
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

        # 2) Pipeline run
        print("\n-- Pipeline run with observation window approach --")
        evaluator.run_pipeline(taskset)

        # Gather node usage info
        usage_info = {scheduler.nodes[i].node_id: scheduler.node_usage[i] for i in range(len(scheduler.nodes))}

        evaluator.compare_results(usage_info)

        # Stop nodes
        for nd in nodes:
            nd.stop()

        # Print final ProfileDB
        print("\nFinal ProfileDB:")
        profiler.print_profile_db()

    # actually run
    profiler, nodes, all_tasks, scheduler = init_phase()
    if profiler is not None and nodes is not None:
        runtime_phase(profiler, nodes, all_tasks, scheduler)
    else:
        print("Init phase failed; skipping runtime.")

# Run the main_demo function
if __name__ == "__main__":
    main_demo()


=== main_demo with Observation Window Logic ===
=== INIT PHASE ===

-- Profiling models --


  df = pd.concat([df, new_row], ignore_index=True)



Total Expected Execution Time (sum of all tasks' minimal times): 170.0000s
Observation Window (with 20.0% slack): 204.0000s
[Scheduler] Task=TaskFFN1, Model=BiggerFFN_ObsWin, best s in [1..4] => 4, final utilization=0.0082
[Scheduler] Stage Decomposition for Task=TaskFFN1:
  Stage=TaskFFN1-stage-0, Node=CPU-0, Layers=['net.0'], DependsOn=[]
  Stage=TaskFFN1-stage-1, Node=CPU-0, Layers=['net.1', 'net.2'], DependsOn=['TaskFFN1-stage-0']
  Stage=TaskFFN1-stage-2, Node=CPU-0, Layers=['net.3', 'net.4'], DependsOn=['TaskFFN1-stage-1']
  Stage=TaskFFN1-stage-3, Node=CPU-0, Layers=['net.5', 'net.6'], DependsOn=['TaskFFN1-stage-2']
[Scheduler] Task=TaskFFN2, Model=BiggerFFN_ObsWin, best s in [1..4] => 4, final utilization=0.0082
[Scheduler] Stage Decomposition for Task=TaskFFN2:
  Stage=TaskFFN2-stage-0, Node=CPU-1, Layers=['net.0'], DependsOn=[]
  Stage=TaskFFN2-stage-1, Node=CPU-1, Layers=['net.1', 'net.2'], DependsOn=['TaskFFN2-stage-0']
  Stage=TaskFFN2-stage-2, Node=CPU-1, Layers=['net.3'