<a href="https://colab.research.google.com/github/rogerwzeng/BigDataSystems/blob/main/mu_Two.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# $\mu$-Two Implementation

### Harvard University, Spring 2025
### CS265 Big Data Systems - Term Project (Systems)  
#### Roger W. Zeng


In [None]:
import torch
import torch.nn as nn
import torch.fx as fx
from torch.fx import Interpreter, GraphModule, symbolic_trace
from torch.profiler import profile, record_function, ProfilerActivity
import torch.cuda as cuda
import torch.cuda.nvtx as nvtx
import torchvision.models as models
from typing import Dict, Set, List, Any
import operator


In [None]:
# ------------------------------
# Nodes in the computation graph
#   Profiling Attributes
#   rank: The position of the node in the topological sort of the graph.
#   gtype: The type of graph this node belongs to, either forward or backward pass.
#   run_time: The runtime of the node in milliseconds.
#   peak_mem: The peak memory usage in bytes.
#   active_mem: The active memory usage in bytes, representing the minimum required memory.
#   Scheduling Attributes
#   to_offload: A list of nodes to be offloaded to host memory after this node is executed.
#   to_delete: A list of nodes to be deleted after this node is executed, further aiding in memory optimization.
#   to_prefetch: A list of nodes to be prefetched from the host memory before this node is executed.
#   to_recompute: A list of nodes to be recomputed before this node is executed. This is relevant for activation checkpointing.
# ------------------------------
"""
class Node:
    def __init__(self, rank, gtype, func=None):
        self.func = func
"""

class Node(fx.Node):  # Inherit from torch.fx.Node
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)  # Initialize parent class attributes
        self.rank = 0
        self.gtype = 'forward'
        self.run_time = None
        self.peak_mem = None
        self.active_mem = None
        self.to_offload = []
        self.to_delete = []
        self.to_prefetch = []
        self.to_recompute = []

    def execute(self, *args, **kwargs):
        if self.func:
            start_time = time.time()
            result = self.func(*args, **kwargs)
            end_time = time.time()
            self.run_time = (end_time - start_time) * 1000  # in milliseconds
            self.peak_mem = torch.cuda.max_memory_allocated()  # in bytes
            self.active_mem = torch.cuda.memory_allocated()  # in bytes
            return result


# -----------------------------------------------------
# Intermediate nodes (feature map tensors) with
# profiling, Swapping and Recomputation attributes
#
# initialized with a tensor and optional references to the last forward access, first backward access, and last backward access nodes.
# -----------------------------------------------------
class IntermediateNode:
    def __init__(self, tensor, last_fw_access=None, first_bw_access=None, last_bw_access=None):
        self.tensor = tensor
        self.memory_size = tensor.storage().nbytes()
        self.inactive_time = None
        self.swap_time = None
        self.last_fw_access = last_fw_access
        self.first_bw_access = first_bw_access
        self.last_bw_access = last_bw_access
        self.prefetch_prompt = None
        self.active_fw_interval = None
        self.active_bw_interval = None
        self.recomp_srcs = []
        self.recomp_graph = None
        self.recomp_cnt = 0
        self.recomp_time = None
        self.total_recomp_time = 0.0
        self.recomp_memory = None
        self.recompute_ratio = None

    def update_inactive_time(self, start_time, end_time):
        """Update the inactive time of the tensor."""
        self.inactive_time = end_time - start_time

    def update_swap_time(self, swap_time):
        """Update the swap time of the tensor."""
        self.swap_time = swap_time

    def update_recomputation(self, recomputation_time, recomputation_memory):
        """Update recomputation attributes."""
        self.recomp_time = recomputation_time
        self.recomp_memory = recomputation_memory
        self.recompute_ratio = self.memory_size / self.recomp_time if self.recomp_time else 0

    def mark_for_recomputation(self, recomputation_sources):
        """Mark the node for recomputation and set its sources."""
        self.recomp_srcs = recomputation_sources

    def set_prefetch_prompt(self, prompt_node):
        """Set the prefetch prompt node."""
        self.prefetch_prompt = prompt_node

    def set_active_intervals(self, fw_start_node, fw_end_node, bw_start_node, bw_end_node):
        """Set the active intervals for forward and backward passes."""
        self.active_fw_interval = (fw_start_node, fw_end_node)
        self.active_bw_interval = (bw_start_node, bw_end_node)

    def recomp_sub_graph(input_tensor):
        # Example recomputation: linear transformation followed by ReLU
        linear_layer = nn.Linear()  #in_features=..., out_features=...)
        relu = nn.ReLU()
        output_tensor = relu(linear_layer(input_tensor))
        return output_tensor

    def recompute_tensor(self):
        """Simulate recomputation of the tensor."""
        # Placeholder for recomputation logic
        if self.recomp_graph and self.recomp_srcs:
            # Execute the sub-graph to regenerate the tensor
            self.tensor = self.recomp_sub_graph(self.recomp_srcs)
            self.recomp_cnt += 1
            self.total_recomp_time += self.recomp_time


In [None]:
class MuTwoProfiler(fx.Interpreter):
  def __init__(self, graph_module: fx.GraphModule, garbage_collect_values=False, graph=None):
    super().__init__(graph_module, garbage_collect_values, graph)
    self.profiler = profile(activities=[ProfilerActivity.CUDA], record_shapes=True)
    self.compute_nodes: Dict[str, Node] = {}  # Stores compute nodes (forward/backward ops)
    self.intermediate_nodes: Dict[str, IntermediateNode] = {}  # Stores tensor nodes

    # Initialize pinned memory buffer and tracking
    self.buffer_size = 1024 * 1024 * 1024  # 1GB buffer
    self.pinned_buffer = torch.empty(self.buffer_size, pin_memory=True)
    self.buffer_offset = 0
    self.tensor_buffer_map = {}  # Maps tensor_id -> (offset, size)

    # Warm up CUDA caching allocator
    self._warmup_cuda_allocator()

  def _warmup_cuda_allocator(self):
    """Warm up the CUDA caching allocator to stabilize measurements"""
    dummy = torch.empty(1024 * 1024, device='cuda')  # 1MB tensor
    del dummy
    torch.cuda.empty_cache()

  def _allocate_buffer_space(self, tensor_size: int) -> int:
    """Allocate space in pinned buffer for tensor"""
    if self.buffer_offset + tensor_size > self.buffer_size:
        self.buffer_offset = 0  # Reset if full (could be more sophisticated)

    offset = self.buffer_offset
    self.buffer_offset += tensor_size
    return offset

  def run_node(self, node: fx.Node) -> Any:
    with record_function(str(node)):
        with self.profiler:
          torch.cuda.reset_peak_memory_stats()
          result = super().run_node(node)
          memory_usage = torch.cuda.max_memory_allocated()
    return result


  def _should_swap_out(self, intermediate_node: IntermediateNode) -> bool:
    """Determine if a tensor should be swapped out based on access patterns"""
    if not intermediate_node.last_fw_access:
        return False

    # Check if tensor won't be needed soon and is large enough to be worth swapping
    memory_threshold = 1024 * 1024  # 1MB
    return (intermediate_node.memory_size > memory_threshold and
            not intermediate_node.first_bw_access)

  def _measure_swap_time(self, tensor: torch.Tensor, operation: str) -> float:
    """Measure time for swapping tensors between CPU and GPU using pinned memory"""
    start_event = cuda.Event(enable_timing=True)
    end_event = cuda.Event(enable_timing=True)

    tensor_id = id(tensor)
    tensor_size = tensor.nelement() * tensor.element_size()

    start_event.record()

    if operation == "swap_out":
        # Allocate buffer space if not already allocated
        if tensor_id not in self.tensor_buffer_map:
            offset = self._allocate_buffer_space(tensor_size)
            self.tensor_buffer_map[tensor_id] = (offset, tensor_size)

        # Copy tensor to pinned memory buffer
        offset, _ = self.tensor_buffer_map[tensor_id]
        buffer_view = self.pinned_buffer[offset:offset + tensor_size].view_as(tensor)
        buffer_view.copy_(tensor.cpu())
        tensor.storage().resize_(0)  # Free GPU memory

    elif operation == "swap_in":
      if tensor_id in self.tensor_buffer_map:
        offset, size = self.tensor_buffer_map[tensor_id]
        # Resize tensor and copy from buffer to GPU
        tensor.storage().resize_(size // tensor.element_size())
        buffer_view = self.pinned_buffer[offset:offset + size].view_as(tensor)
        tensor.copy_(buffer_view.cuda())

    end_event.record()
    end_event.synchronize()
    return start_event.elapsed_time(end_event)

  def _get_memory_stats(self) -> Dict:
    """Get current GPU memory statistics"""
    return {
      'allocated': torch.cuda.memory_allocated(),
      'reserved': torch.cuda.memory_reserved(),
      'max_allocated': torch.cuda.max_memory_allocated()
    }

  def analyze_recomputation(self, node: fx.Node):
    """Analyze potential recomputation opportunities"""
    if node.name in self.intermediate_nodes:
        intermediate_node = self.intermediate_nodes[node.name]

        # Simulate recomputation to measure time and memory impact
        with profile(activities=[ProfilerActivity.CUDA]) as prof:
            intermediate_node.recompute_tensor()

        events = prof.key_averages()
        recomp_time = events.total_average.cuda_time_ns / 1e6  # Convert to ms
        recomp_memory = torch.cuda.max_memory_allocated()

        intermediate_node.update_recomputation(recomp_time, recomp_memory)

  def get_profiling_results(self) -> Dict:
    """Return collected profiling statistics"""
    return {
        'compute_nodes': self.compute_nodes,
        'intermediate_nodes': self.intermediate_nodes
    }

  def _track_tensor_usage(self, node: fx.Node):
    """Track tensor usage patterns for each node"""
    for input_node in node.all_input_nodes:
        if input_node not in self.tensor_usage:
            self.tensor_usage[input_node] = {'first_use': node, 'last_use': node}
        else:
            self.tensor_usage[input_node]['last_use'] = node

  def _get_backward_tensors(self, node: fx.Node) -> List[torch.Tensor]:
    """Get tensors needed for backward pass"""
    # Implementation depends on specific backward pass requirements
    return []

  def _get_forward_tensors(self, node: fx.Node) -> List[torch.Tensor]:
    """Get tensors from forward pass that might be candidates for swapping"""
    # Implementation depends on specific forward pass requirements
    return []

  def _is_last_use(self, node: fx.Node, tensor: torch.Tensor) -> bool:
    """Check if this is the last use of a tensor"""
    tensor_node = self._find_tensor_node(tensor)
    return (tensor_node in self.tensor_usage and
            self.tensor_usage[tensor_node]['last_use'] == node)

  def _find_tensor_node(self, tensor: torch.Tensor) -> fx.Node:
    """Find the node that produced this tensor"""
    # Implementation depends on how tensors are tracked in the graph
    return None


In [None]:
# save this for later
"""
  # Profile node execution
  def run_node(self, node: fx.Node) -> Any:
    # Create or get compute node
    if node.name not in self.compute_nodes:
      self.compute_nodes[node.name] = Node(
          rank=len(self.compute_nodes),
          gtype='forward' if not node.name.startswith('backward') else 'backward',
          func=node.target
      )

    compute_node = self.compute_nodes[node.name]

    # Handle swap-ins for required tensors
    for arg in node.args:
      if isinstance(arg, fx.Node) and arg.name in self.intermediate_nodes:
        tensor_node = self.intermediate_nodes[arg.name]
        if tensor_node.last_fw_access and not tensor_node.first_bw_access:
          swap_time = self._measure_swap_time(tensor_node.tensor, "swap_in")
          tensor_node.update_swap_time(swap_time)

    # Execute and profile the node
    result = compute_node.execute(*node.args, **node.kwargs)

    # Create IntermediateNode for output tensor if applicable
    if isinstance(result, torch.Tensor):
      tensor_name = f"{node.name}_output"
      if tensor_name not in self.intermediate_nodes:
        self.intermediate_nodes[tensor_name] = IntermediateNode(
            tensor=result,
            last_fw_access=compute_node
          )

      intermediate_node = self.intermediate_nodes[tensor_name]

      # Check if tensor can be swapped out
      if self._should_swap_out(intermediate_node):
        swap_time = self._measure_swap_time(result, "swap_out")
        intermediate_node.update_swap_time(swap_time)

    return result
  """

In [None]:

class SimpleModel(nn.Module):
  def __init__(self):
    super(SimpleModel, self).__init__()
    self.linear_relu_stack = nn.Sequential(
        nn.Linear(28*28, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 10),
        nn.ReLU()
      )

  def forward(self, x):
    # x = x + 2
    #logits = x * 3
    logits = self.linear_relu_stack(x)
    return logits

  def modify_model(graph, model):
    for node in graph.nodes:
        if node.op == 'call_module':
            try:
                # Split the target to handle nested modules
                target_parts = node.target.split('.')
                current_module = model
                for part in target_parts:
                    current_module = getattr(current_module, part)

                # Check and replace ReLU with Sigmoid
                if isinstance(current_module, nn.ReLU):
                    print("modified ReLU to Sigmoid")
                    # Replace ReLU with Sigmoid using the parent module and child name
                    parent_module = model
                    for part in target_parts[:-1]:  # Access parent module
                        parent_module = getattr(parent_module, part)
                    setattr(parent_module, target_parts[-1], nn.Sigmoid())

            except AttributeError:
                pass  # Ignore if attribute not found
    return model

    """
              pass  # Ignore if attribute not found

      if node.op == 'call_module' and type(getattr(model, node.target)) is nn.ReLU:
        print("modified ReLU to Sigmoid")
        # Replace ReLU with Sigmoid
        setattr(model, node.target, nn.Sigmoid())

      if node.op == "call_function" and node.target == operator.mul:
        print("changed fro mul to add")
        node.target = operator.add
    return graph
    """

if __name__ == "__main__":

  # instantiate the model
  model = SimpleModel()

  # symbolic trace
  symbolic_traced: fx.GraphModule = symbolic_trace(model)
  print(f"Symbolic Traced Graph:{symbolic_traced.graph}")
  print(f"Symbolic Traced Code:{symbolic_traced.code}")

  # Modifiy the graph
  print(f"Output before modification:{model(torch.rand(1, 28*28))}")
  #print(f"Output before modification:{model(torch.tensor(3.0))}")

  modified_model = SimpleModel.modify_model(symbolic_traced.graph, model)
  modified_traced: fx.GraphModule = symbolic_trace(modified_model)
  print(f"Traced Modified Graph:{modified_traced.graph}")
  #print(f"Output after modification:{modified_model(torch.tensor(3.0))}")
  print(f"Output after modification:{modified_model(torch.rand(1, 28*28))}")


Symbolic Traced Graph:graph():
    %x : [num_users=1] = placeholder[target=x]
    %linear_relu_stack_0 : [num_users=1] = call_module[target=linear_relu_stack.0](args = (%x,), kwargs = {})
    %linear_relu_stack_1 : [num_users=1] = call_module[target=linear_relu_stack.1](args = (%linear_relu_stack_0,), kwargs = {})
    %linear_relu_stack_2 : [num_users=1] = call_module[target=linear_relu_stack.2](args = (%linear_relu_stack_1,), kwargs = {})
    %linear_relu_stack_3 : [num_users=1] = call_module[target=linear_relu_stack.3](args = (%linear_relu_stack_2,), kwargs = {})
    %linear_relu_stack_4 : [num_users=1] = call_module[target=linear_relu_stack.4](args = (%linear_relu_stack_3,), kwargs = {})
    %linear_relu_stack_5 : [num_users=1] = call_module[target=linear_relu_stack.5](args = (%linear_relu_stack_4,), kwargs = {})
    return linear_relu_stack_5
Symbolic Traced Code:


def forward(self, x):
    linear_relu_stack_0 = getattr(self.linear_relu_stack, "0")(x);  x = None
    linear_relu_st