diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 64e672fd695..cb14e96962d 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -51,11 +51,15 @@ class VulkanSupportedOperators(OperatorSupportBase): def __init__( - self, texture_limits: utils.ImageExtents, require_dynamic_shape: bool = False + self, + texture_limits: utils.ImageExtents, + buffer_limit: int, + require_dynamic_shape: bool = False, ) -> None: super().__init__() - self.require_dynamic_shapes = require_dynamic_shape self.texture_limits: utils.ImageExtents = texture_limits + self.buffer_limit = buffer_limit + self.require_dynamic_shapes = require_dynamic_shape def op_node_is_compatible( self, node: torch.fx.Node, features: Optional[OpFeatures] = None @@ -83,6 +87,7 @@ def op_node_is_compatible( node, self.texture_limits ) + can_use_buffers = utils.within_buffer_limit(node, self.buffer_limit) for i, arg in enumerate(node.args): if ( isinstance(arg, torch.fx.Node) @@ -95,10 +100,19 @@ def op_node_is_compatible( valid_texture_layouts = valid_texture_layouts.intersection( arg_texture_layouts ) + can_use_buffers = can_use_buffers and utils.within_buffer_limit( + arg, self.buffer_limit + ) # If there are no valid texture memory layouts, then buffer storage must be # supported by the operator implementation. if len(valid_texture_layouts) == 0: + if not can_use_buffers: + return ( + False, + f"op requires buffers that exceed the buffer limit ({self.buffer_limit})", + ) + compatible = VkStorageType.BUFFER in features.supported_storage_types() reason = "op is compatible" if not compatible: @@ -309,10 +323,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: texture_limits: utils.ImageExtents = self.options.get( "texture_limits", utils.DEFAULT_TEXTURE_LIMITS ) + buffer_limit: int = self.options.get("buffer_limit", utils.DEFAULT_BUFFER_LIMIT) capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, VulkanSupportedOperators( texture_limits, + buffer_limit, require_dynamic_shape=self.options.get("require_dynamic_shapes", False), ), allows_single_node_partition=True, diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 2e9fbba01c7..a6db780309d 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -87,6 +87,7 @@ def is_tensor_node(node: torch.fx.Node) -> bool: ImageExtents = Tuple[int, int, int] DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048) +DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024) class PackedDim(IntEnum): @@ -113,6 +114,22 @@ class PackedDim(IntEnum): } +def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int: + """ + Checks whether the tensors produced by the given node can fit within the device's + GPU buffer limit, which represents the maximum number of elements that can be stored + in a GPU buffer. + """ + assert is_tensor_node(node) + + if isinstance(node.meta["val"], FakeTensor): + return node.meta["val"].numel() < buffer_limit + elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + return all(x.numel() < buffer_limit for x in node.meta["val"]) + else: + raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}") + + def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents: """ Calculate the image extents that will be used to represent a tensor with the given sizes