Skip to content

[dynamo] error_on_graph_break(True) fails to error on graph break in some cases #166589

@williamwen42

Description

@williamwen42

This test fails since compilation succeeds when it should not, thus hiding a graph break.

    def test_error_on_graph_break_nonempty_checkpoint(self):
        cnts = torch._dynamo.testing.CompileCounter()

        @torch.compile(backend=cnts)
        def fn(x):
            x = x + 1
            x = x + 1
            x = x + 1
            with torch._dynamo.error_on_graph_break(True):
                torch._dynamo.graph_break()
            return x + 1

        with self.assertRaises(Unsupported):
            fn(torch.ones(3))

        self.assertEqual(cnts.frame_count, 0)

cc @ezyang @gchanan @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @amjames @Lucaskabela

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions