Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:

# Mark that this node is going to be represented as a TensorRef type in the
# Vulkan compute graph. This annotation is used in later graph passes.
node.meta["vkdg_tensorref"] = True
node.meta["etvk_tensorref"] = True

# Get the list of node users that do not handle their own prepacking
nodes_to_replace_input = []
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/_passes/remove_local_scalar_dense_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None:

for user in node.users:
if node_is_local_scalar_dense_chain(user):
node.meta["vkdg_is_scalar_tensor"] = True
node.meta["etvk_is_scalar_tensor"] = True


def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None:
Expand All @@ -74,7 +74,7 @@ def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node)
if replace_node.args[0].meta["val"].numel() == 1:
replace_node = replace_node.args[0]
assert isinstance(replace_node, torch.fx.Node)
assert replace_node.meta.get("vkdg_is_scalar_tensor", True)
assert replace_node.meta.get("etvk_is_scalar_tensor", True)

with graph.inserting_after(node):
node.replace_all_uses_with(replace_node)
Expand Down
607 changes: 371 additions & 236 deletions backends/vulkan/_passes/tag_memory_meta_pass.py

Large diffs are not rendered by default.

Loading
Loading