# PyDart Library – Third Testing Checkpoint

## Overview

In this checkpoint, I focused on reconstructing the DAG, a critical milestone in the development of PyDart. Success at this stage was essential for validating the correctness of the DAG-based execution approach.

## Main Contributions

1. **Exploring DAG Reconstruction**  
   - Investigated the `torch.fx` module for extracting computational graphs.  
   - Studied topological sorting algorithms to determine execution order.  

2. **Implementing Basic Scheduling**  
   - Designed a simple Round-Robin (RR) scheme for stage allocation.  
   - Implemented a preliminary grouping mechanism for:  
     - Layers (referred to as **modules** in the DAG representation of the DNN forward pass).  
     - Functions (operators).  
     - I grouped them using a constant.(experimented with this as well.)

3. **Validating DAG-Based Execution**  
   - Successfully reconstructed the graph using topological sorting.  
   - Verified correct node allocation within the execution framework.  

4. **Output Matching and Evaluation**  
   - Compared outputs against a standard native PyTorch implementation of the `Evaluator` class.  
   - Confirmed that a DAG-based approach would preserve correctness.  

## Key Insights and Next Steps

- This approach  **confirmed that DAG-based execution would produce correct outputs**, reinforcing the feasibility of this methodology for future iterations.  
- This validation was crucial for refining DAG-based scheduling and improving execution efficiency in subsequent notebooks.

---

**Note**: Multiple iterations were performed during the development of these classes. The key checkpoints included here highlight the most significant developments. Subsequent iterations followed a similar approach.


In [None]:
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
from torchvision import models

# Define a SimpleCNN with torch.cat in its forward pass
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, padding=2)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * 28 * 28, 10)  # Assuming input images are 28x28

    def forward(self, x):
        # First path
        x1 = self.conv1(x)
        x1 = self.relu(x1)

        # Second path
        x2 = self.conv2(x)
        x2 = self.relu(x2)

        # Concatenate along the channel dimension
        x = torch.cat((x1, x2), dim=1)

        x = self.flatten(x)
        x = self.fc(x)
        return x

def trace_and_extract_info(model, model_name="Model"):
    """
    Traces the given model using torch.fx and extracts:
    - Topological Execution Order
    - Dependency List
    """
    print(f"\n{'='*60}\nTracing and Extracting Information for {model_name}\n{'='*60}")

    # Trace the model
    traced_model = symbolic_trace(model)

    # Extract Topological Order (execution order)
    topological_order = []
    for node in traced_model.graph.nodes:
        topological_order.append(node.name)

    # Extract Dependency List
    dependency_list = {}
    for node in traced_model.graph.nodes:
        dependencies = [arg.name for arg in node.all_input_nodes]
        dependency_list[node.name] = dependencies

    # Print Topological Execution Order
    print("\nTopological Execution Order:")
    for idx, node_name in enumerate(topological_order, 1):
        print(f"{idx}: {node_name}")

    # Print Dependency List
    print("\nDependency List:")
    for node_name in topological_order:
        deps = dependency_list[node_name]
        print(f"{node_name}: {deps}")

def main():
    # Instantiate the SimpleCNN
    simple_cnn = SimpleCNN()
    simple_cnn.eval()  # Set to evaluation mode

    # Instantiate the ResNet18 model from torchvision
    resnet18 = models.resnet18(pretrained=False)  # Set pretrained=True if you want pretrained weights
    resnet18.eval()  # Set to evaluation mode

    # Trace and extract information for SimpleCNN
    trace_and_extract_info(simple_cnn, model_name="SimpleCNN")

    # Trace and extract information for ResNet18
    trace_and_extract_info(resnet18, model_name="ResNet18")

if __name__ == "__main__":
    main()





Tracing and Extracting Information for SimpleCNN

Topological Execution Order:
1: x
2: conv1
3: relu
4: conv2
5: relu_1
6: cat
7: flatten
8: fc
9: output

Dependency List:
x: []
conv1: ['x']
relu: ['conv1']
conv2: ['x']
relu_1: ['conv2']
cat: ['relu', 'relu_1']
flatten: ['cat']
fc: ['flatten']
output: ['fc']

Tracing and Extracting Information for ResNet18

Topological Execution Order:
1: x
2: conv1
3: bn1
4: relu
5: maxpool
6: layer1_0_conv1
7: layer1_0_bn1
8: layer1_0_relu
9: layer1_0_conv2
10: layer1_0_bn2
11: add
12: layer1_0_relu_1
13: layer1_1_conv1
14: layer1_1_bn1
15: layer1_1_relu
16: layer1_1_conv2
17: layer1_1_bn2
18: add_1
19: layer1_1_relu_1
20: layer2_0_conv1
21: layer2_0_bn1
22: layer2_0_relu
23: layer2_0_conv2
24: layer2_0_bn2
25: layer2_0_downsample_0
26: layer2_0_downsample_1
27: add_2
28: layer2_0_relu_1
29: layer2_1_conv1
30: layer2_1_bn1
31: layer2_1_relu
32: layer2_1_conv2
33: layer2_1_bn2
34: add_3
35: layer2_1_relu_1
36: layer3_0_conv1
37: layer3_0_bn1
38: laye

In [None]:
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
from torchvision import models
import warnings

# Suppress specific deprecation warnings from torchvision
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.models._utils")

# Define a SimpleCNN with torch.cat in its forward pass
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, padding=2)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * 28 * 28, 10)  # Assuming input images are 28x28

    def forward(self, x):
        # First path
        x1 = self.conv1(x)
        x1 = self.relu(x1)

        # Second path
        x2 = self.conv2(x)
        x2 = self.relu(x2)

        # Concatenate along the channel dimension
        x = torch.cat((x1, x2), dim=1)

        x = self.flatten(x)
        x = self.fc(x)
        return x

def trace_and_extract_info(model, model_name="Model"):
    """
    Traces the given model using torch.fx and extracts:
    - Topological Execution Order
    - Dependency List
    - Interaction Details: How outputs of dependencies are used as inputs
    """
    print(f"\n{'='*60}\nTracing and Extracting Information for {model_name}\n{'='*60}")

    # Trace the model
    traced_model = symbolic_trace(model)

    # Extract Topological Order (execution order)
    topological_order = [node.name for node in traced_model.graph.nodes]

    # Extract Dependency List
    dependency_list = {node.name: [arg.name for arg in node.all_input_nodes] for node in traced_model.graph.nodes}

    # Extract Interaction Details
    interaction_details = {node.name: [arg.name for arg in node.all_input_nodes] for node in traced_model.graph.nodes}

    # Print Topological Execution Order
    print("\nTopological Execution Order:")
    for idx, node_name in enumerate(topological_order, 1):
        print(f"{idx}: {node_name}")

    # Print Dependency List
    print("\nDependency List:")
    for node_name in topological_order:
        deps = dependency_list[node_name]
        print(f"{node_name}: {deps}")

    # Print Interaction Details
    print("\nInteraction Details:")
    for node_name in topological_order:
        inputs = interaction_details[node_name]
        if inputs:
            print(f"{node_name} receives inputs from: {', '.join(inputs)}")
        else:
            print(f"{node_name} receives inputs from: None (Input Placeholder)")

    return traced_model  # Return the traced model for further use

def resolve_arg(arg, node_outputs):
    """
    Recursively replaces Node references in args and kwargs with their actual outputs.
    """
    if isinstance(arg, torch.fx.Node):
        return node_outputs[arg.name]
    elif isinstance(arg, (list, tuple)):
        return type(arg)(resolve_arg(a, node_outputs) for a in arg)
    elif isinstance(arg, dict):
        return {k: resolve_arg(v, node_outputs) for k, v in arg.items()}
    else:
        return arg

def reconstruct_forward(traced_model, input_tensor):
    """
    Reconstructs the forward pass based on the traced graph's topological order and dependencies.
    Returns the reconstructed output.
    """
    print("\nReconstructing Forward Pass Based on Topological Order and Dependencies...\n")

    # Dictionary to hold the outputs of each node
    node_outputs = {}

    # Iterate through nodes in topological order
    for node in traced_model.graph.nodes:
        if node.op == 'placeholder':
            # Assign the input tensor to the node
            node_outputs[node.name] = input_tensor
            print(f"Executing Placeholder: {node.name}")
        elif node.op == 'call_module':
            # Get the submodule
            submodule = dict(traced_model.named_modules())[node.target]
            # Resolve the arguments
            args = resolve_arg(node.args, node_outputs)
            kwargs = resolve_arg(node.kwargs, node_outputs)
            # Execute the submodule
            print(f"Executing Module: {node.target} with inputs {[arg.name for arg in node.all_input_nodes]}")
            node_outputs[node.name] = submodule(*args, **kwargs)
        elif node.op == 'call_function':
            # Get the function
            func = node.target
            # Resolve the arguments
            args = resolve_arg(node.args, node_outputs)
            kwargs = resolve_arg(node.kwargs, node_outputs)
            # Execute the function
            print(f"Executing Function: {func.__name__} with inputs {[arg.name for arg in node.all_input_nodes]}")
            node_outputs[node.name] = func(*args, **kwargs)
        elif node.op == 'call_method':
            # Get the method
            method = getattr(resolve_arg(node.args[0], node_outputs), node.target)
            # Resolve the arguments (excluding the first argument which is 'self')
            args = resolve_arg(node.args[1:], node_outputs)
            kwargs = resolve_arg(node.kwargs, node_outputs)
            # Execute the method
            print(f"Executing Method: {node.target} on {node.args[0].name} with inputs {[arg.name for arg in node.all_input_nodes[1:]]}")
            node_outputs[node.name] = method(*args, **kwargs)
        elif node.op == 'output':
            # Assign the final output
            output = resolve_arg(node.args[0], node_outputs)
            print(f"Final Output: {node.name}")
            node_outputs[node.name] = output
        else:
            raise NotImplementedError(f"Operation {node.op} is not supported.")

    # The output node holds the final output
    return node_outputs['output']

def main():
    # Instantiate the SimpleCNN
    simple_cnn = SimpleCNN()
    simple_cnn.eval()  # Set to evaluation mode

    # Instantiate the ResNet18 model from torchvision
    # Note: 'pretrained' is deprecated, use 'weights' instead if using a newer torchvision version
    try:
        resnet18 = models.resnet18(weights=None)  # Set weights=models.ResNet18_Weights.DEFAULT for pretrained
    except TypeError:
        # For older torchvision versions
        resnet18 = models.resnet18(pretrained=False)
    resnet18.eval()  # Set to evaluation mode

    # Trace and extract information for SimpleCNN
    traced_simple_cnn = trace_and_extract_info(simple_cnn, model_name="SimpleCNN")

    # Create a sample input tensor for SimpleCNN
    # Assuming input images are 28x28 with 3 channels
    simple_cnn_input = torch.randn(1, 3, 28, 28)

    # Reconstruct forward pass for SimpleCNN within no_grad context
    with torch.no_grad():
        reconstructed_simple_cnn_output = reconstruct_forward(traced_simple_cnn, simple_cnn_input)

        # Get actual output from SimpleCNN
        actual_simple_cnn_output = simple_cnn(simple_cnn_input)

    # Compare the outputs
    print("\nComparing Reconstructed Output with Actual Output for SimpleCNN:")
    if torch.allclose(reconstructed_simple_cnn_output, actual_simple_cnn_output, atol=1e-6):
        print("Success: The reconstructed output matches the actual output.")
    else:
        print("Warning: The reconstructed output does not match the actual output.")

    # Trace and extract information for ResNet18
    traced_resnet18 = trace_and_extract_info(resnet18, model_name="ResNet18")

    # Create a sample input tensor for ResNet18
    # ResNet18 typically expects 224x224 images with 3 channels
    resnet18_input = torch.randn(1, 3, 224, 224)

    # Reconstruct forward pass for ResNet18 within no_grad context
    with torch.no_grad():
        reconstructed_resnet18_output = reconstruct_forward(traced_resnet18, resnet18_input)

        # Get actual output from ResNet18
        actual_resnet18_output = resnet18(resnet18_input)

    # Compare the outputs
    print("\nComparing Reconstructed Output with Actual Output for ResNet18:")
    if torch.allclose(reconstructed_resnet18_output, actual_resnet18_output, atol=1e-6):
        print("Success: The reconstructed output matches the actual output.")
    else:
        print("Warning: The reconstructed output does not match the actual output.")

    print(reconstructed_simple_cnn_output, actual_simple_cnn_output)
    print(reconstructed_resnet18_output.shape, actual_resnet18_output.shape)

if __name__ == "__main__":
    main()



Tracing and Extracting Information for SimpleCNN

Topological Execution Order:
1: x
2: conv1
3: relu
4: conv2
5: relu_1
6: cat
7: flatten
8: fc
9: output

Dependency List:
x: []
conv1: ['x']
relu: ['conv1']
conv2: ['x']
relu_1: ['conv2']
cat: ['relu', 'relu_1']
flatten: ['cat']
fc: ['flatten']
output: ['fc']

Interaction Details:
x receives inputs from: None (Input Placeholder)
conv1 receives inputs from: x
relu receives inputs from: conv1
conv2 receives inputs from: x
relu_1 receives inputs from: conv2
cat receives inputs from: relu, relu_1
flatten receives inputs from: cat
fc receives inputs from: flatten
output receives inputs from: fc

Reconstructing Forward Pass Based on Topological Order and Dependencies...

Executing Placeholder: x
Executing Module: conv1 with inputs ['x']
Executing Module: relu with inputs ['conv1']
Executing Module: conv2 with inputs ['x']
Executing Module: relu with inputs ['conv2']
Executing Function: cat with inputs ['relu', 'relu_1']
Executing Module: fla

In [None]:
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
from torchvision import models
import warnings

# Suppress specific deprecation warnings from torchvision
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.models._utils")

# Define a SimpleCNN with torch.cat in its forward pass
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, padding=2)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * 28 * 28, 10)  # Assuming input images are 28x28

    def forward(self, x):
        # First path
        x1 = self.conv1(x)
        x1 = self.relu(x1)

        # Second path
        x2 = self.conv2(x)
        x2 = self.relu(x2)

        # Concatenate along the channel dimension
        x = torch.cat((x1, x2), dim=1)

        x = self.flatten(x)
        x = self.fc(x)
        return x

def trace_and_extract_info(model, model_name="Model"):
    """
    Traces the given model using torch.fx and extracts:
    - Topological Execution Order
    - Dependency List
    - Interaction Details: How outputs of dependencies are used as inputs
    """
    print(f"\n{'='*60}\nTracing and Extracting Information for {model_name}\n{'='*60}")

    # Trace the model
    traced_model = symbolic_trace(model)

    # Extract Topological Order (execution order)
    topological_order = [node.name for node in traced_model.graph.nodes]

    # Extract Dependency List
    dependency_list = {node.name: [arg.name for arg in node.all_input_nodes] for node in traced_model.graph.nodes}

    # Extract Interaction Details
    interaction_details = {node.name: [arg.name for arg in node.all_input_nodes] for node in traced_model.graph.nodes}

    # Print Topological Execution Order
    print("\nTopological Execution Order:")
    for idx, node_name in enumerate(topological_order, 1):
        print(f"{idx}: {node_name}")

    # Print Dependency List
    print("\nDependency List:")
    for node_name in topological_order:
        deps = dependency_list[node_name]
        print(f"{node_name}: {deps}")

    # Print Interaction Details
    print("\nInteraction Details:")
    for node_name in topological_order:
        inputs = interaction_details[node_name]
        if inputs:
            print(f"{node_name} receives inputs from: {', '.join(inputs)}")
        else:
            print(f"{node_name} receives inputs from: None (Input Placeholder)")

    return traced_model  # Return the traced model for further use

def resolve_arg(arg, node_outputs):
    """
    Recursively replaces Node references in args and kwargs with their actual outputs.
    """
    if isinstance(arg, torch.fx.Node):
        return node_outputs[arg.name]
    elif isinstance(arg, (list, tuple)):
        return type(arg)(resolve_arg(a, node_outputs) for a in arg)
    elif isinstance(arg, dict):
        return {k: resolve_arg(v, node_outputs) for k, v in arg.items()}
    else:
        return arg

def group_topological_order(topological_order, group_size=2):
    """
    Groups the topological order list into fixed-size groups.
    Each group is named as 'stage-1', 'stage-2', etc.
    """
    groups = {}
    num_groups = (len(topological_order) + group_size - 1) // group_size  # Ceiling division

    for i in range(num_groups):
        start_idx = i * group_size
        end_idx = start_idx + group_size
        group_nodes = topological_order[start_idx:end_idx]
        stage_name = f"stage-{i+1}"
        groups[stage_name] = group_nodes

    return groups

def reconstruct_forward_with_groups(traced_model, input_tensor, group_size=2):
    """
    Reconstructs the forward pass based on grouped topological order.
    Each group of operations is executed sequentially as a stage.
    Also prints dependencies between stages.
    """
    print("\nReconstructing Forward Pass Using Grouped Stages...\n")

    # Get the topological order
    topological_order = [node.name for node in traced_model.graph.nodes]
    print("Topological Order:", topological_order)

    # Group the operations
    grouped_operations = group_topological_order(topological_order, group_size=group_size)

    # Print the grouped stages
    print("Grouped Stages:")
    for stage, nodes in grouped_operations.items():
        print(f"{stage}: {nodes}")

    # Map each node to its stage for dependency analysis
    node_to_stage = {}
    for stage, nodes in grouped_operations.items():
        for node in nodes:
            node_to_stage[node] = stage

    # Analyze dependencies between stages
    stage_dependencies = {stage: set() for stage in grouped_operations.keys()}

    for stage, nodes in grouped_operations.items():
        for node in nodes:
            # Find the node object by name
            node_obj = next(n for n in traced_model.graph.nodes if n.name == node)
            dependencies = node_obj.all_input_nodes
            for dep in dependencies:
                dep_stage = node_to_stage.get(dep.name, None)
                if dep_stage and dep_stage != stage:
                    stage_dependencies[stage].add(dep_stage)

    # Print dependencies between stages
    print("\nDependencies Between Stages:")
    for stage, deps in stage_dependencies.items():
        if deps:
            deps_formatted = ', '.join(sorted(deps))
            print(f"{stage} depends on: {deps_formatted}")
        else:
            print(f"{stage} has no dependencies on other stages.")

    # Dictionary to hold the outputs of each node
    node_outputs = {}

    # Iterate through each stage
    for stage, nodes in grouped_operations.items():
        print(f"\n--- Executing {stage} ---")
        for node_name in nodes:
            node = next(n for n in traced_model.graph.nodes if n.name == node_name)
            print(node.name,node.op)

            if node.op == 'placeholder':
                # Assign the input tensor to the node
                node_outputs[node.name] = input_tensor
                print(input_tensor)
                print(f"[{stage}] Executing Placeholder: {node.name}")
            elif node.op == 'call_module':
                # Get the submodule
                submodule = dict(traced_model.named_modules())[node.target]
                # Resolve the arguments
                args = resolve_arg(node.args, node_outputs)
                kwargs = resolve_arg(node.kwargs, node_outputs)
                # Execute the submodule
                print(f"[{stage}] Executing Module: {node.target} with inputs {[arg.name for arg in node.all_input_nodes]}")
                node_outputs[node.name] = submodule(*args, **kwargs)
            elif node.op == 'call_function':
                # Get the function
                func = node.target
                # Resolve the arguments
                args = resolve_arg(node.args, node_outputs)
                kwargs = resolve_arg(node.kwargs, node_outputs)
                # Execute the function
                input_names = ', '.join([arg.name for arg in node.all_input_nodes])
                print(f"[{stage}] Executing Function: {func.__name__} with inputs {input_names}")
                node_outputs[node.name] = func(*args, **kwargs)
            elif node.op == 'call_method':
                # Get the method
                method = getattr(resolve_arg(node.args[0], node_outputs), node.target)
                # Resolve the arguments (excluding the first argument which is 'self')
                args = resolve_arg(node.args[1:], node_outputs)
                kwargs = resolve_arg(node.kwargs, node_outputs)
                # Execute the method
                input_names = ', '.join([arg.name for arg in node.all_input_nodes[1:]])
                print(f"[{stage}] Executing Method: {node.target} on {node.args[0].name} with inputs {input_names}")
                node_outputs[node.name] = method(*args, **kwargs)
            elif node.op == 'output':
                # Assign the final output
                output = resolve_arg(node.args[0], node_outputs)
                print(f"[{stage}] Final Output: {node.name}")
                node_outputs[node.name] = output
            else:
                raise NotImplementedError(f"Operation {node.op} is not supported.")

    # The output node holds the final output
    return node_outputs['output']

def main():
    # Instantiate the SimpleCNN
    simple_cnn = SimpleCNN()
    simple_cnn.eval()  # Set to evaluation mode

    # Instantiate the ResNet18 model from torchvision
    # Note: 'pretrained' is deprecated, use 'weights' instead if using a newer torchvision version
    try:
        resnet18 = models.resnet18(weights=None)  # Set weights=models.ResNet18_Weights.DEFAULT for pretrained
    except TypeError:
        # For older torchvision versions
        resnet18 = models.resnet18(pretrained=False)
    resnet18.eval()  # Set to evaluation mode

    # Trace and extract information for SimpleCNN
    traced_simple_cnn = trace_and_extract_info(simple_cnn, model_name="SimpleCNN")

    # Create a sample input tensor for SimpleCNN
    # Assuming input images are 28x28 with 3 channels
    simple_cnn_input = torch.randn(1, 3, 28, 28)

    # Reconstruct forward pass for SimpleCNN within no_grad context
    with torch.no_grad():
        reconstructed_simple_cnn_output = reconstruct_forward_with_groups(traced_simple_cnn, simple_cnn_input, group_size=2)

        # Get actual output from SimpleCNN
        actual_simple_cnn_output = simple_cnn(simple_cnn_input)

    # Compare the outputs
    print(reconstructed_simple_cnn_output,actual_simple_cnn_output)
    print("\nComparing Reconstructed Output with Actual Output for SimpleCNN:")
    if torch.allclose(reconstructed_simple_cnn_output, actual_simple_cnn_output, atol=1e-6):
        print("Success: The reconstructed output matches the actual output.")
    else:
        print("Warning: The reconstructed output does not match the actual output.")

    # Trace and extract information for ResNet18
    traced_resnet18 = trace_and_extract_info(resnet18, model_name="ResNet18")

    # Create a sample input tensor for ResNet18
    # ResNet18 typically expects 224x224 images with 3 channels
    resnet18_input = torch.randn(1, 3, 224, 224)

    # Reconstruct forward pass for ResNet18 within no_grad context
    with torch.no_grad():
        reconstructed_resnet18_output = reconstruct_forward_with_groups(traced_resnet18, resnet18_input, group_size=4)

        # Get actual output from ResNet18
        actual_resnet18_output = resnet18(resnet18_input)

    # Compare the outputs
    print("\nComparing Reconstructed Output with Actual Output for ResNet18:")
    if torch.allclose(reconstructed_resnet18_output, actual_resnet18_output, atol=1e-6):
        print("Success: The reconstructed output matches the actual output.")
    else:
        print("Warning: The reconstructed output does not match the actual output.")

if __name__ == "__main__":
    main()



Tracing and Extracting Information for SimpleCNN

Topological Execution Order:
1: x
2: conv1
3: relu
4: conv2
5: relu_1
6: cat
7: flatten
8: fc
9: output

Dependency List:
x: []
conv1: ['x']
relu: ['conv1']
conv2: ['x']
relu_1: ['conv2']
cat: ['relu', 'relu_1']
flatten: ['cat']
fc: ['flatten']
output: ['fc']

Interaction Details:
x receives inputs from: None (Input Placeholder)
conv1 receives inputs from: x
relu receives inputs from: conv1
conv2 receives inputs from: x
relu_1 receives inputs from: conv2
cat receives inputs from: relu, relu_1
flatten receives inputs from: cat
fc receives inputs from: flatten
output receives inputs from: fc

Reconstructing Forward Pass Using Grouped Stages...

Topological Order: ['x', 'conv1', 'relu', 'conv2', 'relu_1', 'cat', 'flatten', 'fc', 'output']
Grouped Stages:
stage-1: ['x', 'conv1']
stage-2: ['relu', 'conv2']
stage-3: ['relu_1', 'cat']
stage-4: ['flatten', 'fc']
stage-5: ['output']

Dependencies Between Stages:
stage-1 has no dependencies on o