Skip to content

Commit

Permalink
Support 'BaseOutput' and subclasses from 'diffusers' in dynamo
Browse files Browse the repository at this point in the history
ghstack-source-id: 6548b7cd6fd1a60b65eb6bef7afd8a6c6bc0a145
Pull Request resolved: #111978
  • Loading branch information
BowenBao committed Oct 27, 2023
1 parent 4b51724 commit 34fa2be
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 30 deletions.
97 changes: 97 additions & 0 deletions test/dynamo/test_base_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Owner(s): ["module: dynamo"]
import unittest.mock

import torch

import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same

try:
from diffusers.models import unet_2d
except ImportError:
unet_2d = None


def maybe_skip(fn):
if unet_2d is None:
return unittest.skip("requires diffusers")(fn)
return fn


class TestBaseOutput(torch._dynamo.test_case.TestCase):
@maybe_skip
def test_create(self):
def fn(a):
tmp = unet_2d.UNet2DOutput(a + 1)
return tmp

torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=1)

@maybe_skip
def test_assign(self):
def fn(a):
tmp = unet_2d.UNet2DOutput(a + 1)
tmp.sample = a + 2
return tmp

args = [torch.randn(10)]
obj1 = fn(*args)

cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
obj2 = opt_fn(*args)
self.assertTrue(same(obj1.sample, obj2.sample))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)

def _common(self, fn, op_count):
args = [
unet_2d.UNet2DOutput(
sample=torch.randn(10),
)
]
obj1 = fn(*args)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
obj2 = opt_fn(*args)
self.assertTrue(same(obj1, obj2))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, op_count)

@maybe_skip
def test_getattr(self):
def fn(obj: unet_2d.UNet2DOutput):
x = obj.sample * 10
return x

self._common(fn, 1)

@maybe_skip
def test_getitem(self):
def fn(obj: unet_2d.UNet2DOutput):
x = obj["sample"] * 10
return x

self._common(fn, 1)

@maybe_skip
def test_tuple(self):
def fn(obj: unet_2d.UNet2DOutput):
a = obj.to_tuple()
return a[0] * 10

self._common(fn, 1)

@maybe_skip
def test_index(self):
def fn(obj: unet_2d.UNet2DOutput):
return obj[0] * 10

self._common(fn, 1)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
6 changes: 5 additions & 1 deletion torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,11 @@ def _convert_frame_assert(
):
return None
if code.co_name == "<genexpr>" and code.co_filename.endswith(
("transformers/file_utils.py", "transformers/utils/generic.py")
(
"transformers/file_utils.py",
"transformers/utils/generic.py",
"diffusers/utils/outputs.py",
)
):
# not needed, but cleans up torchbench error stats
return None
Expand Down
82 changes: 53 additions & 29 deletions torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,30 @@ def call_method(
return super().call_method(tx, name, args, kwargs)


def _is_matching_transformers_cls(cls) -> bool:
if not cls.__module__.startswith("transformers."):
return False

try:
from transformers.file_utils import ModelOutput

return issubclass(cls, ModelOutput)
except ImportError:
return False


def _is_matching_diffusers_cls(cls) -> bool:
if not cls.__module__.startswith("diffusers."):
return False

try:
from diffusers.utils import BaseOutput

return issubclass(cls, BaseOutput)
except ImportError:
return False


class DataClassVariable(ConstDictVariable):
"""
This is a bit of a hack to deal with
Expand All @@ -311,20 +335,27 @@ class DataClassVariable(ConstDictVariable):
@staticmethod
@functools.lru_cache(None)
def _patch_once():
from transformers.file_utils import ModelOutput
try:
from transformers.file_utils import ModelOutput

for obj in ModelOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
for obj in ModelOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
except ImportError:
pass

@staticmethod
def is_matching_cls(cls):
try:
from transformers.file_utils import ModelOutput
from diffusers.utils import BaseOutput

return issubclass(cls, ModelOutput)
for obj in BaseOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
except ImportError:
return False
pass

@staticmethod
def is_matching_cls(cls):
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)

@classmethod
def is_matching_object(cls, obj):
Expand Down Expand Up @@ -437,26 +468,19 @@ def var_getattr(self, tx, name: str) -> "VariableTracker":
class CustomizedDictVariable(ConstDictVariable):
@staticmethod
def is_matching_cls(cls):
try:
# True if using default OrderedDict.__init__ and did not implement __post_init__
if (
issubclass(cls, collections.OrderedDict)
and cls.__init__ is collections.OrderedDict.__init__
and not hasattr(cls, "__post_init__")
):
return True
# hack for HF usecase:
# assume dataclass annotation for ModelOutput subclass
# assume self.create is AA to ModelOutput.__post_init__
# for non-HF usecase:
# check __module__ string to avoid costy HF import
if cls.__module__ != "transformers.modeling_outputs":
return False
from transformers.file_utils import ModelOutput

return issubclass(cls, ModelOutput)
except ImportError:
return False
# True if using default OrderedDict.__init__ and did not implement __post_init__
if (
issubclass(cls, collections.OrderedDict)
and cls.__init__ is collections.OrderedDict.__init__
and not hasattr(cls, "__post_init__")
):
return True
# hack for HF usecase:
# assume dataclass annotation for ModelOutput subclass
# assume self.create is AA to ModelOutput.__post_init__
# for non-HF usecase:
# check __module__ string to avoid costy HF import
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)

@classmethod
def is_matching_object(cls, obj):
Expand Down

0 comments on commit 34fa2be

Please sign in to comment.