Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create type hint stub files for module torch #12500

Closed
wants to merge 45 commits into from
Closed

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Oct 9, 2018

We have:

  • This is an initial stab at creating a type stub torch/init.pyi .
  • This is only tested on Python 3, since that's the only Python version mypy
    works on.
  • So far, we only aim at doing this for torch functions and torch.Tensor.
  • Quite a few methods and functions have to be typed manually. These are
    done in torch/init.pyi.in

For me, PyCharm (the non-paid one) didn't seem to indicate errors in the .pyi when opening and seemed to be able to get the type hint for the few functions I tried, but I don't use PyCharm for my usual PyTorch activities, so I didn't extensively try this out.

An example of a generated PYI is at this gist. https://gist.github.com/ezyang/bf9b6a5fa8827c52152858169bcb61b1


Original summary:

This is more a basis for discussion that a ready solution, as it does lots of funny things, and for many of them a better solution will be found, but to facilitate discussion, I'm putting this out.

After many improvements, this could contribute to fixing #7318.

We have

  • This is an initial stab at creating a type stub torch/__init__.pyi .
  • We only do this for Python 3 because we want to use type hint introspection.
  • So far, we only aim at doing this for torch functions and torch.Tensor.
  • We need to import the newly build torch. Thus we do this at the end of the build.
    We use os.fork to only import the module in a child process.
  • We use an annotate decorator to be able to put type hints on the native python functions in a way that
    a) they're available in the usual place for Python 3
    b) we stay Python 2 compatible
  • Some annotations in torch/functional.py are provided as examples, but the remaining ones still need to be done if you are OK with the decorator approach. At first glance, it seems that the annotations are also available to Python3 introspection, e.g. on Jupyter.

For me, PyCharm (the non-paid one) didn't seem to indicate errors in the .pyi when opening and seemed to be able to get the type hint for the few functions I tried, but I don't use PyCharm for my usual PyTorch activities, so I didn't extensively try this out.

An example of a generated PYI is at this gist.

This is more a basis for discussion that a ready solution,
as it does lots of funny things, and for many of them
a better solution will be found.

- Initial stab at creating a type torch/__init__.pyi .
- We only do this for Python 3 because we want to
  use type hint introspection.
- So far, we only aim at doing this for torch functions
  and torch.Tensor.
- We need to import the newly build torch. Thus we
  do this at the end of the build.
  We use os.fork to only import the module in a child
  process.
- We use an annotate decorator to be able to put
  type hints on the native python functions in a way that
  a) they're available in the usual place for Python 3
  b) we stay Python 2 compatible
- Some annotatons in torch/functional.py are provided
  as examples, but the remaining ones still need to be done.

This could end up fixing pytorch#7318
Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

torch/_utils.py Show resolved Hide resolved
'double': 'float',
'Generator *': 'Generator',
'Generator*': 'Generator',
'std::vector<Tensor>': 'Tuple[Tensor, ...]',

This comment was marked as off-topic.

'Layout': 'layout',
'void*': 'int', # dataptr
'std::string': 'str',
'real': 'float',

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

- Add manual tensor() annotation, thank you, @elliotwite.
- Create class stubs for device etc.
- Improve type mapping. Thank you, Simon, for the review comments.
'std::string': 'str',
'real': 'Union[float, int]',
'accreal': 'Union[float, int]',
'IntegerTensor': 'Tensor',

This comment was marked as off-topic.

tools/pyi/gen_pyi.py Outdated Show resolved Hide resolved

def arg_to_type_hint(arg):
name = arg['name']
if name == 'from': # keyword...

This comment was marked as off-topic.

@ssnl
Copy link
Collaborator

ssnl commented Oct 11, 2018

Question: what do you think is a good way to test this? Maybe pick some functions, and construct inputs basing on arg type, and then assert that it doesn't throw error from arg parser (it may still error in later stages due to things like shape mismatch)?

@ezyang
Copy link
Contributor

ezyang commented Oct 11, 2018

We'll definitely need some sort of testing strategy for this, otherwise it will bitrot faster than you can say leroy jenkins.


# annotation decorator to get annotations in a way that is compatible
# with both Python 2 and 3
def annotate(ret, **kwargs):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

needed_modules = set()


def type_to_python(typename):

This comment was marked as off-topic.

@t-vi
Copy link
Collaborator Author

t-vi commented Oct 11, 2018

Re testing: So the type hints should ideally have two properties:

  • Code satisfying the type hints will be "correct" in some way. One could try to check this using an approach similar to the one you mention, parsing the error messages. If we are off here, programs/people will have wrong ideas about what can be passed, but it isn't worse than now.

  • Any correct code will satisfy the type hint restrictions.
    I think this may be more important property - people will get upset if their correct code is marked as wrong by checkers. One testing approach I could think of is to put "known good code" (our tests? the examples/tutorials?) through a checker (mypy?) and see if it passes. The question is whether the tests are a tad too dynamic to be statically checked and where to get good test code from if they are.



def do_gen_py(build_lib_path):
while '' in sys.path:

This comment was marked as off-topic.

decls = [d for d in decls if d['name'] != fname + '_out']
for decl in decls:
skip = 'Type' in [a['dynamic_type'] for a in decl['arguments']]
if not skip:

This comment was marked as off-topic.

return name + ': ' + typename + default


def generate_type_hints(fname, decls, is_tensor=False):

This comment was marked as off-topic.

This comment was marked as off-topic.

python_returns_s = python_returns[0]
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
numargs = len(decl['arguments'])
have_vararg_version = (numargs > 0 and decl['arguments'][0]['dynamic_type'] in {'IntList', 'TensorList'} and

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Oct 11, 2018

OK, so I briefly skimmed the generation code, and I think the best thing we can do to make this presentable is to do some Declarations.yaml cleanup, which should simplify a bit of the hacking around you had to do here. #12562 records some overall directions we want to take here, but probably the most important prereqs for this patch are:

  • Centralizing TensorOptions to kwargs transformation
  • Fix some types (e.g., Generator * and std::vector<Tensor>) to no longer be directly C++ types
  • Rename dynamic_type to type (after having eliminated what we previously called type)
  • Eliminate all mentions of Type from Declarations.yaml
  • Centralized logic for determining if a method is inplace or not
  • Change method_of to something simpler like function and/or method. Then don't have to use the "does it have a self argument" heuristic.

@t-vi
Copy link
Collaborator Author

t-vi commented Oct 11, 2018

So what does the #12562 washlist mean for progressing here?
I'll certainly fix the immediate comments, thanks for these!
Is it fair to say that you don't object to the annotation approach using the decorator? Then I would see to sprinkling that a bit more on functional.py.
Regarding code for testing: On the bug report, @elliotwaite mentioned the documentation examples as a source for code. That seems a great idea and we do have all the docstrings, so maybe we canmake a .py from these and run mypy.

@ezyang
Copy link
Contributor

ezyang commented Oct 11, 2018

So what does the #12562 washlist mean for progressing here?

Let me look again at the code and work out exactly what would be affected by the changes.

Is it fair to say that you don't object to the annotation approach using the decorator? Then I would see to sprinkling that a bit more on functional.py.

Yeah, I don't personally object to it. It seems like a quite nice way to add the annotation.

Regarding code for testing: On the bug report, @elliotwaite mentioned the documentation examples as a source for code. That seems a great idea and we do have all the docstrings, so maybe we canmake a .py from these and run mypy.

This sounds like a great idea.

return type_hints


def parameters_from_signature(sig):

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Oct 11, 2018

These spots look like they would make Declarations.yaml refactoring harder

has_out = fname + '_out' in dnames

if not name.startswith('_'):

Don't want to hard-code _out handling and inplace handling; should be centralized.

skip = 'Type' in [a['dynamic_type'] for a in decl['arguments']]

Type is going away. This one's pretty harmless though, easy enough to fix once Type is killed.

if 'self: Tensor' in python_args:

Hard-coded handling of self. Maybe this is OK and we'll keep this in the end. CC @zdevito

python_args += ["dtype: dtype=None",

Hard-coded TensorOptions expansion

have_vararg_version = (numargs > 0 and decl['arguments'][0]['dynamic_type'] in {'IntList', 'TensorList'} and

I actually have no idea what is going on with varargs.

@ezyang
Copy link
Contributor

ezyang commented Oct 11, 2018

As a whole, after reviewing this patch, I feel a bit better about landing it before the refactor happens, after cleanups happen. I don't think it will be that much more burden when we make adjustments. (Maybe this means I'm signing up to do the adjustments :) cc @zdevito @gchanan

@t-vi
Copy link
Collaborator Author

t-vi commented Oct 11, 2018

The _ filtering for "internal" functions does not work as is BTW, as we need annotations for __mul__ .

setup.py Outdated
class create_pyi(distutils.command.build.build):
def run(self):
print("-- Building .pyi --")
if sys.version_info[0] == 3:

This comment was marked as off-topic.

This comment was marked as off-topic.

soumith pushed a commit that referenced this pull request Jan 24, 2019
* create type hint stub files for module torch

This is more a basis for discussion that a ready solution,
as it does lots of funny things, and for many of them
a better solution will be found.

- Initial stab at creating a type torch/__init__.pyi .
- We only do this for Python 3 because we want to
  use type hint introspection.
- So far, we only aim at doing this for torch functions
  and torch.Tensor.
- We need to import the newly build torch. Thus we
  do this at the end of the build.
  We use os.fork to only import the module in a child
  process.
- We use an annotate decorator to be able to put
  type hints on the native python functions in a way that
  a) they're available in the usual place for Python 3
  b) we stay Python 2 compatible
- Some annotatons in torch/functional.py are provided
  as examples, but the remaining ones still need to be done.

This could end up fixing #7318

This was backported from #12500 but with tests removed and
the stub file checked in directly, so we don't have to also
make sure all build shenanigans work.  Master will be properly
tested.  The stub file was generated by running
'python setup.py create_pyi' and then copying the generated
pyi file in the build directory (find -name *.pyi) to torch/

* Ignore pyi in flake8

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Real fix for lint error

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@ezyang
Copy link
Contributor

ezyang commented Jan 24, 2019

After a lot of whacking, the type annotation script no longer requires import torch to generate annotations.

@t-vi
Copy link
Collaborator Author

t-vi commented Jan 24, 2019

Awesome stuff, @ezyang . ❤️ Probably would have been easier to not base it on my feeble attempts. 😕

@ezyang
Copy link
Contributor

ezyang commented Jan 24, 2019

Let's be clear now, your diff helped a lot :) Would have been 10x more work without it.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@ezyang
Copy link
Contributor

ezyang commented Jan 25, 2019

Ugh, CI failure doesn't repro locally. Time to fire up the Docker image, I guess...

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@ezyang
Copy link
Contributor

ezyang commented Jan 28, 2019

What I ran into trouble with was to make sure mypy picks up the right PyTorch. (Which is why I did this temporary directory dance.)

So, in the end, you were right! The problem is specific to when you install PyTorch, because mypy treats installed packages differently from local packages. It will refuse to attempt to typecheck using installed packages unless they are registered on typeshed, which explains why the imports cannot be found. Your symlink trick solved the problem, because it made the package look local again. There doesn't seem to be a way to otherwise force mypy to treat the package as typecheckable when it's not registered on typeshed.

I think I am just going to resurrect this code verbatim with a comment.

soumith pushed a commit that referenced this pull request Jan 29, 2019
* create type hint stub files for module torch

This is more a basis for discussion that a ready solution,
as it does lots of funny things, and for many of them
a better solution will be found.

- Initial stab at creating a type torch/__init__.pyi .
- We only do this for Python 3 because we want to
  use type hint introspection.
- So far, we only aim at doing this for torch functions
  and torch.Tensor.
- We need to import the newly build torch. Thus we
  do this at the end of the build.
  We use os.fork to only import the module in a child
  process.
- We use an annotate decorator to be able to put
  type hints on the native python functions in a way that
  a) they're available in the usual place for Python 3
  b) we stay Python 2 compatible
- Some annotatons in torch/functional.py are provided
  as examples, but the remaining ones still need to be done.

This could end up fixing #7318

This was backported from #12500 but with tests removed and
the stub file checked in directly, so we don't have to also
make sure all build shenanigans work.  Master will be properly
tested.  The stub file was generated by running
'python setup.py create_pyi' and then copying the generated
pyi file in the build directory (find -name *.pyi) to torch/

* Ignore pyi in flake8

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Real fix for lint error

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@ezyang
Copy link
Contributor

ezyang commented Jan 29, 2019

Oh my god it passed all the tests LOL

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@vsiles
Copy link

vsiles commented Feb 11, 2019

I'm curious: why didn't you just annotate torch/tensor.py directly ?

@ezyang
Copy link
Contributor

ezyang commented Feb 12, 2019

@vsiles For a pretty silly reason, actually: the original patch @t-vi wrote didn't do that, and when I tried it, PyCharm shows up the module as torch.tensor rather than torch (in reality, I tried a version where I just added a type stub for torch._C and then let things work out; in that case, I got torch._C.)

This isn't set in stone, esp. if people don't mind the originating module look ugly. But that's the reason.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants