Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
create type hint stub files for module torch (#16089)
* create type hint stub files for module torch This is more a basis for discussion that a ready solution, as it does lots of funny things, and for many of them a better solution will be found. - Initial stab at creating a type torch/__init__.pyi . - We only do this for Python 3 because we want to use type hint introspection. - So far, we only aim at doing this for torch functions and torch.Tensor. - We need to import the newly build torch. Thus we do this at the end of the build. We use os.fork to only import the module in a child process. - We use an annotate decorator to be able to put type hints on the native python functions in a way that a) they're available in the usual place for Python 3 b) we stay Python 2 compatible - Some annotatons in torch/functional.py are provided as examples, but the remaining ones still need to be done. This could end up fixing #7318 This was backported from #12500 but with tests removed and the stub file checked in directly, so we don't have to also make sure all build shenanigans work. Master will be properly tested. The stub file was generated by running 'python setup.py create_pyi' and then copying the generated pyi file in the build directory (find -name *.pyi) to torch/ * Ignore pyi in flake8 Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Real fix for lint error Signed-off-by: Edward Z. Yang <ezyang@fb.com>
- Loading branch information
Showing
5 changed files
with
1,768 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
from __future__ import print_function | ||
|
||
import os | ||
import sys | ||
import inspect | ||
|
||
needed_modules = set() | ||
|
||
|
||
def type_to_python(typename): | ||
typename = {'int64_t': 'int', | ||
'Scalar': 'float', | ||
'ScalarType': 'dtype', | ||
'Device': 'Union[device, str, None]', | ||
'IntList': 'Tuple[int, ...]', | ||
'Tensor?': 'Optional[Tensor]', | ||
'Tensor': 'Tensor', | ||
'bool': 'bool', | ||
'double': 'float', | ||
'Generator *': 'Generator', | ||
'Generator*': 'Generator', | ||
'std::vector<Tensor>': 'Tuple[Tensor, ...]', | ||
'TensorList': 'Tuple[Tensor, ...]', | ||
'Storage': 'Storage', # inaccurate (FloatStorage...) | ||
'SparseTensorRef': 'Tensor', | ||
'void': 'None', | ||
'Layout': 'layout', | ||
'void*': 'int', # dataptr | ||
'std::string': 'str', | ||
'real': 'float', | ||
'accreal': 'float', | ||
'Union[int, Tuple[int, ...]]': 'Union[int, Tuple[int, ...]]', | ||
'IntegerTensor': 'Tensor', | ||
'BoolTensor': 'Tensor', | ||
'IndexTensor': 'Tensor', | ||
}[typename] | ||
return typename | ||
|
||
|
||
def arg_to_type_hint(arg): | ||
name = arg['name'] | ||
if name == 'from': # keyword... | ||
name += '_' | ||
typename = type_to_python(arg['dynamic_type']) | ||
if arg.get('is_nullable'): | ||
typename = 'Optional[' + typename + ']' | ||
if 'default' in arg: | ||
default = '=' + str(arg['default']) | ||
else: | ||
default = '' | ||
return name + ': ' + typename + default | ||
|
||
|
||
def generate_type_hints(fname, decls, is_tensor=False): | ||
type_hints = [] | ||
dnames = ([d['name'] for d in decls]) | ||
has_out = fname + '_out' in dnames | ||
if has_out: | ||
decls = [d for d in decls if d['name'] != fname + '_out'] | ||
for decl in decls: | ||
skip = 'Type' in [a['dynamic_type'] for a in decl['arguments']] | ||
if not skip: | ||
has_tensor_options = 'TensorOptions' in [a['dynamic_type'] for a in decl['arguments']] | ||
python_args = [arg_to_type_hint(a) for a in decl['arguments'] if a['dynamic_type'] != 'TensorOptions'] | ||
if is_tensor: | ||
if 'self: Tensor' in python_args: | ||
python_args.remove('self: Tensor') | ||
python_args = ['self'] + python_args | ||
elif 'self: float' in python_args: | ||
python_args.remove('self: float') | ||
python_args = ['self'] + python_args | ||
else: | ||
raise Exception("method without self is unexpected") | ||
if has_out: | ||
python_args += ['*', 'out: Optional[Tensor]=None'] | ||
if has_tensor_options: | ||
if '*' not in python_args: | ||
python_args.append('*') | ||
python_args += ["dtype: dtype=None", | ||
"layout: layout=torch.strided", | ||
"device: device=None", | ||
"requires_grad:bool=False"] | ||
python_args_s = ', '.join(python_args) | ||
python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']] | ||
if len(python_returns) > 1: | ||
python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']' | ||
else: | ||
python_returns_s = python_returns[0] | ||
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s) | ||
numargs = len(decl['arguments']) | ||
have_vararg_version = (numargs > 0 and decl['arguments'][0]['dynamic_type'] in {'IntList', 'TensorList'} and | ||
(numargs == 1 or python_args[1] == '*')) | ||
if len(decls) > 1 or have_vararg_version: | ||
type_hint = "@overload\n" + type_hint | ||
type_hints.append(type_hint) | ||
if have_vararg_version: | ||
typelist = decl['arguments'][0]['dynamic_type'] | ||
if typelist == 'IntList': | ||
vararg_type = 'int' | ||
else: | ||
vararg_type = 'Tensor' | ||
# replace first argument and eliminate '*' if present | ||
python_args = ['*' + decl['arguments'][0]['name'] + ': ' + vararg_type] + python_args[2:] | ||
python_args_s = ', '.join(python_args) | ||
type_hint = "@overload\ndef {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s) | ||
type_hints.append(type_hint) | ||
return type_hints | ||
|
||
|
||
def parameters_from_signature(sig): | ||
# adapted from standard library inspect | ||
result = [] | ||
render_pos_only_separator = False | ||
render_kw_only_separator = True | ||
for param in sig.parameters.values(): | ||
kind = param.kind | ||
formatted = param._name | ||
|
||
if param._annotation is not inspect._empty: | ||
formatted = '{}:{}'.format(formatted, inspect.formatannotation(param._annotation)) | ||
|
||
if param._default is not inspect._empty: | ||
if type(param._default).__name__ == 'module': # better way? | ||
default_repr = param._default.__name__ | ||
needed_modules.add(default_repr) | ||
else: | ||
default_repr = repr(param._default) | ||
formatted = '{}={}'.format(formatted, default_repr) | ||
|
||
if kind == inspect._VAR_POSITIONAL: | ||
formatted = '*' + formatted | ||
elif kind == inspect._VAR_KEYWORD: | ||
formatted = '**' + formatted | ||
|
||
if kind == inspect._POSITIONAL_ONLY: | ||
render_pos_only_separator = True | ||
elif render_pos_only_separator: | ||
result.append('/') | ||
render_pos_only_separator = False | ||
|
||
if kind == inspect._VAR_POSITIONAL: | ||
render_kw_only_separator = False | ||
elif kind == inspect._KEYWORD_ONLY and render_kw_only_separator: | ||
result.append('*') | ||
render_kw_only_separator = False | ||
|
||
formatted = formatted.replace('torch.Tensor', 'Tensor') | ||
result.append(formatted) | ||
|
||
if render_pos_only_separator: | ||
result.append('/') | ||
return result | ||
|
||
|
||
def type_hint_from_python_fn(fname, fn): | ||
sig = inspect.signature(fn) | ||
python_parameters = parameters_from_signature(sig) | ||
|
||
return_annotation = None | ||
if sig.return_annotation is not inspect._empty: | ||
return_annotation = inspect.formatannotation(sig.return_annotation) | ||
return_annotation = return_annotation.replace('torch.Tensor', 'Tensor') | ||
|
||
python_parameters = ', '.join(python_parameters) | ||
if return_annotation: | ||
return ["def {}({}) -> {}: ...".format(fname, python_parameters, return_annotation)] | ||
else: | ||
return ["def {}({}): ...".format(fname, python_parameters, return_annotation)] | ||
|
||
|
||
def do_gen_py(build_lib_path): | ||
import torch | ||
import yaml | ||
import inspect | ||
import types | ||
import re | ||
import collections | ||
|
||
fns = collections.defaultdict(list) | ||
yaml_loader = getattr(yaml, 'CLoader', yaml.loader) | ||
|
||
for d in yaml.load(open('torch/share/ATen/Declarations.yaml'), | ||
Loader=yaml_loader): | ||
|
||
name = d['name'] | ||
if name.endswith('_out'): | ||
name = name[:-4] | ||
if not name.startswith('_'): | ||
fns[name].append(d) | ||
|
||
type_hints = [] | ||
for fname in dir(torch): | ||
fn = getattr(torch, fname) | ||
docstr = inspect.getdoc(fn) | ||
if docstr: | ||
if isinstance(fn, types.BuiltinFunctionType): | ||
if fname in fns: | ||
type_hints += generate_type_hints(fname, fns[fname]) | ||
else: | ||
pass # todo | ||
elif isinstance(fn, types.FunctionType): | ||
type_hints += type_hint_from_python_fn(fname, fn) | ||
type_hints_s = '\n\n'.join(type_hints) + '\n' | ||
|
||
tensor_type_hints = [] | ||
for fname in dir(torch.Tensor): | ||
fn = getattr(torch.Tensor, fname) | ||
docstr = inspect.getdoc(fn) | ||
if docstr and not fname.startswith('_'): | ||
if getattr(fn, '__qualname__', '').startswith('_TensorBase.'): # better check? | ||
if fname in fns: | ||
tensor_type_hints += generate_type_hints(fname, fns[fname], is_tensor=True) | ||
else: | ||
pass # these require magic... print (fname) | ||
elif isinstance(fn, types.FunctionType): # python defined | ||
tensor_type_hints += type_hint_from_python_fn(fname, fn) | ||
tensor_type_hints_s = """class Tensor:\n""" + '\n\n'.join( | ||
[' ' + re.sub(r"\bTensor\b", "'Tensor'", s.replace('\n', '\n' + ' ')) | ||
for s in tensor_type_hints]) + '\n\n' | ||
|
||
header = """ | ||
from typing import Tuple, Optional, Union | ||
""" | ||
header += '\n'.join(["import " + m for m in needed_modules]) | ||
with open(os.path.join(build_lib_path, 'torch', '__init__.pyi'), 'w') as f: | ||
print(header, file=f) | ||
print(tensor_type_hints_s, file=f) | ||
print(type_hints_s, file=f) | ||
|
||
|
||
def gen_pyi(build_lib_path): | ||
# we import torch, better do that in a subprocess | ||
if os.fork() == 0: | ||
do_gen_py(build_lib_path) | ||
else: | ||
os.wait() |
Oops, something went wrong.