From ca2b27ade0540502fea0d73e6c96fc4822e34833 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Wed, 8 Jan 2025 16:40:57 -0300 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- test/dynamo/test_generator.py | 31 +++++++++++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 12 ++++++++++++ 2 files changed, 43 insertions(+) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index ca62ae3b4ac6..2bfc7a8af302 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -601,6 +601,37 @@ def fn(t): self.assertEqual(i, 3) self.assertEqual(y, [(0, t), (1, t + 1), (2, t + 2)]) + @unittest.expectedFailure + def test_cleanup_throw(self): + def nested_generator(): + try: + yield 1 + yield 2 + except StopIteration: + return 123 # noqa: B901 + + def outer_generator(): + yield from nested_generator() + yield 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = outer_generator() + next(gen) # Start the outer generator and enter the nested generato + + i = 0 + try: + # Force an exception while the generator is running + i = gen.throw(StopIteration("stop")) + except RuntimeError: + pass + return (i, t.sin()) + + t = torch.randn(3) + i, y = fn(t) + self.assertEqual(i, 3) + self.assertEqual(y, t.sin()) + class GeneratorCPythonTests(GeneratorTestsBase): # Taken from commit diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index a51fcc532cd7..f3c67e9ea2d4 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1489,6 +1489,18 @@ def RAISE_VARARGS(self, inst): self._raise_exception_variable(inst) unimplemented("raise ... from ...") + def CLEANUP_THROW(self, inst): + tos = self.stack[-1] + assert isinstance(tos, ExceptionVariable) + if tos.exc_type is StopIteration: + unimplemented("CLEANUP_THROW") + # _type = self.pop() + # value = self.pop() + # _tb = self.pop() + # self.stack.append(value) + else: + self.RERAISE(inst) + def RERAISE(self, inst): if sys.version_info >= (3, 11): # RERAISE is currently supported in a narrow case of `raise ... from None` From 4775444d2b2736690565276277db771b6a2c5e67 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 10 Jan 2025 00:04:23 +0000 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- test/dynamo/test_generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 55b3c3a0883d..d809297ff2e9 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -27,8 +27,9 @@ def tearDown(self): def _compile_check(self, fn): eager = EagerAndRecordGraphs() t = torch.randn(2) - torch.compile(fn, backend=eager, fullgraph=True)(t) + r = torch.compile(fn, backend=eager, fullgraph=True)(t) self.assertGreater(len(eager.graphs), 0) + return t, r class GeneratorTests(GeneratorTestsBase): @@ -629,8 +630,7 @@ def fn(t): pass return (i, t.sin()) - t = torch.randn(3) - i, y = fn(t) + t, (i, y) = self._compile_check(fn) self.assertEqual(i, 3) self.assertEqual(y, t.sin()) From 431b2335b8705b852647930cd2806213a600f02a Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 17 Jan 2025 01:46:12 +0000 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- test/dynamo/test_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 95d6f3b2c8bb..f066a1a583a6 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -1,7 +1,7 @@ # Owner(s): ["module: dynamo"] import itertools -import unittest import sys +import unittest from collections import OrderedDict import torch From e9199ec8e9b4fccabe2c896de22caa746ff515fa Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 4 Feb 2025 15:15:19 +0000 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- torch/_dynamo/symbolic_convert.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 131c3c88caac..fef89213fa2f 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1492,10 +1492,11 @@ def RAISE_VARARGS(self, inst): unimplemented("raise ... from ...") def CLEANUP_THROW(self, inst): + # https://github.com/python/cpython/pull/96010 tos = self.stack[-1] assert isinstance(tos, ExceptionVariable) if tos.exc_type is StopIteration: - unimplemented("CLEANUP_THROW") + unimplemented("CLEANUP_THROW with StopIteration") else: self.RERAISE(inst)