diff --git a/backends/apple/mps/operators/constant_ops.py b/backends/apple/mps/operators/constant_ops.py index aa92258da86..dacb09215cb 100644 --- a/backends/apple/mps/operators/constant_ops.py +++ b/backends/apple/mps/operators/constant_ops.py @@ -55,7 +55,12 @@ def define_node( elif node.target == exir_ops.edge.aten.empty.memory_format: fill_value = 0 elif node.target == exir_ops.edge.aten.scalar_tensor.default: - fill_value = float(node.args[0]) + fill_value = cast(float, node.args[0]) + + if fill_value == float("-inf"): + fill_value = "-inf" + elif fill_value == float("inf"): + fill_value = "inf" dtype = MPSDataType.mps_data_type_float32 if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None: diff --git a/backends/apple/mps/operators/node_visitor.py b/backends/apple/mps/operators/node_visitor.py index ed2afea7727..e9f879db88a 100644 --- a/backends/apple/mps/operators/node_visitor.py +++ b/backends/apple/mps/operators/node_visitor.py @@ -157,10 +157,9 @@ def define_scalar( """ assert isinstance(val, int) or isinstance(val, float) - if val in self.tensor_to_id: - return self.tensor_to_id[val] + id = len(mps_graph.mps_values) + self.tensor_to_id[val] = id - id = self.get_serialized_id(val, mps_graph) tensor = torch.tensor(val) constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data( tensor, mps_graph, mps_data_type, id diff --git a/backends/apple/mps/serialization/mps_graph_schema.py b/backends/apple/mps/serialization/mps_graph_schema.py index 04a41abaa1c..66697b04b7d 100644 --- a/backends/apple/mps/serialization/mps_graph_schema.py +++ b/backends/apple/mps/serialization/mps_graph_schema.py @@ -391,7 +391,7 @@ class MPSFull: @dataclass class MPSFullLike(MPSNode1x1): - fill_value: float = 0.0 + fill_value: Union[float, str] = 0.0 dtype: MPSDataType = MPSDataType.mps_data_type_float32