From c2a99a863980654d82af675ed3769201ed2ef718 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Mon, 1 Apr 2024 15:37:13 -0700 Subject: [PATCH 1/3] [MPS] Fix static llama AOT tracing --- backends/apple/mps/operators/constant_ops.py | 7 ++++++- backends/apple/mps/operators/node_visitor.py | 6 +++--- backends/apple/mps/operators/op_clone.py | 13 +++++++++---- .../apple/mps/serialization/mps_graph_schema.py | 2 +- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/backends/apple/mps/operators/constant_ops.py b/backends/apple/mps/operators/constant_ops.py index aa92258da86..dacb09215cb 100644 --- a/backends/apple/mps/operators/constant_ops.py +++ b/backends/apple/mps/operators/constant_ops.py @@ -55,7 +55,12 @@ def define_node( elif node.target == exir_ops.edge.aten.empty.memory_format: fill_value = 0 elif node.target == exir_ops.edge.aten.scalar_tensor.default: - fill_value = float(node.args[0]) + fill_value = cast(float, node.args[0]) + + if fill_value == float("-inf"): + fill_value = "-inf" + elif fill_value == float("inf"): + fill_value = "inf" dtype = MPSDataType.mps_data_type_float32 if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None: diff --git a/backends/apple/mps/operators/node_visitor.py b/backends/apple/mps/operators/node_visitor.py index ed2afea7727..0f5411737c9 100644 --- a/backends/apple/mps/operators/node_visitor.py +++ b/backends/apple/mps/operators/node_visitor.py @@ -157,10 +157,10 @@ def define_scalar( """ assert isinstance(val, int) or isinstance(val, float) - if val in self.tensor_to_id: - return self.tensor_to_id[val] - id = self.get_serialized_id(val, mps_graph) + id = len(mps_graph.mps_values) + self.tensor_to_id[val] = id + tensor = torch.tensor(val) constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data( tensor, mps_graph, mps_data_type, id diff --git a/backends/apple/mps/operators/op_clone.py b/backends/apple/mps/operators/op_clone.py index 2310ae02da7..aca2b14dbe1 100644 --- a/backends/apple/mps/operators/op_clone.py +++ b/backends/apple/mps/operators/op_clone.py @@ -8,10 +8,10 @@ NodeVisitor, register_node_visitor, ) -from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSGraph +from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSGraph, MPSView from executorch.backends.apple.mps.utils.mps_utils import get_input_node from executorch.exir.dialects._ops import ops as exir_ops - +from executorch.backends.transforms import get_shape @register_node_visitor class CloneVisitor(NodeVisitor): @@ -31,5 +31,10 @@ def define_node( raise RuntimeError( "aten._to_copy not supported with more than one argument currently" ) - input_id = self.define_tensor(get_input_node(node, 0), mps_graph) - self.tensor_to_id[node] = input_id + mps_node = self.create_unary_node(node, mps_graph, MPSView) + view_shape = get_shape(node) + + mps_node.mpsnode_union.num_dims = len(view_shape) + mps_node.mpsnode_union.shape = view_shape + + mps_graph.mps_nodes.append(mps_node) diff --git a/backends/apple/mps/serialization/mps_graph_schema.py b/backends/apple/mps/serialization/mps_graph_schema.py index 04a41abaa1c..66697b04b7d 100644 --- a/backends/apple/mps/serialization/mps_graph_schema.py +++ b/backends/apple/mps/serialization/mps_graph_schema.py @@ -391,7 +391,7 @@ class MPSFull: @dataclass class MPSFullLike(MPSNode1x1): - fill_value: float = 0.0 + fill_value: Union[float, str] = 0.0 dtype: MPSDataType = MPSDataType.mps_data_type_float32 From 53f89a294ff3c0ba32fdc5067a8dfdab928040b2 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Mon, 1 Apr 2024 16:05:09 -0700 Subject: [PATCH 2/3] Fix lint --- backends/apple/mps/operators/node_visitor.py | 1 - backends/apple/mps/operators/op_clone.py | 9 ++++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/backends/apple/mps/operators/node_visitor.py b/backends/apple/mps/operators/node_visitor.py index 0f5411737c9..e9f879db88a 100644 --- a/backends/apple/mps/operators/node_visitor.py +++ b/backends/apple/mps/operators/node_visitor.py @@ -157,7 +157,6 @@ def define_scalar( """ assert isinstance(val, int) or isinstance(val, float) - id = len(mps_graph.mps_values) self.tensor_to_id[val] = id diff --git a/backends/apple/mps/operators/op_clone.py b/backends/apple/mps/operators/op_clone.py index aca2b14dbe1..ad277973f89 100644 --- a/backends/apple/mps/operators/op_clone.py +++ b/backends/apple/mps/operators/op_clone.py @@ -8,10 +8,13 @@ NodeVisitor, register_node_visitor, ) -from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSGraph, MPSView -from executorch.backends.apple.mps.utils.mps_utils import get_input_node -from executorch.exir.dialects._ops import ops as exir_ops +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSGraph, + MPSView, +) from executorch.backends.transforms import get_shape +from executorch.exir.dialects._ops import ops as exir_ops + @register_node_visitor class CloneVisitor(NodeVisitor): From 2577dc1871a060f835df27ca0302551650546141 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Mon, 1 Apr 2024 17:31:40 -0700 Subject: [PATCH 3/3] Revert clone changes - not needed --- backends/apple/mps/operators/op_clone.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/backends/apple/mps/operators/op_clone.py b/backends/apple/mps/operators/op_clone.py index ad277973f89..2310ae02da7 100644 --- a/backends/apple/mps/operators/op_clone.py +++ b/backends/apple/mps/operators/op_clone.py @@ -8,11 +8,8 @@ NodeVisitor, register_node_visitor, ) -from executorch.backends.apple.mps.serialization.mps_graph_schema import ( - MPSGraph, - MPSView, -) -from executorch.backends.transforms import get_shape +from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSGraph +from executorch.backends.apple.mps.utils.mps_utils import get_input_node from executorch.exir.dialects._ops import ops as exir_ops @@ -34,10 +31,5 @@ def define_node( raise RuntimeError( "aten._to_copy not supported with more than one argument currently" ) - mps_node = self.create_unary_node(node, mps_graph, MPSView) - view_shape = get_shape(node) - - mps_node.mpsnode_union.num_dims = len(view_shape) - mps_node.mpsnode_union.shape = view_shape - - mps_graph.mps_nodes.append(mps_node) + input_id = self.define_tensor(get_input_node(node, 0), mps_graph) + self.tensor_to_id[node] = input_id