From 801bde1a28381ed775d0df886b7a29abe696a480 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 27 Jan 2025 16:40:28 -0800 Subject: [PATCH 1/2] Fix memory profiling for memory.view ops Pull Request resolved: https://github.com/pytorch/executorch/pull/7925 ATT ghstack-source-id: 263342054 @exported-using-ghexport Differential Revision: [D68448333](https://our.internmc.facebook.com/intern/diff/D68448333/) --- exir/memory_planning.py | 1 + util/activation_memory_profiler.py | 26 +++++++++++++++++--------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 6f0ab2a3922..be471b6f745 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -747,6 +747,7 @@ def apply_algo( storage with tensors in the outer module. TODO: make these optimizations once we have some baseline working. """ + specs = update_all_tensors_lifetime(graph_module, graph_signature) bufsizes: List[int] = algo( graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output diff --git a/util/activation_memory_profiler.py b/util/activation_memory_profiler.py index c149a461224..80e4fac56e2 100644 --- a/util/activation_memory_profiler.py +++ b/util/activation_memory_profiler.py @@ -15,7 +15,8 @@ import torch from executorch.exir import ExecutorchProgramManager from executorch.exir.memory_planning import get_node_tensor_specs -from executorch.exir.tensor import num_bytes_from_shape_and_dtype + +from executorch.exir.tensor import num_bytes_from_shape_and_dtype, TensorSpec from torch.export import ExportedProgram @@ -53,10 +54,11 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline] """ nodes = graph.nodes memory_timeline: List[Optional[MemoryTimeline]] = [None for _ in range(len(nodes))] + unique_specs: set[TensorSpec] = set() for _, node in enumerate(nodes): if node.op == "output": continue - if node.target == memory.alloc: + if node.target == memory.alloc or node.target == memory.view: continue tensor_specs = get_node_tensor_specs(node) if tensor_specs is None: @@ -65,6 +67,9 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline] # TODO: Make use of mem_id in the allocation info if tensor_spec is None or tensor_spec.mem_id is None or tensor_spec.const: continue + if tensor_spec in unique_specs: + continue + unique_specs.add(tensor_spec) start, end = tensor_spec.lifetime size = num_bytes_from_shape_and_dtype( typing.cast(torch.Size, tensor_spec.shape), tensor_spec.dtype @@ -75,6 +80,7 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline] memory_timeline_j = memory_timeline[j] if memory_timeline_j is None: memory_timeline_j = MemoryTimeline() + memory_timeline[j] = memory_timeline_j assert memory_timeline_j memory_timeline_j.allocations.append( Allocation( @@ -106,6 +112,7 @@ def generate_memory_trace( chrome_trace_filename: str, enable_memory_offsets: bool = False, method_name: str = "forward", + ommit_metadata: bool = False, ): """ Generate the memory timeline from the given ExecuTorch program. @@ -151,13 +158,14 @@ def generate_memory_trace( e["pid"] = int(allocation.memory_id) e["tid"] = tid e["args"] = {} - e["args"]["op_name"] = f"{allocation.op_name}" - # ID refers to memory space, typically from 1 to N. - # For CPU, everything is allocated on one "space", other backends may have multiple. - e["args"]["Memory ID"] = allocation.memory_id - e["args"]["fqn"] = f"{allocation.fqn}" - e["args"]["source"] = f"{allocation.file_and_line_num}" - e["args"]["bytes"] = allocation.size_bytes + if not ommit_metadata: + e["args"]["op_name"] = f"{allocation.op_name}" + # ID refers to memory space, typically from 1 to N. + # For CPU, everything is allocated on one "space", other backends may have multiple. + e["args"]["Memory ID"] = allocation.memory_id + e["args"]["fqn"] = f"{allocation.fqn}" + e["args"]["source"] = f"{allocation.file_and_line_num}" + e["args"]["bytes"] = allocation.size_bytes start_time += allocation_size_kb trace_events.append(e) tid += 1 From 510d7955bb6fa0a24d62771049810928efe0b868 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 27 Jan 2025 16:40:30 -0800 Subject: [PATCH 2/2] [ET][Memory planning] Improve greedy memory planning. Pull Request resolved: https://github.com/pytorch/executorch/pull/7926 This diff replaces the old greedy algorithm. Older algorithm resulted in 35% worse compared to theoretical optimum. THis matter for long context even more since additional overhead can be few hundred MB. For example the theorical optimial for llama3_2 8B, 4-bit quantized modelw ith context length of 2k needs about 1G of memory. This theoretcial max can be observed by looking at the peaks in memory profile. Current agorithm resulted in about 1.6GB of planned memory. New algorithm reduce that to about 1.1G. ghstack-source-id: 263342052 @exported-using-ghexport Differential Revision: [D68448332](https://our.internmc.facebook.com/intern/diff/D68448332/) --- backends/vulkan/vulkan_preprocess.py | 8 +- exir/memory_planning.py | 225 +++++++++++++++++++++++---- exir/passes/memory_planning_pass.py | 22 ++- exir/tests/test_joint_graph.py | 4 +- exir/tests/test_memory_planning.py | 44 +++++- 5 files changed, 270 insertions(+), 33 deletions(-) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 6e406a10ba6..02ca8d2bec5 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -6,6 +6,8 @@ # pyre-strict +from functools import partial + from typing import Any, Dict, final, List import executorch.backends.vulkan.utils as utils @@ -17,7 +19,6 @@ from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform - from executorch.backends.vulkan._passes import ( insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, @@ -41,6 +42,8 @@ PreprocessResult, ) from executorch.exir.backend.utils import DelegateMappingBuilder + +from executorch.exir.memory_planning import greedy from executorch.exir.pass_base import ExportPass, PassBase from executorch.exir.passes import MemoryPlanningPass, SpecPropPass @@ -189,11 +192,12 @@ def preprocess( # noqa: C901 # Finally, apply dynamic shape passes and memory planning pass. These passes # must be applied only when the graph structure is finalized. + greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False) program = apply_passes( program, [ ConstraintBasedSymShapeEvalPass(), - MemoryPlanningPass(), + MemoryPlanningPass(memory_planning_algo=greedy_memory_planning), ], ) diff --git a/exir/memory_planning.py b/exir/memory_planning.py index be471b6f745..1fc1f0e02fd 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -11,7 +11,7 @@ import operator import typing from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union import torch @@ -117,6 +117,17 @@ def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool: return has_overlap + @classmethod + def _debug_message_from_specs( + cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec + ) -> str: + message = ( + f"lhs life time: {lhs_spec.lifetime}, rhs lifetime: {rhs_spec.lifetime} " + ) + message += f"lhs: mem_id {lhs_spec.mem_id} storage: {lhs_spec.mem_offset}, {lhs_spec.allocated_memory} " + message += f"rhs: mem_id {rhs_spec.mem_id} storage: {rhs_spec.mem_offset}, {rhs_spec.allocated_memory}" + return message + def verify_storage_reuse( self, allow_lifetime_and_storage_overlap: bool = False ) -> int: @@ -159,7 +170,7 @@ def verify_storage_reuse( lhs_spec, rhs_spec ): raise InternalError( - f"Unexpected storage overlap: lhs {lhs_spec}, rhs {rhs_spec}" + f"Unexpected storage overlap: {Verifier._debug_message_from_specs(lhs_spec, rhs_spec)}" ) # Check that each mem_obj_id is consistent with whether the tensors have @@ -454,6 +465,18 @@ def update_all_tensors_lifetime( return specs +@dataclass +class AllocationSpec: + """ + AllocationSpec is used to represent the allocation of a tensor. + """ + + # The offset of the tensor in the shared object/pool. + offset: int + # TensorSpec + spec: TensorSpec + + @dataclass class SharedObject: r""" @@ -470,8 +493,15 @@ class SharedObject: offset: int # size of this shared object in bytes size: int + # When the object is first created + first_used_index: int # the object will be available for index (last_used_index + 1) last_used_index: int + # list of allocations belong to this shared object + allocations: List[AllocationSpec] = field(default_factory=list) + + def __repr__(self) -> str: + return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])" def materialize_buffer( @@ -489,35 +519,124 @@ def materialize_buffer( return total_size -def _size_abs_dif(sobj: SharedObject, spec: TensorSpec) -> int: +def _does_not_overlap(sobj: SharedObject, spec: TensorSpec) -> bool: r""" - Calculate the absolute different between the size of a shared object and - a tensor. + Check if a shared object and a tensor do not overlap. """ - return abs(sobj.size - spec.allocated_memory) + for alloc in sobj.allocations: + if not ( + spec.lifetime[1] < alloc.spec.lifetime[0] + or spec.lifetime[0] > alloc.spec.lifetime[1] + ): + return False + return True + + +def _find_max_overlapping_allocations_offset( + sobj: SharedObject, spec: TensorSpec +) -> int: + max_offset = 0 + for alloc in sobj.allocations: + if ( + spec.lifetime[1] < alloc.spec.lifetime[0] + or spec.lifetime[0] > alloc.spec.lifetime[1] + ): + continue + max_offset = max(alloc.offset + alloc.spec.allocated_memory, max_offset) + return max_offset def pick_shared_obj( - shared_objects: List[SharedObject], spec: TensorSpec + shared_objects: List[SharedObject], + spec: TensorSpec, + allow_overlapping_allocations: bool = True, ) -> SharedObject: r""" - Pick the available shared object with closest size to the tensor. - If there are no available shared object left, create a new one. + Pick the available shared object to which to assign this spec, + or create a new one + Algorithm details + Previous: Look at every spec in chronological order. Find if previously allocated object + allows it to fit in. If not, allocate a new object. + New: + - Sort all the specs by allocation size + - Process the specs in order + - If the spec's size in smaller than previously allocated buckets: + - Conditions under which previously allocated bucket can be used: + - Lifetime of the spec does not overlap with lifetime of the bucket. + - In this case allocate spec to that bucket and expand its lifetime. + - Spec is allocated at offset = 0 in this bucket. + - Add this spec to allocated object's list of specs. + - Lifetime of the spec overlaps with lifetime of the bucket, + partially or fully (e.g. spec's lifetime subset of bucket's lifetime) + - If none of the specs in the bucket overlaps with spec's lifetime. + - Allocate spec to the bucket at offset = 0. + - Add this spec to the bucket's list of specs. + - Expand bucket's lifetime accounting for added spec's lifetime. + - If one or more specs in the bucket overlaps with spec's lifetime. + - Collect offsets (at which the given overlapping spec is allocated in the bucket). + of all the overlapping specs, and find the max offset. + - Allocate spec to the bucket at offset = max_offset + max_offset_spec_size. + - Add this spec to the bucket's list of specs. + - Expand bucket's lifetime accounting for added spec's lifetime. + - If none of these conditions are met, allocate a new bucket. + - Add spec to this bucket. + - Update bucket's lifetime to that of the spec. + - If the spec's size is larger than previously allocated buckets, allocate a new bucket. + - Size and lifetime of this bucket is that of the spec + + Proof of correctness: + - If allocating a new bucket, it is correct. + - If allocating spec to an existing bucket, whose lifetime does not overlap with any + of the previously allocated specs' lifetime, then the allocation is correct. + Proof of correctness by induction when adding spec to an existing bucket: + - If all previous allocations in the given bucket are correct: + - Then the new one being added must be correct because when the requested allocation + overlaps with one or more previous allocations, we find the largest offset among + all the overlapping allocations, and allocate the new spec at that offset. Hence, + the allocation at such an offset, will not overlap with any previous allocations. + Base case: A newly added allocation within a bucket with single allocation is correct: + because a) it must fit and b) its lifetime must not overlap with object's lifetime. + This holds true because of the following invariants: + - Once a bucket is created, it is never resized. + - All the allocations within a bucket follow this: + - Span, defined by allocation's offset + size, of two allocations can only overlap, + if their timelines do not overlap. """ - # TODO: do better than linear scan picked = None for sobj in shared_objects: - if spec.lifetime[0] > sobj.last_used_index: - if picked is None or _size_abs_dif(sobj, spec) < _size_abs_dif( - picked, spec - ): - picked = sobj - sobj.last_used_index = spec.lifetime[1] - sobj.size = max(sobj.size, spec.allocated_memory) + if _does_not_overlap(sobj, spec): + assert sobj.size >= spec.allocated_memory, "Allocation specs are not sorted" + picked = sobj + sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0]) + sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1]) + allocation_spec = AllocationSpec(0, spec) + picked.allocations.append(allocation_spec) + break + + if picked is None and allow_overlapping_allocations: + for sobj in shared_objects: + max_offset = _find_max_overlapping_allocations_offset(sobj, spec) + if max_offset > 0: + if max_offset + spec.allocated_memory <= sobj.size: + picked = sobj + sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0]) + sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1]) + allocation_spec = AllocationSpec(max_offset, spec) + picked.allocations.append(allocation_spec) + break + if picked is None: picked = SharedObject( - len(shared_objects), -1, spec.allocated_memory, spec.lifetime[1] + len(shared_objects), + -1, + spec.allocated_memory, + spec.lifetime[0], + spec.lifetime[1], ) + allocation_spec = AllocationSpec(0, spec) + picked.allocations.append(allocation_spec) + picked.first_used_index = spec.lifetime[0] + picked.last_used_index = spec.lifetime[1] shared_objects.append(picked) return picked @@ -550,13 +669,50 @@ def get_node_tensor_specs( ] +# Little bit hacky to check if the graph contains +# XNNPACK delegate +# Why? + + +def _contains_xnnpack_delegate(graph_module: torch.fx.GraphModule) -> bool: + for node in graph_module.graph.nodes: + if node.target == executorch_call_delegate: + lowered_module = getattr( + graph_module.graph.owning_module, node.args[0].target + ) + if "xnnpack" in lowered_module.backend_id.lower(): + return True + return False + + def greedy( graph_module: torch.fx.GraphModule, alignment: int, graph_signature: Optional[ExportGraphSignature] = None, alloc_graph_input: bool = True, alloc_graph_output: bool = True, + allow_overlapping_allocations: bool = True, ) -> List[int]: + r"""Greedy algorithm to allocate memory for tensors in the graph. + alloc_graph_input: If set to true, the algorithm will allocate memory for graph input. + alloc_graph_output: If set to true, the algorithm will allocate memory for graph output. + allow_overlapping_allocations: If set to true, allows for allocations that overlap + in their lifetime but are at different offsets in the storage. By default true. + This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping + allocations disabled + """ + # padding allocation with 64 bytes. + # this requirement is really for XNNPACK backend which can read tensors + # beyond the end of the tensor. This is done for performance + # optimizations in XNNPACK. + # While accounting for backend specific requirement is not the right choice + # in backend agnostic memory planning, we do it here as it seems most appropriate. + # Right now this applies to greedy only so any other + # algorithm that plans memory for XNNPACK backend will + # not have this. + extra_padded_bytes = 0 + if _contains_xnnpack_delegate(graph_module): + extra_padded_bytes = 64 spec2obj = {} shared_objects = defaultdict(list) # Don't do assertion in collect_specs_from_nodes if we have already encountered @@ -565,6 +721,9 @@ def greedy( # For each tensor, pick the available shared object with closest size to # the tensor. If there are no available shared object left, create a new # one. + import bisect + + sorted_specs = [] for spec in collect_specs_from_nodes( graph_module.graph.nodes, graph_signature, @@ -572,10 +731,16 @@ def greedy( ignore_graph_input=not alloc_graph_input, ignore_graph_output=not alloc_graph_output, ): + bisect.insort(sorted_specs, spec, key=lambda x: x.allocated_memory) + sorted_specs.reverse() + + for spec in sorted_specs: if spec.mem_id is None: spec.mem_id = 1 spec.realign(alignment) - spec2obj[spec] = pick_shared_obj(shared_objects[spec.mem_id], spec) + spec2obj[spec] = pick_shared_obj( + shared_objects[spec.mem_id], spec, allow_overlapping_allocations + ) if len(shared_objects) == 0: # Cannot find any tensor in the graph that needs to be allocated. @@ -583,6 +748,7 @@ def greedy( total_sizes = [0, 0] else: total_sizes = [0] * (max(shared_objects.keys()) + 1) + num_specs_processed = 0 for mem_id in shared_objects: input_total_size = 0 if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None): @@ -594,13 +760,20 @@ def greedy( total_sizes[mem_id] = materialize_buffer( shared_objects[mem_id], input_total_size ) - - # Since we now know the number of shared objects we need and the size of - # each shared object, we can assign offset in the memory buffer for each - # shared object. - for spec, sobj in spec2obj.items(): - spec.mem_obj_id = sobj.idx - spec.mem_offset = sobj.offset + total_sizes[mem_id] += extra_padded_bytes + + # Since we now know the number of shared objects we need and the size of + # each shared object, we can assign offset in the memory buffer for each + # shared object. + for sobj in shared_objects[mem_id]: + for alloc in sobj.allocations: + spec = alloc.spec + alloc.spec.mem_obj_id = sobj.idx + alloc.spec.mem_offset = sobj.offset + alloc.offset + num_specs_processed += 1 + assert ( + len(spec2obj) == num_specs_processed + ), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs" logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}") return total_sizes diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index 112b8f5fc52..710042fcd00 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -6,7 +6,8 @@ import logging import warnings -from typing import Callable, List, Optional +from functools import partial +from typing import Any, Callable, List, Optional import torch from executorch.exir.error import internal_assert @@ -24,6 +25,17 @@ from torch.export.exported_program import ExportGraphSignature +# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function +def _callable_name(any_callable: Callable[..., Any]) -> str: + if isinstance(any_callable, partial): + return any_callable.func.__name__ + + try: + return any_callable.__name__ + except AttributeError: + return str(any_callable) + + class MemoryPlanningPass(PassBase): def __init__( self, @@ -127,4 +139,12 @@ def run( f"The {getattr(self.memory_planning_algo, '__name__', repr(self.memory_planning_algo))} algorithm reuses storage for {num_reuse_pairs} pair of tensors" ) verifier.verify_graph_input_output() + if ( + callable(self.memory_planning_algo) + and _callable_name(self.memory_planning_algo) == "greedy" + ): + # Only verify storage reuse for greedy algorithm + # At the moment cadence backends memory planning fails this + # I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function + verifier.verify_storage_reuse() return PassResult(graph_module, True) diff --git a/exir/tests/test_joint_graph.py b/exir/tests/test_joint_graph.py index f3b6f0ed557..349fa92e826 100644 --- a/exir/tests/test_joint_graph.py +++ b/exir/tests/test_joint_graph.py @@ -84,13 +84,13 @@ def forward(self, x, y): et.executorch_program.execution_plan[0] .values[0] .val.allocation_info.memory_offset_low, - 0, + 96, ) self.assertEqual( et.executorch_program.execution_plan[0] .values[1] .val.allocation_info.memory_offset_low, - 48, + 224, ) loss = m(*example_inputs) diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index 1f94f0341f1..de1dd8abc36 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -106,6 +106,28 @@ def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: return (torch.randn(2),) +class LinearsWithDifferentSizeAndViewOps(torch.nn.Module): + def __init__(self) -> None: + super(LinearsWithDifferentSizeAndViewOps, self).__init__() + self.linears = torch.nn.ModuleList() + for x in [8, 16, 32, 64]: + self.linears.append(torch.nn.Linear(x, x * 2)) + + def forward(self, i: torch.Tensor) -> torch.Tensor: + o1 = i + for linear in self.linears: + o1 = linear(o1) + o1 = o1.view(-1, 64, 2) + o1 = o1 + 1 + o2 = i + for linear in self.linears: + o2 = linear(o2) + return o1.view(-1, 128) + o2 + + def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(3, 8),) + + class ModuleReturnTwo(nn.Module): def __init__(self) -> None: super(ModuleReturnTwo, self).__init__() @@ -360,6 +382,13 @@ def verify_overlap_placeholders( ], ) + test_linear_with_view: Callable[..., None] = maketest( + LinearsWithDifferentSizeAndViewOps, + criteria=[ + (greedy, True), + ], + ) + # greedy algorithm will reuse memory if we let the algorithm allocate # memory for both graph input and output. test_list_arg: Callable[..., None] = maketest( @@ -508,15 +537,26 @@ def test_multiple_pools( verifier.verify_graph_input_output() idx = 0 + reference_output = {} + actual_output = {} for node in graph_module.graph.nodes: if node.op == "placeholder" or ( node.op == "call_function" and node.target in (torch.ops.aten.add.out, torch.ops.aten.mul.out) ): mem_id, mem_offset = expected_allocs[idx] - self.assertEqual(node.meta["spec"].mem_id, mem_id) - self.assertEqual(node.meta["spec"].mem_offset, mem_offset) + actual_mem_id, actual_mem_offset = ( + node.meta["spec"].mem_id, + node.meta["spec"].mem_offset, + ) + if (mem_id, mem_offset) not in reference_output: + reference_output[(mem_id, mem_offset)] = 1 + actual_output[(actual_mem_id, actual_mem_offset)] = 1 + else: + reference_output[(mem_id, mem_offset)] += 1 + actual_output[(actual_mem_id, actual_mem_offset)] += 1 idx += 1 + self.assertEqual(reference_output, actual_output) self.assertEqual(graph_module.meta["non_const_buffer_sizes"], expected_bufsizes) def test_constants_not_memory_planned(self) -> None: