diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 0bdc16616ef..059b3a07be0 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -59,6 +59,7 @@ def __init__( texture_limits: utils.ImageExtents, buffer_limit: int, require_dynamic_shape: bool = False, + skip_bool_tensors: bool = False, operator_blocklist: Optional[Set[OpKey]] = None, operator_allowlist: Optional[Set[OpKey]] = None, fusable_subgraphs: Optional[List[PatternMatch]] = None, @@ -69,6 +70,7 @@ def __init__( self.texture_limits: utils.ImageExtents = texture_limits self.buffer_limit = buffer_limit self.require_dynamic_shapes = require_dynamic_shape + self.skip_bool_tensors = skip_bool_tensors self.operator_blocklist: Set[OpKey] = ( operator_blocklist if operator_blocklist is not None else set() ) @@ -117,6 +119,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex return False, "no operator implementation" features = get_op_features(target) + # bool tensors are internally represented with int8 buffers, which may not be + # supported by some GPUs. Therefore, provide the option to skip these tensors. + if self.skip_bool_tensors and utils.op_contains_bool_tensor(node): + return False, f"op {utils.node_io_str(node)} contains bool tensor" + # Get the possible tensor representations for each tensor participating in the # this operator. Then check that all tensors are representable as either a # buffer or texture. @@ -398,6 +405,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: texture_limits, buffer_limit, require_dynamic_shape=self.options.get("require_dynamic_shapes", False), + skip_bool_tensors=self.options.get("skip_bool_tensors", False), operator_blocklist=self.operator_blocklist, operator_allowlist=self.operator_allowlist, fusable_subgraphs=fusable_subgraphs, diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml index 81675ae8917..6fe5a67c286 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml @@ -6,5 +6,6 @@ permute_buffer: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: permute_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml index f68b8dcdd3d..22d1bdd7b51 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml @@ -6,5 +6,6 @@ permute_texture: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: permute_texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.yaml index f68b2bd1250..62bab110828 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.yaml @@ -8,6 +8,7 @@ transfer_buffer: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: select_buffer OP_NAME: select diff --git a/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.yaml index 6922f120e49..7824801ddb6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.yaml @@ -8,6 +8,7 @@ transfer_texture: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: select_texture3d OP_NAME: select diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.yaml b/backends/vulkan/runtime/graph/ops/glsl/view.yaml index 33364a25225..e963d253424 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/view.yaml @@ -8,5 +8,6 @@ view: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: view diff --git a/backends/vulkan/test/tester.py b/backends/vulkan/test/tester.py index b2066a06ec0..0707c09158f 100644 --- a/backends/vulkan/test/tester.py +++ b/backends/vulkan/test/tester.py @@ -44,8 +44,9 @@ def __init__( class Partition(BaseStages.Partition): def __init__(self, partitioner: Optional[Partitioner] = None): + vk_compile_spec = {"skip_bool_tensors": True} super().__init__( - partitioner=partitioner or VulkanPartitioner(), + partitioner=partitioner or VulkanPartitioner(vk_compile_spec), ) @@ -55,6 +56,10 @@ def __init__( partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, ): + if partitioners is None: + vk_compile_spec = {"skip_bool_tensors": True} + partitioners = [VulkanPartitioner(vk_compile_spec)] + super().__init__( default_partitioner_cls=VulkanPartitioner, partitioners=partitioners, diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 09c57f649ae..6a510e65925 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -259,6 +259,47 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool: return False +def ndim_of(node: Any) -> Optional[int]: + """ + Returns the number of dimensions of the tensor produced by the given node + """ + if not is_single_tensor_node(node): + return None + + return node.meta["val"].ndim + + +def is_unsqueezed_vector(node: torch.fx.Node) -> bool: + """ + Returns True if the node's tensor has all dimensions equal to 1 except for the last dimension. + """ + if not is_single_tensor_node(node): + return False + + tensor = node.meta["val"] + assert isinstance(tensor, FakeTensor) + + if len(tensor.shape) < 1: + return False + # All dims except last are 1, last can be any size + return all(dim == 1 for dim in tensor.shape[:-1]) + + +def op_contains_bool_tensor(node: torch.fx.Node) -> bool: + """ + Returns true if the operator used to compute the given node contains a bool tensor + """ + if is_tensor_node(node) and tensor_node_is_bool(node): + return True + + for arg_node in node.args: + # pyre-ignore[6] + if is_tensor_node(arg_node) and tensor_node_is_bool(arg_node): + return True + + return False + + def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]: primary_arg_idx: Optional[int] = None for i, arg_node in enumerate(node.args): @@ -568,6 +609,16 @@ def make_intersect(self, other: "TensorRepSet") -> "TensorRepSet": self.valid_texture_layouts & other.valid_texture_layouts, ) + def make_union(self, other: "TensorRepSet") -> "TensorRepSet": + """ + Merge this TensorRepSet with another TensorRepSet, returning a new TensorRepSet + with the union of the two. + """ + return TensorRepSet( + self.valid_buffer_layouts | other.valid_buffer_layouts, + self.valid_texture_layouts | other.valid_texture_layouts, + ) + def is_compatible(self, storage: TensorRepr) -> bool: """ Check if this TensorRepr is compatible with the given TensorRepSet. @@ -693,10 +744,6 @@ def make_filtered_tensor_repset( if len(tensor_val.shape) > 4: return TensorRepSet(tensor_repset.valid_buffer_layouts, set()) - # Bool tensors are currently not supported - if tensor_val.dtype == torch.bool: - return NO_STORAGE - return TensorRepSet(tensor_repset.valid_buffer_layouts, valid_texture_layouts) @@ -1230,6 +1277,26 @@ def is_in_8bit_range(tensor: torch.Tensor) -> bool: ## +def normalize_dims(dims: Union[int, List[int]], ndim: int) -> Union[int, List[int]]: + """ + Normalize dimension indices to be non-negative and within [0, ndim). + Accepts a single int or a list of ints. + """ + if isinstance(dims, int): + if dims < 0: + dims += ndim + + return dims + + normalized = [] + for d in dims: + if d < 0: + d += ndim + normalized.append(d) + + return normalized + + def nchw_dim_to_whcn_dim(nchw_dim: int, ndim: int) -> int: # Handle negative indices for nchw_dim if nchw_dim < 0: