diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 33f8d10a7b71b..8ecfe493650d7 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -3346,8 +3346,7 @@ def fn(x): return y x = {"a": torch.tensor([1]), "b": torch.tensor([1])} - # FIXME It should be KeyError here - self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError, lambda: fn(x)) + self.assertRaises(KeyError, lambda: fn(x)) def test_attached_attribute_in_dir(self): class MyModule(torch.nn.Module): @@ -4949,6 +4948,20 @@ def f(a, tmp): # grad state may not be properly reset after the error self.assertTrue(torch.is_grad_enabled()) + def test_const_dict_keyerror(self): + d = {} + + def fn(x): + try: + y = d[0] + except KeyError: + y = 1 + return x + y + + opt_fn = torch.compile(fn, backend="eager") + inp = torch.randn(3, 3) + self.assertEqual(fn(inp), opt_fn(inp)) + instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo_expected_failures/TestTensorBoardEmbedding.test_embedding b/test/dynamo_expected_failures/TestTensorBoardEmbedding.test_embedding deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestTensorBoardEmbedding.test_embedding_64 b/test/dynamo_expected_failures/TestTensorBoardEmbedding.test_embedding_64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestTensorBoardSummary.test_image_with_one_channel b/test/dynamo_expected_failures/TestTensorBoardSummary.test_image_with_one_channel deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestTensorBoardSummary.test_image_without_channel b/test/dynamo_expected_failures/TestTensorBoardSummary.test_image_without_channel deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index cbe6079ab907c..791d19ffb4c1a 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -787,27 +787,29 @@ def call_self_handler(tx, args, kwargs): def constant_fold_handler(tx, args, kwargs): # fast path - return builder( - tx, - fn( + try: + res = fn( *[x.as_python_constant() for x in args], - ), - ) + ) + except Exception as exc: + unimplemented(f"constant fold exception: {repr(exc)}") + return builder(tx, res) else: def constant_fold_handler(tx, args, kwargs): # path with a runtime check if check_unspec_or_constant_args(args, kwargs): - return builder( - tx, - fn( + try: + res = fn( *[x.as_python_constant() for x in args], **{ k: v.as_python_constant() for k, v in kwargs.items() }, - ), - ) + ) + except Exception as exc: + unimplemented(f"constant fold exception: {repr(exc)}") + return builder(tx, res) handlers.append(constant_fold_handler) diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 723cf6ac77ef6..77da855b69aec 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -195,7 +195,7 @@ def reconstruct(self, codegen): def getitem_const(self, arg: VariableTracker): key = ConstDictVariable._HashableTracker(arg) if key not in self.items: - raise KeyError(arg.value) + unimplemented(f"dict KeyError: {arg.value}") return self.items[key] def call_method(