Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enums + jit still not working in PyTorch 2.1 w/ Python 3.11 #108933

Closed
rwightman opened this issue Sep 9, 2023 · 4 comments
Closed

Enums + jit still not working in PyTorch 2.1 w/ Python 3.11 #108933

rwightman opened this issue Sep 9, 2023 · 4 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@rwightman
Copy link

rwightman commented Sep 9, 2023

馃悰 Describe the bug

I reported this issue a few months back on torch nightlies where torchscript was having issues with a number of timm models when using Python 3.11. It seemed like there was an awareness of the issues, plan to fix discussing things in slack, but issue persists in the 2.1 RC..

It looks like it's a problem with an enum used, so any model that has an enum in timm is failing torchscript tests in python 3.11. No problems with any recent pytorch versions in python < 3.11.

  model = torch.jit.script(model)
            ^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/_script.py", line 1324, in script
    return torch.jit._recursive.create_script_module(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/_recursive.py", line 556, in create_script_module
    concrete_type = get_module_concrete_type(nn_module, share_types)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/_recursive.py", line 505, in get_module_concrete_type
    concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/_recursive.py", line 438, in get_or_create_concrete_type
    concrete_type_builder = infer_concrete_type_builder(nn_module)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/_recursive.py", line 284, in infer_concrete_type_builder
    sub_concrete_type = get_module_concrete_type(item, share_types)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/_recursive.py", line 505, in get_module_concrete_type
    concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/_recursive.py", line 438, in get_or_create_concrete_type
    concrete_type_builder = infer_concrete_type_builder(nn_module)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/_recursive.py", line 395, in infer_concrete_type_builder
    attr_type, inferred = infer_type(name, value)
                          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/_recursive.py", line 226, in infer_type
    ann_to_type = torch.jit.annotations.ann_to_type(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/annotations.py", line 508, in ann_to_type
    the_type = try_ann_to_type(ann, loc, rcb)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/annotations.py", line 489, in try_ann_to_type
    scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/jit/_script.py", line 1506, in _recursive_compile_class
    rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/_jit_internal.py", line 457, in createResolutionCallbackForClassMethods
    captures.update(get_closure(fn))
                    ^^^^^^^^^^^^^^^
  File "/blah/.conda/envs/pytorch-210n/lib/python3.11/site-packages/torch/_jit_internal.py", line 207, in get_closure
    captures.update(fn.__globals__)
                    ^^^^^^^^^^^^^^
AttributeError: 'wrapper_descriptor' object has no attribute '__globals__'

Versions

timm is failing torchscript tests in python 3.11. No problems with any recent pytorch versions in python < 3.11.

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel

@rwightman rwightman changed the title Enums + jit still not working in PyTorch 2.10 w/ Python 3.11 Enums + jit still not working in PyTorch 2.1 w/ Python 3.11 Sep 9, 2023
@drisspg drisspg added oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module triage review labels Sep 11, 2023
@rwightman
Copy link
Author

FYI to reproduce, w/ a newer (0.9.x) timm release:

import torch
import timm
mm = timm.create_model('vit_base_patch16_224', pretrained=True)
torch.jit.script(mm)

@malfet malfet added this to the 2.1.1 milestone Sep 19, 2023
@malfet malfet self-assigned this Sep 19, 2023
@malfet
Copy link
Contributor

malfet commented Sep 19, 2023

Original PR that I though should unblock enums for 3.11: #91805

Here is the minimal reproducer for the issue:

import torch
from enum import Enum

class Color(int, Enum):
    RED = 1
    GREEN = 2

def enum_fn(x: Color, y: Color) -> bool:
    if x == Color.RED:
        return True
    return x == y

m = torch.jit.script(enum_fn)

malfet added a commit that referenced this issue Sep 20, 2023
In Python-3.11+ typed enums (such as `enum.IntEnum`) retain `__new__`,`__str__` and so on method of the base class via `__init__subclass__()` method (see https://docs.python.org/3/whatsnew/3.11.html#enum )

This change allows typed enums to be scriptable on 3.11, by explicitly marking several `enum.Enum` method to be dropped by jit script
Add test that typed enums are jit-scriptable

Fixes #108933
malfet added a commit that referenced this issue Sep 21, 2023
In Python-3.11+ typed enums (such as `enum.IntEnum`) retain `__new__`,`__str__` and so on method of the base class via `__init__subclass__()` method (see https://docs.python.org/3/whatsnew/3.11.html#enum ), i.e. following code
```python
import sys
import inspect
from enum import Enum

class IntColor(int, Enum):
    RED = 1
    GREEN = 2

class Color(Enum):
    RED = 1
    GREEN = 2

def get_methods(cls):
    def predicate(m):
        if not inspect.isfunction(m) and not inspect.ismethod(m):
            return False
        return m.__name__ in cls.__dict__
    return inspect.getmembers(cls, predicate=predicate)

if __name__ == "__main__":
    print(sys.version)
    print(f"IntColor methods {get_methods(IntColor)}")
    print(f"Color methods {get_methods(Color)}")
```

Returns empty list for both cases for older Python, but on Python-3.11+ it returns list contains of enum constructors and others:
```shell
% conda run -n py310 python bar.py
3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:41:52) [Clang 15.0.7 ]
IntColor methods []
Color methods []
% conda run -n py311 python bar.py
3.11.0 | packaged by conda-forge | (main, Oct 25 2022, 06:21:25) [Clang 14.0.4 ]
IntColor methods [('__format__', <function Enum.__format__ at 0x105006ac0>), ('__new__', <function Enum.__new__ at 0x105006660>), ('__repr__', <function Enum.__repr__ at 0x1050068e0>)]
Color methods []
```

This change allows typed enums to be scriptable on 3.11, by explicitly marking several `enum.Enum` method to be dropped by jit script and adds test that typed enums are jit-scriptable.

Fixes #108933

Pull Request resolved: #109717
Approved by: https://github.com/atalman, https://github.com/davidberard98

(cherry picked from commit 55685d5)
malfet added a commit that referenced this issue Sep 21, 2023
In Python-3.11+ typed enums (such as `enum.IntEnum`) retain `__new__`,`__str__` and so on method of the base class via `__init__subclass__()` method (see https://docs.python.org/3/whatsnew/3.11.html#enum ), i.e. following code
```python
import sys
import inspect
from enum import Enum

class IntColor(int, Enum):
    RED = 1
    GREEN = 2

class Color(Enum):
    RED = 1
    GREEN = 2

def get_methods(cls):
    def predicate(m):
        if not inspect.isfunction(m) and not inspect.ismethod(m):
            return False
        return m.__name__ in cls.__dict__
    return inspect.getmembers(cls, predicate=predicate)

if __name__ == "__main__":
    print(sys.version)
    print(f"IntColor methods {get_methods(IntColor)}")
    print(f"Color methods {get_methods(Color)}")
```

Returns empty list for both cases for older Python, but on Python-3.11+ it returns list contains of enum constructors and others:
```shell
% conda run -n py310 python bar.py
3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:41:52) [Clang 15.0.7 ]
IntColor methods []
Color methods []
% conda run -n py311 python bar.py
3.11.0 | packaged by conda-forge | (main, Oct 25 2022, 06:21:25) [Clang 14.0.4 ]
IntColor methods [('__format__', <function Enum.__format__ at 0x105006ac0>), ('__new__', <function Enum.__new__ at 0x105006660>), ('__repr__', <function Enum.__repr__ at 0x1050068e0>)]
Color methods []
```

This change allows typed enums to be scriptable on 3.11, by explicitly marking several `enum.Enum` method to be dropped by jit script and adds test that typed enums are jit-scriptable.

Fixes #108933

Cherry-pick of #109717 into release/2.1 branch.
Approved by: https://github.com/atalman, https://github.com/davidberard98

(cherry picked from commit 55685d5)
@malfet malfet modified the milestones: 2.1.1, 2.1.0 Sep 22, 2023
@malfet
Copy link
Contributor

malfet commented Sep 22, 2023

Was fixed in lastest 2.1.0 release candidate, tested on my M1 mac by running the following:

% python -c "import sys;import torch;import timm;print(sys.version, torch.__version__, torch.version.git_version);torch.jit.script(timm.create_model('vit_base_patch16_224', pretrained=True))"
3.11.5 | packaged by conda-forge | (main, Aug 27 2023, 03:33:12) [Clang 15.0.7 ] 2.1.0 7bcf7da3a268b435777fe87c7794c382f444e86d

@rwightman
Copy link
Author

@malfet thanks, tested in timm with latest RC and looks good

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants