New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for overloaded functions in stubgenc generated by pybind11 #5975
Changes from 9 commits
0834954
7035a97
cf4de08
38872dc
860de83
ae08bd3
d1aee10
13048b5
769577a
c77fe63
824c65a
382946b
fb7ad2d
eebb0e1
cd02f06
e0ace0f
d4be948
e42097b
8082b30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,8 @@ | |
|
||
from mypy.stubutil import ( | ||
is_c_module, write_header, infer_sig_from_docstring, | ||
infer_prop_type_from_docstring | ||
infer_prop_type_from_docstring, ArgList, TypedArgSig, | ||
infer_arg_sig_from_docstring, TypedFunctionSig | ||
) | ||
|
||
|
||
|
@@ -123,51 +124,46 @@ def generate_c_function_stub(module: ModuleType, | |
) -> None: | ||
ret_type = 'None' if name == '__init__' and class_name else 'Any' | ||
|
||
if self_var: | ||
self_arg = '%s, ' % self_var | ||
else: | ||
self_arg = '' | ||
if (name in ('__new__', '__init__') and name not in sigs and class_name and | ||
class_name in class_sigs): | ||
sig = class_sigs[class_name] | ||
inferred = [TypedFunctionSig(name=name, | ||
args=infer_arg_sig_from_docstring(class_sigs[class_name]), | ||
ret_type=ret_type)] # type: Optional[List[TypedFunctionSig]] | ||
else: | ||
docstr = getattr(obj, '__doc__', None) | ||
inferred = infer_sig_from_docstring(docstr, name) | ||
if inferred: | ||
sig, ret_type = inferred | ||
else: | ||
if not inferred: | ||
if class_name and name not in sigs: | ||
sig = infer_method_sig(name) | ||
inferred = [TypedFunctionSig(name, args=infer_method_sig(name), ret_type=ret_type)] | ||
else: | ||
sig = sigs.get(name, '(*args, **kwargs)') | ||
# strip away parenthesis | ||
sig = sig[1:-1] | ||
if sig: | ||
if self_var: | ||
# remove annotation on self from signature if present | ||
groups = sig.split(',', 1) | ||
if groups[0] == self_var or groups[0].startswith(self_var + ':'): | ||
self_arg = '' | ||
sig = '{},{}'.format(self_var, groups[1]) if len(groups) > 1 else self_var | ||
else: | ||
self_arg = self_arg.replace(', ', '') | ||
|
||
if sig: | ||
sig_types = [] | ||
# convert signature in form of "self: TestClass, arg0: str" to | ||
# list [[self, TestClass], [arg0, str]] | ||
for arg in sig.split(','): | ||
arg_type = arg.split(':', 1) | ||
if len(arg_type) == 1: | ||
# there is no type provided in docstring | ||
sig_types.append(arg_type[0].strip()) | ||
else: | ||
arg_type_name = strip_or_import(arg_type[1].strip(), module, imports) | ||
sig_types.append('%s: %s' % (arg_type[0].strip(), arg_type_name)) | ||
sig = ", ".join(sig_types) | ||
inferred = [TypedFunctionSig(name=name, | ||
args=infer_arg_sig_from_docstring( | ||
sigs.get(name, '(*args, **kwargs)')), | ||
ret_type=ret_type)] | ||
|
||
is_overloaded = len(inferred) > 1 if inferred else False | ||
if is_overloaded: | ||
imports.append('from typing import overload') | ||
if inferred: | ||
for signature in inferred: | ||
sig = [] | ||
for arg in signature.args: | ||
if arg.name == self_var or not arg.type: | ||
# no type | ||
sig.append(arg.name) | ||
else: | ||
# type info | ||
sig.append('{}: {}'.format(arg.name, strip_or_import(arg.type, | ||
module, | ||
imports))) | ||
|
||
ret_type = strip_or_import(ret_type, module, imports) | ||
output.append('def %s(%s%s) -> %s: ...' % (name, self_arg, sig, ret_type)) | ||
if is_overloaded: | ||
output.append('@overload') | ||
output.append('def {function}({args}) -> {ret}: ...'.format( | ||
function=name, | ||
args=", ".join(sig), | ||
ret=strip_or_import(signature.ret_type, module, imports) | ||
)) | ||
|
||
|
||
def strip_or_import(typ: str, module: ModuleType, imports: List[str]) -> str: | ||
|
@@ -307,29 +303,38 @@ def is_skipped_attribute(attr: str) -> bool: | |
'__weakref__') # For pickling | ||
|
||
|
||
def infer_method_sig(name: str) -> str: | ||
def infer_method_sig(name: str) -> ArgList: | ||
if name.startswith('__') and name.endswith('__'): | ||
name = name[2:-2] | ||
if name in ('hash', 'iter', 'next', 'sizeof', 'copy', 'deepcopy', 'reduce', 'getinitargs', | ||
'int', 'float', 'trunc', 'complex', 'bool'): | ||
return '()' | ||
return [] | ||
if name == 'getitem': | ||
return '(index)' | ||
return [TypedArgSig(name='index', type=None, default=None)] | ||
if name == 'setitem': | ||
return '(index, object)' | ||
return [ | ||
TypedArgSig(name='index', type=None, default=None), | ||
TypedArgSig(name='object', type=None, default=None) | ||
] | ||
if name in ('delattr', 'getattr'): | ||
return '(name)' | ||
return [TypedArgSig(name='name', type=None, default=None)] | ||
if name == 'setattr': | ||
return '(name, value)' | ||
return [ | ||
TypedArgSig(name='name', type=None, default=None), | ||
TypedArgSig(name='value', type=None, default=None) | ||
] | ||
if name == 'getstate': | ||
return '()' | ||
return [] | ||
if name == 'setstate': | ||
return '(state)' | ||
return [TypedArgSig(name='state', type=None, default=None)] | ||
if name in ('eq', 'ne', 'lt', 'le', 'gt', 'ge', | ||
'add', 'radd', 'sub', 'rsub', 'mul', 'rmul', | ||
'mod', 'rmod', 'floordiv', 'rfloordiv', 'truediv', 'rtruediv', | ||
'divmod', 'rdivmod', 'pow', 'rpow'): | ||
return '(other)' | ||
return [TypedArgSig(name='other', type=None, default=None)] | ||
if name in ('neg', 'pos'): | ||
return '()' | ||
return '(*args, **kwargs)' | ||
return [] | ||
return [ | ||
TypedArgSig(name='*args', type=None, default=None), | ||
TypedArgSig(name='**kwargs', type=None, default=None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would make |
||
] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,14 +1,31 @@ | ||||||
import enum | ||||||
import io | ||||||
import re | ||||||
import sys | ||||||
import os | ||||||
import tokenize | ||||||
|
||||||
from typing import Optional, Tuple, Sequence, MutableSequence, List, MutableMapping, IO | ||||||
from typing import Optional, Tuple, Sequence, MutableSequence, List, MutableMapping, IO, NamedTuple | ||||||
from types import ModuleType | ||||||
|
||||||
|
||||||
# Type Alias for Signatures | ||||||
Sig = Tuple[str, str] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this still used somewhere? If yes, maybe we should switch to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's still used in:
|
||||||
|
||||||
TypedArgSig = NamedTuple('TypedArgSig', [ | ||||||
wiktorn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
('name', str), | ||||||
('type', Optional[str]), | ||||||
('default', Optional[str]) | ||||||
wiktorn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
]) | ||||||
|
||||||
ArgList = List[TypedArgSig] | ||||||
wiktorn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
TypedFunctionSig = NamedTuple('TypedFunctionSig', [ | ||||||
('name', str), | ||||||
('args', ArgList), | ||||||
('ret_type', str) | ||||||
]) | ||||||
|
||||||
|
||||||
def parse_signature(sig: str) -> Optional[Tuple[str, | ||||||
List[str], | ||||||
|
@@ -106,32 +123,129 @@ def write_header(file: IO[str], module_name: Optional[str] = None, | |||||
'# NOTE: This dynamically typed stub was automatically generated by stubgen.\n\n') | ||||||
|
||||||
|
||||||
def infer_sig_from_docstring(docstr: str, name: str) -> Optional[Tuple[str, str]]: | ||||||
class State(enum.Enum): | ||||||
INIT = 1 | ||||||
FUNCTION_NAME = 2 | ||||||
ARGUMENT_LIST = 3 | ||||||
ARGUMENT_TYPE = 4 | ||||||
ARGUMENT_DEFAULT = 5 | ||||||
RETURN_VALUE = 6 | ||||||
OPEN_BRACKET = 7 | ||||||
wiktorn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
|
||||||
def infer_sig_from_docstring(docstr: str, name: str) -> Optional[List[TypedFunctionSig]]: | ||||||
if not docstr: | ||||||
return None | ||||||
docstr = docstr.lstrip() | ||||||
# look for function signature, which is any string of the format | ||||||
# <function_name>(<signature>) -> <return type> | ||||||
# or perhaps without the return type | ||||||
|
||||||
# in the signature, we allow the following characters: | ||||||
# colon/equal: to match default values, like "a: int=1" | ||||||
# comma/space/brackets: for type hints like "a: Tuple[int, float]" | ||||||
# dot: for classes annotating using full path, like "a: foo.bar.baz" | ||||||
# to capture return type, | ||||||
sig_str = r'\([a-zA-Z0-9_=:, \[\]\.]*\)' | ||||||
sig_match = r'%s(%s)' % (name, sig_str) | ||||||
# first, try to capture return type; we just match until end of line | ||||||
m = re.match(sig_match + ' -> ([a-zA-Z].*)$', docstr, re.MULTILINE) | ||||||
if m: | ||||||
# strip potential white spaces at the right of return type | ||||||
return m.group(1), m.group(2).rstrip() | ||||||
|
||||||
# try to not match return type | ||||||
m = re.match(sig_match, docstr) | ||||||
if m: | ||||||
return m.group(1), 'Any' | ||||||
return None | ||||||
|
||||||
state = [State.INIT] | ||||||
accumulator = "" | ||||||
arg_type = None | ||||||
arg_name = "" | ||||||
arg_default = None | ||||||
ret_type = "Any" | ||||||
found = False | ||||||
args = [] # type: List[TypedArgSig] | ||||||
signatures = [] # type: List[TypedFunctionSig] | ||||||
try: | ||||||
for token in tokenize.tokenize(io.BytesIO(docstr.encode('utf-8')).readline): | ||||||
if token.type == tokenize.NAME and token.string == name and state[-1] == State.INIT: | ||||||
state.append(State.FUNCTION_NAME) | ||||||
|
||||||
elif (token.type == tokenize.OP and token.string == '(' and | ||||||
state[-1] == State.FUNCTION_NAME): | ||||||
state.pop() | ||||||
accumulator = "" | ||||||
found = True | ||||||
state.append(State.ARGUMENT_LIST) | ||||||
|
||||||
elif state[-1] == State.FUNCTION_NAME: | ||||||
# reset state, function name not followed by '(' | ||||||
state.pop() | ||||||
|
||||||
elif (token.type == tokenize.OP and token.string in ('[', '(', '{') and | ||||||
state[-1] != State.INIT): | ||||||
accumulator += token.string | ||||||
state.append(State.OPEN_BRACKET) | ||||||
|
||||||
elif (token.type == tokenize.OP and token.string in (']', ')', '}') and | ||||||
state[-1] == State.OPEN_BRACKET): | ||||||
accumulator += token.string | ||||||
state.pop() | ||||||
|
||||||
elif (token.type == tokenize.OP and token.string == ':' and | ||||||
state[-1] == State.ARGUMENT_LIST): | ||||||
arg_name = accumulator | ||||||
accumulator = "" | ||||||
state.append(State.ARGUMENT_TYPE) | ||||||
|
||||||
elif (token.type == tokenize.OP and token.string == '=' and | ||||||
state[-1] in (State.ARGUMENT_LIST, State.ARGUMENT_TYPE)): | ||||||
if state[-1] == State.ARGUMENT_TYPE: | ||||||
arg_type = accumulator | ||||||
state.pop() | ||||||
else: | ||||||
arg_name = accumulator | ||||||
accumulator = "" | ||||||
state.append(State.ARGUMENT_DEFAULT) | ||||||
|
||||||
elif (token.type == tokenize.OP and token.string in (',', ')') and | ||||||
state[-1] in (State.ARGUMENT_LIST, State.ARGUMENT_DEFAULT, State.ARGUMENT_TYPE)): | ||||||
if state[-1] == State.ARGUMENT_DEFAULT: | ||||||
arg_default = accumulator | ||||||
state.pop() | ||||||
elif state[-1] == State.ARGUMENT_TYPE: | ||||||
arg_type = accumulator | ||||||
state.pop() | ||||||
elif state[-1] == State.ARGUMENT_LIST: | ||||||
arg_name = accumulator | ||||||
|
||||||
if token.string == ')': | ||||||
state.pop() | ||||||
args.append(TypedArgSig(name=arg_name, type=arg_type, default=arg_default)) | ||||||
arg_name = "" | ||||||
arg_type = None | ||||||
arg_default = None | ||||||
accumulator = "" | ||||||
|
||||||
elif token.type == tokenize.OP and token.string == '->' and state[-1] == State.INIT: | ||||||
accumulator = "" | ||||||
state.append(State.RETURN_VALUE) | ||||||
|
||||||
# ENDMAKER is necessary for python 3.4 and 3.5 | ||||||
elif (token.type in (tokenize.NEWLINE, tokenize.ENDMARKER) and | ||||||
state[-1] in (State.INIT, State.RETURN_VALUE)): | ||||||
if state[-1] == State.RETURN_VALUE: | ||||||
ret_type = accumulator | ||||||
accumulator = "" | ||||||
state.pop() | ||||||
|
||||||
if found: | ||||||
signatures.append(TypedFunctionSig(name=name, args=args, ret_type=ret_type)) | ||||||
found = False | ||||||
args = [] | ||||||
ret_type = 'Any' | ||||||
# leave state as INIT | ||||||
else: | ||||||
accumulator += token.string | ||||||
|
||||||
return signatures | ||||||
except tokenize.TokenError: | ||||||
wiktorn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# return as much as collected | ||||||
return signatures | ||||||
|
||||||
|
||||||
def infer_arg_sig_from_docstring(docstr: str) -> ArgList: | ||||||
""" | ||||||
convert signature in form of "(self: TestClass, arg0: str='ada')" to ArgList | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
:param docstr: | ||||||
:return: ArgList with infered argument names and its types | ||||||
""" | ||||||
wiktorn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
ret = infer_sig_from_docstring("stub" + docstr, "stub") | ||||||
if ret: | ||||||
return ret[0].args | ||||||
|
||||||
return [] | ||||||
|
||||||
|
||||||
def infer_prop_type_from_docstring(docstr: str) -> Optional[str]: | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For multiline we use this style:
(also many of these will not need to be multiline if you implement the suggestion above).