diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 61ab7b4c40f..e173d4b66a4 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2327,10 +2327,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # Cast the const_arg to the dtype of the x_arg full_arg = self.resolve_full_arg(x_arg, const_arg) + full_output_dtype = ( + torch.int32 if isinstance(full_arg, int) else torch.float32 + ) + # Extract an argument to a separate full op. with graph_module.graph.inserting_before(mul_node): full_node = graph_module.graph.call_function( - torch.ops.aten.full.default, args=([1], full_arg) + torch.ops.aten.full.default, + args=([1], full_arg), + kwargs={"dtype": full_output_dtype}, ) full_node.meta = mul_node.meta full_node.meta["val"] = [1]