Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] adding torch.jit.isinstance support #46062

Closed
wants to merge 12 commits into from

Conversation

Lilyjjo
Copy link
Contributor

@Lilyjjo Lilyjjo commented Oct 8, 2020

Stack from ghstack:

Adds support for torch.jit.isinstance in both eager and script mode

Example use:

import torch
from typing import Any, List

class TestModule(torch.nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
    
    def call(self, input1: str, input2: str) -> str:
        return input1

    def forward(self, input: Any) -> None:
        if torch.jit.isinstance(input, List[str]):
            for el in input:
                print(el)

TestModule().forward(["1","2"])
scripted_module = torch.jit.script(TestModule())
scripted_module(["1", "2"])

Differential Revision: D24264415

@Lilyjjo Lilyjjo requested a review from apaszke as a code owner October 8, 2020 23:02
Lilyjjo added a commit that referenced this pull request Oct 8, 2020
ghstack-source-id: 429519585f7fb5aada39b91103fdeebc9a94ebed
Pull Request resolved: #46062
@dr-ci
Copy link

dr-ci bot commented Oct 8, 2020

💊 CI failures summary and remediations

As of commit dd82816 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)---
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 43 times.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Oct 8, 2020
Currently eager implementation only works for python 3.8+ 

Also have issue of shadowing isinstance within the eager implementation, which is why it is torch.jit.isinstance2 

[ghstack-poisoned]
have issue of shadowing isinstance within the eager implementation, which is why it is torch.jit.isinstance2 

[ghstack-poisoned]
Lilyjjo added a commit that referenced this pull request Oct 9, 2020
ghstack-source-id: 5c165eb874b6015036e110ee5f3668b7d47125f2
Pull Request resolved: #46062
@Lilyjjo Lilyjjo linked an issue Oct 9, 2020 that may be closed by this pull request
@Lilyjjo Lilyjjo changed the title [WIP] adding torch.jit.isinstance [JIT] adding torch.jit.isinstance support Oct 9, 2020
torch/jit/__init__.py Outdated Show resolved Hide resolved
torch/jit/_isinstance.py Outdated Show resolved Hide resolved
return getattr(the_type, "__args__", None)


def check_args_exist(the_type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you reuse methods here?

def is_list(ann):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, can do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So actually, looking more into this, the is_list/is_dict/etc in _jit_internal.py do more checks that what is necessary for the actual torch.jit.isinstance needs. The first check in the _jit_internal.py which triggers the shared error message is the only real things that _isintance needs, the return values from the _jit_internal.py aren't needed. I see the benefit of using the _jit_internal.py functions because the error messages should be the same and it is code duplication to have them twice. I can see a couple options:

  1. use the _jit_internal.py functions in _isinstance and perform the unnecessary checks
  2. re-write the error messages in a new function which both the _jit_intenral.py functions and _isinstnace functions can utilize
  3. keep the code as is in this PR

torch/jit/_isinstance.py Outdated Show resolved Hide resolved
return False
else:
return False
elif origin_type is Union: # TODO actually handles Optional Case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC, why not checking for Optional[] here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling getattr(the_type, "__origin__", None) on an Optional[str] type returns the type of typing.Union instead of typing.Option. Whoever wrote the is_optional logic in _jit_internal.py did some similarly odd logic to address this

)


def generics_checker(the_obj, the_type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't have to be in this diff.

Could you add a TODO for implementing check for "future"?

origin_type = get_origin(the_type)
if origin_type:
return generics_checker(the_obj, the_type)
# handle odd case of non typed optional origin returning as none
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC, why is this the case? Is this a quirk of Python?

Copy link
Contributor Author

@Lilyjjo Lilyjjo Oct 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an oddity of how python treats the Optional type. For the plain type Optional without an inner specification (e.g. just plain Optional instead of Optional[str]), calling getattr(the_type, "__origin__", None) returns None. This is opposed to Dict/List/Tuple with no inner specifications where the returned origin is a list/dict/tuple instead of None

test/jit/test_isinstance.py Outdated Show resolved Hide resolved
test/jit/test_isinstance.py Show resolved Hide resolved
Adds support for torch.jit.isinstance in both eager and script mode

Example use:

```
import torch
from typing import Any, List

class TestModule(torch.nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
    
    def call(self, input1: str, input2: str) -> str:
        return input1

    def forward(self, input: Any) -> None:
        if torch.jit.isinstance(input, List[str]):
            for el in input:
                print(el)

TestModule().forward(["1","2"])
scripted_module = torch.jit.script(TestModule())
scripted_module(["1", "2"])
```

[ghstack-poisoned]
Lilyjjo added a commit that referenced this pull request Oct 12, 2020
ghstack-source-id: 1de7208c3804b1484f9c2a22aaf2725b92197eeb
Pull Request resolved: #46062
torch/jit/_isinstance.py Outdated Show resolved Hide resolved
torch/jit/__init__.py Outdated Show resolved Hide resolved
test/jit/test_isinstance.py Outdated Show resolved Hide resolved
x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}]
self.checkScript(list_in_while_loop, (x,))

def test_switch_on_type(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that this test is testing refinement. Consider renaming it to reflect that more clearly.

Comment on lines 4 to 9
def get_origin(target_type):
return getattr(target_type, "__origin__", None)


def get_args(target_type):
return getattr(target_type, "__args__", None)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These will behave differently between python3.6 and python3.7+:

Python 3.6.2+ (heads/master-dirty:404de642c0, Sep 15 2017, 15:52:01) 
[GCC 4.2.1 Compatible Apple LLVM 9.0.0 (clang-900.0.37)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> from typing import List
>>> a = List
>>> a.__args__
>>> a.__origin__
Python 3.7.4 (default, Aug 13 2019, 15:17:50) 
[Clang 4.0.1 (tags/RELEASE_401/final)] :: Anaconda, Inc. on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> from typing import List
>>> a = List
>>> a.__args__
(~T,)
>>> a.__origin__
<class 'list'>

which (I think) will cause _isinstance to skip directly to return isinstance(obj, target_type) in python3.6, where target_type in my example is typing.List, but not in python3.7.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code works here because in 3.6 List is treated differently from List[str] :

Python 3.6.2 |Anaconda, Inc.| (default, Oct  5 2017, 07:59:26) 
[GCC 7.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from typing import List
>>> a = List[str]
>>> a.__args__
(<class 'str'>,)
>>> a.__origin__
typing.List
>>> a = List
>>> a.__args__
>>> a.__origin__

But this allows eager mode to accept List/Dict/Tuple as the object type while Scripting doesn't for python 3.6. I'll see what I can do to make the behavior the same

torch/jit/_isinstance.py Outdated Show resolved Hide resolved
torch/jit/_isinstance.py Outdated Show resolved Hide resolved
torch/jit/_isinstance.py Outdated Show resolved Hide resolved
torch/jit/_isinstance.py Outdated Show resolved Hide resolved
test/jit/test_isinstance.py Outdated Show resolved Hide resolved
Adds support for torch.jit.isinstance in both eager and script mode

Example use:

```
import torch
from typing import Any, List

class TestModule(torch.nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
    
    def call(self, input1: str, input2: str) -> str:
        return input1

    def forward(self, input: Any) -> None:
        if torch.jit.isinstance(input, List[str]):
            for el in input:
                print(el)

TestModule().forward(["1","2"])
scripted_module = torch.jit.script(TestModule())
scripted_module(["1", "2"])
```

Differential Revision: [D24264415](https://our.internmc.facebook.com/intern/diff/D24264415)

[ghstack-poisoned]
elif target_type is Dict or target_type is dict:
_jit_internal.is_dict(target_type)
elif target_type is None or target_type is Optional:
_jit_internal.is_optional(target_type)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm using the _jit_internal.py functions only for the error message they throw, the return value of these function calls aren't used. I'm still questioning if I should use the _jit_internal.py functions this way or if I should extract the error messages from the _jit_internal.py functions, put them in a different location, and then have this part of the code for _isinstance and the _jit_internal.py functions call them separately

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think either way is fine with a slight preference to refactor the error message part and reuse. Just to make sure I understand why we don't use those return values, it is because you are rolling your own logic to check origin and typing like on line 1108 and 1109?

Copy link
Contributor Author

@Lilyjjo Lilyjjo Oct 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _jit_internal.py functions don't aid in helping to check for nested containers, they only check if the outer container is of the right type and if the outer container contains an inner type. In generics_checker on 1102-1157 there's a need to do work to process nested containers and I think the code is cleaner when I don't use the _jit_intneral.py functions.

I'm also confused on what the checks for the __module__ below code int he _jit_internal.py functions are doing:

if not hasattr(ann, '__module__'):
        return False
    return ann.__module__ == 'typing' and \
        (getattr(ann, '__origin__', None) is List or
            getattr(ann, '__origin__', None) is list)

Should I be checking for this __module__ attribute in my logic somewhere? I'm not familiar enough yet with modules to know what the importance of this __module__ attribute is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary since you are already checking equality against List, which is from typing module. I am not entirely sure why it is needed in _jit_internal.py. @SplitInfinity Could you clarify?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm not sure either.

I think it might also make sense to move this code to jit_internal.py to reuse those error messages. I would prefer that to calling is_list, etc. for the error message and throwing away the return value. ignore, export, is_scripting and ignore live there and those are certainly not internal APIs (but I think they are in that file in order to avoid circular import problems).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll move the code and factor out the error messages 💯

Adds support for torch.jit.isinstance in both eager and script mode

Example use:

```
import torch
from typing import Any, List

class TestModule(torch.nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
    
    def call(self, input1: str, input2: str) -> str:
        return input1

    def forward(self, input: Any) -> None:
        if torch.jit.isinstance(input, List[str]):
            for el in input:
                print(el)

TestModule().forward(["1","2"])
scripted_module = torch.jit.script(TestModule())
scripted_module(["1", "2"])
```

Differential Revision: [D24264415](https://our.internmc.facebook.com/intern/diff/D24264415)

[ghstack-poisoned]
Lilyjjo added a commit that referenced this pull request Oct 13, 2020
ghstack-source-id: 79618038d10e321a2c3716c8272041ff9bfb790c
Pull Request resolved: #46062
@codecov
Copy link

codecov bot commented Oct 14, 2020

Codecov Report

Merging #46062 into gh/Lilyjjo/9/base will increase coverage by 0.04%.
The diff coverage is 100.00%.

@@                  Coverage Diff                  @@
##           gh/Lilyjjo/9/base   #46062      +/-   ##
=====================================================
+ Coverage              68.39%   68.43%   +0.04%     
=====================================================
  Files                    411      411              
  Lines                  53972    54051      +79     
=====================================================
+ Hits                   36914    36992      +78     
- Misses                 17058    17059       +1     

torch/jit/_script.py Outdated Show resolved Hide resolved
elif target_type is Dict or target_type is dict:
_jit_internal.is_dict(target_type)
elif target_type is None or target_type is Optional:
_jit_internal.is_optional(target_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think either way is fine with a slight preference to refactor the error message part and reuse. Just to make sure I understand why we don't use those return values, it is because you are rolling your own logic to check origin and typing like on line 1108 and 1109?

def optional_test_none(x: Any):
assert torch.jit.isinstance(x, Optional[torch.Tensor])
# assert not torch.jit.isinstance(x, Optional[str])
# TODO: above line fails in TS interpreter need to investigate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you looking into this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm running into the difficulty where eager mode and scripting behave differently. Eager mode will allow both isinstances to return true (since the None can be either in Optional[torch.Tensor] or Optional[str] Python), but from what I can tell during scripting type refinement happens during the first isinstance call refining the type to Optional[torch.Tensor] which then makes the 2nd isinstance call return false since the type isn't Optional[str].

I'm not sure how to go about resolving this. I don't know how to in python eager mode carry the type assignment outcome between isinstance calls to make the 2nd isinstance call return false

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good finding.

I don't think we should change eager mode behavior as we would want users to mechanically migrate from isinstance to torch.jit.isinstance.

I think what we can do though is to change behavior of prim::isinstance in TorchScript. From what I can tell, prim::isinstance ultimate calls into here which checks for sub-typing relationship. I think this logic is only correct when all types either completely overlaps (subtype relationship) or do not overlap (not an instance of). Optional[T] doesn't fall into this situation. It would be even worse for Union[T1, T2, T3] when we support it in the future.

We should therefore fix this by changing implementation of prim::isinstance. It doesn't have to be in this diff though since this case is rare and unconventional enough. Feel free to leave a TODO here for now and we can discuss how to deal with it with more audience.

def optional_test(x: Any):
assert torch.jit.isinstance(x, Optional[torch.Tensor])
assert not torch.jit.isinstance(x, Optional[str])
# TODO: successful torch.jit.isinstance makes sets type?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll delete this TODO 😬

Copy link
Contributor

@gmagogsfm gmagogsfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, approving. But please wait for SplitIfinity to clarify on one of the comment before merging.

def optional_test_none(x: Any):
assert torch.jit.isinstance(x, Optional[torch.Tensor])
# assert not torch.jit.isinstance(x, Optional[str])
# TODO: above line fails in TS interpreter need to investigate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good finding.

I don't think we should change eager mode behavior as we would want users to mechanically migrate from isinstance to torch.jit.isinstance.

I think what we can do though is to change behavior of prim::isinstance in TorchScript. From what I can tell, prim::isinstance ultimate calls into here which checks for sub-typing relationship. I think this logic is only correct when all types either completely overlaps (subtype relationship) or do not overlap (not an instance of). Optional[T] doesn't fall into this situation. It would be even worse for Union[T1, T2, T3] when we support it in the future.

We should therefore fix this by changing implementation of prim::isinstance. It doesn't have to be in this diff though since this case is rare and unconventional enough. Feel free to leave a TODO here for now and we can discuss how to deal with it with more audience.

elif target_type is Dict or target_type is dict:
_jit_internal.is_dict(target_type)
elif target_type is None or target_type is Optional:
_jit_internal.is_optional(target_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary since you are already checking equality against List, which is from typing module. I am not entirely sure why it is needed in _jit_internal.py. @SplitInfinity Could you clarify?

Copy link

@SplitInfinity SplitInfinity left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add docstrings for the new functions that you introduced?

Adds support for torch.jit.isinstance in both eager and script mode

Example use:

```
import torch
from typing import Any, List

class TestModule(torch.nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
    
    def call(self, input1: str, input2: str) -> str:
        return input1

    def forward(self, input: Any) -> None:
        if torch.jit.isinstance(input, List[str]):
            for el in input:
                print(el)

TestModule().forward(["1","2"])
scripted_module = torch.jit.script(TestModule())
scripted_module(["1", "2"])
```

Differential Revision: [D24264415](https://our.internmc.facebook.com/intern/diff/D24264415)

[ghstack-poisoned]
Lilyjjo added a commit that referenced this pull request Oct 15, 2020
ghstack-source-id: 3be8fd333c19cee69134cb344616f1149bc6a55b
Pull Request resolved: #46062
print(val)

m = torch.jit.script(MyModule())
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SplitInfinity is this what you had in mind for a doc string? I'm not exactly clear on what doc strings should look like.

Another doc string question: should this go here or on should it be added to the torch/jit/init.py callsite on line 75 of that file?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this particular docstring with example code that uses torch.jit.isinstance should go on isinstance in torch/jit/__init__.py.

It's also a good idea to include sections describing the arguments and return values. For example, here is the docstring from torch.jit.wait:

def wait(future):
    """
    Forces completion of a `torch.jit.Future[T]` asynchronous task, returning the
    result of the task. See :func:`~fork` for docs and examples.
    Arguments:
        func (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork`
    Returns:
        `T`: the return value of the the completed task
    """

@SplitInfinity
Copy link

Also, can you take a look at the coverage report and add some tests for a few lines that aren't being covered if possible? For example, I think the line that returns False when a value in a dictionary isn't of the expected type isn't being hit: https://codecov.io/gh/pytorch/pytorch/pull/46062/diff#D1-1132

Adds support for torch.jit.isinstance in both eager and script mode

Example use:

```
import torch
from typing import Any, List

class TestModule(torch.nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
    
    def call(self, input1: str, input2: str) -> str:
        return input1

    def forward(self, input: Any) -> None:
        if torch.jit.isinstance(input, List[str]):
            for el in input:
                print(el)

TestModule().forward(["1","2"])
scripted_module = torch.jit.script(TestModule())
scripted_module(["1", "2"])
```

Differential Revision: [D24264415](https://our.internmc.facebook.com/intern/diff/D24264415)

[ghstack-poisoned]
Lilyjjo added a commit that referenced this pull request Oct 16, 2020
ghstack-source-id: 9752d038dab2473f0e06dca823cbe7cb01a70868
Pull Request resolved: #46062
Adds support for torch.jit.isinstance in both eager and script mode

Example use:

```
import torch
from typing import Any, List

class TestModule(torch.nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
    
    def call(self, input1: str, input2: str) -> str:
        return input1

    def forward(self, input: Any) -> None:
        if torch.jit.isinstance(input, List[str]):
            for el in input:
                print(el)

TestModule().forward(["1","2"])
scripted_module = torch.jit.script(TestModule())
scripted_module(["1", "2"])
```

Differential Revision: [D24264415](https://our.internmc.facebook.com/intern/diff/D24264415)

[ghstack-poisoned]
Lilyjjo added a commit that referenced this pull request Oct 16, 2020
ghstack-source-id: c0ec83732e1086a9655d62e0b5f7aa97d76bab33
Pull Request resolved: #46062
Adds support for torch.jit.isinstance in both eager and script mode

Example use:

```
import torch
from typing import Any, List

class TestModule(torch.nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
    
    def call(self, input1: str, input2: str) -> str:
        return input1

    def forward(self, input: Any) -> None:
        if torch.jit.isinstance(input, List[str]):
            for el in input:
                print(el)

TestModule().forward(["1","2"])
scripted_module = torch.jit.script(TestModule())
scripted_module(["1", "2"])
```

Differential Revision: [D24264415](https://our.internmc.facebook.com/intern/diff/D24264415)

[ghstack-poisoned]
Lilyjjo added a commit that referenced this pull request Oct 16, 2020
ghstack-source-id: 4d5ef5ddd6799d46d05b5deaf14fa6daf33356ab
Pull Request resolved: #46062
Adds support for torch.jit.isinstance in both eager and script mode

Example use:

```
import torch
from typing import Any, List

class TestModule(torch.nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
    
    def call(self, input1: str, input2: str) -> str:
        return input1

    def forward(self, input: Any) -> None:
        if torch.jit.isinstance(input, List[str]):
            for el in input:
                print(el)

TestModule().forward(["1","2"])
scripted_module = torch.jit.script(TestModule())
scripted_module(["1", "2"])
```

Differential Revision: [D24264415](https://our.internmc.facebook.com/intern/diff/D24264415)

[ghstack-poisoned]
Lilyjjo added a commit that referenced this pull request Oct 19, 2020
ghstack-source-id: b0e7cf8ba59cf65eb27a4d793ca0b2b8e72eb2cb
Pull Request resolved: #46062
@facebook-github-bot
Copy link
Contributor

@Lilyjjo merged this pull request in f83cf2d.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[jit] support isinstance(foo, dict)
4 participants