Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,17 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
arg, self.buffer_limit
)

op_available_layouts = features.supported_memory_layouts(
VkStorageType.TEXTURE_3D
)

can_use_texture = any(
layout in op_available_layouts for layout in valid_texture_layouts
)

# If there are no valid texture memory layouts, then buffer storage must be
# supported by the operator implementation.
if len(valid_texture_layouts) == 0:
if not can_use_texture:
if not can_use_buffers:
return (
False,
Expand All @@ -131,17 +139,8 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
reason = "op requires buffers which is not supported by op impl"
return compatible, reason

op_available_layouts = features.supported_memory_layouts(
VkStorageType.TEXTURE_3D
)

is_compatible = any(
layout in op_available_layouts for layout in valid_texture_layouts
)
if not is_compatible:
return False, "Required texutre memory layout not supported"

return is_compatible, "Op is compatible"
return True, "Op is compatible"

def node_is_compatible(
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
Expand Down Expand Up @@ -220,7 +219,7 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, boo

def log_skip(self, node: torch.fx.Node, reason: str) -> None:
if node.op == "call_function":
logger.info(
print(
f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}"
)

Expand All @@ -231,6 +230,7 @@ def is_node_supported(
return r

def _is_node_supported(self, node: torch.fx.Node) -> bool:
print("is_node_supported")
target = node.target
if node.target == torch.ops.higher_order.auto_functionalized:
first_arg = node.args[0]
Expand Down Expand Up @@ -340,6 +340,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# subgraphs containing the nodes with the tags
partition_tags = {}

logger.setLevel(logging.INFO)
print("partition")
print("set level but no logging...")

texture_limits: utils.ImageExtents = self.options.get(
"texture_limits", utils.DEFAULT_TEXTURE_LIMITS
)
Expand Down
1 change: 1 addition & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,7 @@ def to_edge_transform_and_lower(
for name, partitioner_list in partitioner.items():
if i < len(partitioner_list):
method_to_partitioner[name] = partitioner_list[i]
print("to_backen")
edge_manager = edge_manager.to_backend(method_to_partitioner)

for name, program in edge_manager._edge_programs.items():
Expand Down
Loading