diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index c2561d06487..a6ddf1c82d2 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -856,7 +856,7 @@ def test_non_const_buffer_sizes(self) -> None: class Add(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: b = 3 + 1 - return x + b + return x + torch.tensor(b) f = Add() @@ -1325,7 +1325,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Find the multiplication node in the graph that was emitted. for node in program_mul.exported_program().graph.nodes: - if node.target == torch.ops.aten.mul.out: + if ( + node.target == torch.ops.aten.mul.out + or node.target == torch.ops.aten.mul.Scalar_out + ): break self.assertIsNotNone(node) diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index d96e8a24143..fca8bd2212f 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -292,9 +292,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ).exported_program() for node in ep.graph.nodes: self.assertNotEqual(node.op, "get_attr") - self.assertEqual( - len([node for node in ep.graph.nodes if node.op == "placeholder"]), 2 - ) def test_constraint_present_after_dce(self): import executorch.exir as exir