Skip to content

Commit

Permalink
Update on "[wip] Improve exception support"
Browse files Browse the repository at this point in the history
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
  • Loading branch information
anijain2305 committed May 23, 2024
2 parents 04253b6 + e2bc380 commit e94e224
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 1 deletion.
130 changes: 130 additions & 0 deletions test/dynamo/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Owner(s): ["module: dynamo"]

import torch
import torch._dynamo.config

import torch._dynamo.test_case
import torch._functorch.config
import torch.utils.checkpoint


class ExceptionTests(torch._dynamo.test_case.TestCase):
def test_exception(self):
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
raise NotImplementedError
except Exception:
x = torch.sigmoid(x)

return x

x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)

def test_exception2(self):
def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
raise NotImplementedError
except AssertionError:
x = torch.sigmoid(x)
except NotImplementedError:
x = torch.cos(x)

return x

x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)

def test_exception_raised_from_child(self):
def gn():
raise NotImplementedError

def fn(x):
x = torch.cos(x)
try:
x = torch.sin(x)
gn()
x = torch.sin(x)
except Exception:
x = torch.sigmoid(x)

return x

x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)

def test_nn_module_getattr(self):
class A:
def __init__(self):
self._b = 20

def __getattr__(self, name):
fixed_name = "_" + name
if fixed_name in self.__dict__:
return self.__dict__[fixed_name]
raise AttributeError(f"{name} absent")

class B(A):
def __init__(self):
self.a = 10

def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return 30

obj = B()

def fn(x):
return x * obj.a * obj.b * obj.c

x = torch.ones(4)
ref = fn(x)
print(ref)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)

@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_custom_getattr_on_module_exception(self):
class Foo(torch.nn.Module):
def __init__(self, a=3):
super().__init__()
self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2))

def __getattr__(self, name):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "a_copy":
return self.a
raise

def forward(self, x):
return x * self.a * self.a_copy

mod = Foo()
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)

x = torch.ones(4)
self.assertEqual(mod(x), opt_mod(x))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def var_getattr(self, tx, name):
# For attributes (not methods) that were not caught in the special handling above,
# (e.g. tensor.real), we handle these generically, assuming that the output type is
# a tensor.
if result is None and name != "grad":
if result is None and name not in ("grad", "grad_fn"):

def try_generic_attr_handling():
from .builder import wrap_fx_proxy
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,10 @@ def var_getattr(self, tx, name):
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
return tx.output.side_effects.load_attr(self, name)

if name == "__dict__":
options = {"source": source}
return variables.GetAttrVariable(self, name, **options)

try:
subobj = self._getattr_static(name)
except AttributeError:
Expand Down

0 comments on commit e94e224

Please sign in to comment.