From 5670a80b39fcf863e1f12ace30c1268ed78405b8 Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Mon, 14 Oct 2024 18:41:45 -0700 Subject: [PATCH] check to_copy args in vulkan_partitioner Summary: in exir dialect, to_copy doesn't have dtype arg and it is inferred from the dtype of the output tensor. The args will be of length 1 with the sole arg being the input tensor. Thus the previous check always returns False as args is never > 1. Differential Revision: D64267104 --- .../vulkan/partitioner/vulkan_partitioner.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 2bfd8b44a70..109a61049d2 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -144,9 +144,24 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool: return False - def is_valid_to_copy(self, node: torch.fx.node) -> bool: # pyre-ignore[11] - # lower only if floating point dtype conversion - return len(node.args) > 1 and node.args[1] in (torch.float32, torch.float16) + def is_valid_to_copy(self, node: torch.fx.Node) -> bool: + float_dtypes = [torch.float16, torch.float32] + + if len(node.args) != 1: + return False + + in_arg = node.args[0] + if not isinstance(in_arg, torch.fx.Node): + return False + + in_tensor = in_arg.meta.get("val", None) + out_tensor = node.meta.get("val", None) + + if isinstance(in_tensor, FakeTensor) and isinstance(out_tensor, FakeTensor): + if out_tensor.dtype in float_dtypes and in_tensor.dtype in float_dtypes: + return True + + return False def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node @@ -174,13 +189,13 @@ def _is_node_supported( if target not in VulkanSupportedOperators._ops: return False - features = VulkanSupportedOperators._ops[target] - if target == exir_ops.edge.aten._to_copy.default and not self.is_valid_to_copy( node ): return False + features = VulkanSupportedOperators._ops[target] + if self.require_dynamic_shapes and not features.supports_dynamic_shape: return False