Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ def insert_transition_node(
node: torch.fx.Node,
arg: torch.fx.Node,
arg_node_repr: utils.TensorRepr,
) -> None:
) -> torch.fx.Node:
"""
Insert a clone node to transition the tensor associated with `arg` to a tensor with
the requested representation `arg_node_repr`, and use the cloned node as an argument
to `node` instead of `arg`.

Returns the newly created clone node.
"""
with graph_module.graph.inserting_before(node):
clone_node = graph_module.graph.create_node(
Expand All @@ -45,6 +47,7 @@ def insert_transition_node(
clone_node.meta["spec"].const = False
utils.set_node_repr(clone_node, arg_node_repr)
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)
return clone_node


def set_arg_node_repr_or_transition(
Expand All @@ -53,14 +56,16 @@ def set_arg_node_repr_or_transition(
arg_i: int,
arg_node_repr: utils.TensorRepr,
dirty: bool,
transition_cache: dict | None = None,
) -> bool:
"""
Does one of following:
1. Sets the `node_repr` of the argument at `arg_i` of `op_node` if the argument node
does not currently have a `node_repr`
2. No-op if the current `node_repr` is already the same as the requested represetnation.
3. Insert a transition node to create a copy of the argument with the desired `node_repr`
if the current `node_repr` is different than what is needed.
if the current `node_repr` is different than what is needed. If a transition clone
already exists for the same (source, target_repr) pair, reuse it.
"""
arg_node = op_node.args[arg_i]

Expand All @@ -78,15 +83,33 @@ def single_node_impl(node: torch.fx.Node) -> bool:
if cur_node_repr == arg_node_repr:
return False

assert utils.is_single_tensor_node(node)

# Check if a transition clone already exists for this (source, target_repr).
cache_key = (
node,
arg_node_repr.storage_type,
arg_node_repr.memory_layout,
)
if transition_cache is not None and cache_key in transition_cache:
cached_clone = transition_cache[cache_key]
node.replace_all_uses_with(cached_clone, lambda x, y=op_node: x == y)
if not dirty:
logger.info(
f"[Vulkan Delegate] Reusing transition for {op_node.format_node()}:"
)
logger.info(f" arg {arg_i} ({node}): reusing {cached_clone}")
return True

if not dirty:
logger.info(
f"[Vulkan Delegate] Inserting transition(s) for {op_node.format_node()}:"
)

# Existing node representation is different; insert a transition node
# Currently, the transition node insertion logic can only handle single tensor nodes
assert utils.is_single_tensor_node(node)
insert_transition_node(graph_module, op_node, node, arg_node_repr)
clone_node = insert_transition_node(graph_module, op_node, node, arg_node_repr)

if transition_cache is not None:
transition_cache[cache_key] = clone_node

logger.info(f" arg {arg_i} ({node}): ({cur_node_repr}) -> ({arg_node_repr})")

Expand Down Expand Up @@ -407,7 +430,10 @@ def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None:
self.constrain_op_out_repset(op_repsets)

def set_op_node_tensor_reprs(
self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node
self,
graph_module: torch.fx.GraphModule,
op_node: torch.fx.Node,
transition_cache: dict | None = None,
) -> None:
"""
For an operator representated by `op_node`, get the OpRepSets associated with
Expand Down Expand Up @@ -458,7 +484,12 @@ def set_op_node_tensor_reprs(
if isinstance(arg_node, torch.fx.Node):
transitions_inserted = (
set_arg_node_repr_or_transition(
graph_module, op_node, i, arg_node_repr, transitions_inserted
graph_module,
op_node,
i,
arg_node_repr,
transitions_inserted,
transition_cache,
)
or transitions_inserted
)
Expand All @@ -473,12 +504,14 @@ def set_op_node_tensor_reprs(
i,
arg_node_repr,
transitions_inserted,
transition_cache,
)
or transitions_inserted
)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
transition_cache: dict = {}
for node in graph_module.graph.nodes:
self.set_op_node_tensor_reprs(graph_module, node)
self.set_op_node_tensor_reprs(graph_module, node, transition_cache)

return PassResult(graph_module, True)
13 changes: 9 additions & 4 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,14 +1511,19 @@ def try_constrain_with_arg_repset(
if not arg_current_repset.any_in_common(source_repset):
return False

# Compute the narrowed repset (intersection of current arg and source).
narrowed = arg_current_repset.make_intersect(source_repset)

if self.sync_primary_io_repr:
if not self.get_out_repset(0).has_compatible_packed_dim_info_set(
source_repset
):
# Check that the narrowed result is compatible with the output.
# Using the intersection rather than the raw source_repset avoids
# rejecting valid constraints where the source has extra layouts
# (e.g. ANY_TEXTURE includes HP/CP) that don't exist in the output
# but also don't appear in the intersection.
if not self.get_out_repset(0).has_compatible_packed_dim_info_set(narrowed):
return False

# If this point is reached, then it is possible to constrain
narrowed = arg_current_repset.make_intersect(source_repset)
self.args_repset_list[arg_i] = narrowed

# Propagate to other synced args via packed-dim compatibility
Expand Down
Loading