Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
texture_limits: utils.ImageExtents,
buffer_limit: int,
require_dynamic_shape: bool = False,
skip_bool_tensors: bool = False,
operator_blocklist: Optional[Set[OpKey]] = None,
operator_allowlist: Optional[Set[OpKey]] = None,
fusable_subgraphs: Optional[List[PatternMatch]] = None,
Expand All @@ -69,6 +70,7 @@ def __init__(
self.texture_limits: utils.ImageExtents = texture_limits
self.buffer_limit = buffer_limit
self.require_dynamic_shapes = require_dynamic_shape
self.skip_bool_tensors = skip_bool_tensors
self.operator_blocklist: Set[OpKey] = (
operator_blocklist if operator_blocklist is not None else set()
)
Expand Down Expand Up @@ -117,6 +119,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
return False, "no operator implementation"
features = get_op_features(target)

# bool tensors are internally represented with int8 buffers, which may not be
# supported by some GPUs. Therefore, provide the option to skip these tensors.
if self.skip_bool_tensors and utils.op_contains_bool_tensor(node):
return False, f"op {utils.node_io_str(node)} contains bool tensor"

# Get the possible tensor representations for each tensor participating in the
# this operator. Then check that all tensors are representable as either a
# buffer or texture.
Expand Down Expand Up @@ -398,6 +405,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
texture_limits,
buffer_limit,
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
skip_bool_tensors=self.options.get("skip_bool_tensors", False),
operator_blocklist=self.operator_blocklist,
operator_allowlist=self.operator_allowlist,
fusable_subgraphs=fusable_subgraphs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ permute_buffer:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint8
shader_variants:
- NAME: permute_buffer
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ permute_texture:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint8
shader_variants:
- NAME: permute_texture3d
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ transfer_buffer:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint8
shader_variants:
- NAME: select_buffer
OP_NAME: select
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ transfer_texture:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint8
shader_variants:
- NAME: select_texture3d
OP_NAME: select
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ view:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint8
shader_variants:
- NAME: view
7 changes: 6 additions & 1 deletion backends/vulkan/test/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def __init__(

class Partition(BaseStages.Partition):
def __init__(self, partitioner: Optional[Partitioner] = None):
vk_compile_spec = {"skip_bool_tensors": True}
super().__init__(
partitioner=partitioner or VulkanPartitioner(),
partitioner=partitioner or VulkanPartitioner(vk_compile_spec),
)


Expand All @@ -55,6 +56,10 @@ def __init__(
partitioners: Optional[List[Partitioner]] = None,
edge_compile_config: Optional[EdgeCompileConfig] = None,
):
if partitioners is None:
vk_compile_spec = {"skip_bool_tensors": True}
partitioners = [VulkanPartitioner(vk_compile_spec)]

super().__init__(
default_partitioner_cls=VulkanPartitioner,
partitioners=partitioners,
Expand Down
75 changes: 71 additions & 4 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,47 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool:
return False


def ndim_of(node: Any) -> Optional[int]:
"""
Returns the number of dimensions of the tensor produced by the given node
"""
if not is_single_tensor_node(node):
return None

return node.meta["val"].ndim


def is_unsqueezed_vector(node: torch.fx.Node) -> bool:
"""
Returns True if the node's tensor has all dimensions equal to 1 except for the last dimension.
"""
if not is_single_tensor_node(node):
return False

tensor = node.meta["val"]
assert isinstance(tensor, FakeTensor)

if len(tensor.shape) < 1:
return False
# All dims except last are 1, last can be any size
return all(dim == 1 for dim in tensor.shape[:-1])


def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
"""
Returns true if the operator used to compute the given node contains a bool tensor
"""
if is_tensor_node(node) and tensor_node_is_bool(node):
return True

for arg_node in node.args:
# pyre-ignore[6]
if is_tensor_node(arg_node) and tensor_node_is_bool(arg_node):
return True

return False


def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
primary_arg_idx: Optional[int] = None
for i, arg_node in enumerate(node.args):
Expand Down Expand Up @@ -568,6 +609,16 @@ def make_intersect(self, other: "TensorRepSet") -> "TensorRepSet":
self.valid_texture_layouts & other.valid_texture_layouts,
)

def make_union(self, other: "TensorRepSet") -> "TensorRepSet":
"""
Merge this TensorRepSet with another TensorRepSet, returning a new TensorRepSet
with the union of the two.
"""
return TensorRepSet(
self.valid_buffer_layouts | other.valid_buffer_layouts,
self.valid_texture_layouts | other.valid_texture_layouts,
)

def is_compatible(self, storage: TensorRepr) -> bool:
"""
Check if this TensorRepr is compatible with the given TensorRepSet.
Expand Down Expand Up @@ -693,10 +744,6 @@ def make_filtered_tensor_repset(
if len(tensor_val.shape) > 4:
return TensorRepSet(tensor_repset.valid_buffer_layouts, set())

# Bool tensors are currently not supported
if tensor_val.dtype == torch.bool:
return NO_STORAGE

return TensorRepSet(tensor_repset.valid_buffer_layouts, valid_texture_layouts)


Expand Down Expand Up @@ -1230,6 +1277,26 @@ def is_in_8bit_range(tensor: torch.Tensor) -> bool:
##


def normalize_dims(dims: Union[int, List[int]], ndim: int) -> Union[int, List[int]]:
"""
Normalize dimension indices to be non-negative and within [0, ndim).
Accepts a single int or a list of ints.
"""
if isinstance(dims, int):
if dims < 0:
dims += ndim

return dims

normalized = []
for d in dims:
if d < 0:
d += ndim
normalized.append(d)

return normalized


def nchw_dim_to_whcn_dim(nchw_dim: int, ndim: int) -> int:
# Handle negative indices for nchw_dim
if nchw_dim < 0:
Expand Down
Loading