diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 2ef9a4c444d0..ce2b3b8d8880 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -1720,6 +1720,7 @@ def generate_stubs(options: Options) -> None: ) # Separately analyse C modules using different logic. + all_modules = sorted(m.module for m in (py_modules + c_modules)) for mod in c_modules: if any(py_mod.module.startswith(mod.module + ".") for py_mod in py_modules + c_modules): target = mod.module.replace(".", "/") + "/__init__.pyi" @@ -1728,7 +1729,9 @@ def generate_stubs(options: Options) -> None: target = os.path.join(options.output_dir, target) files.append(target) with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): - generate_stub_for_c_module(mod.module, target, sig_generators=sig_generators) + generate_stub_for_c_module( + mod.module, target, known_modules=all_modules, sig_generators=sig_generators + ) num_modules = len(py_modules) + len(c_modules) if not options.quiet and num_modules > 0: print("Processed %d modules" % num_modules) diff --git a/mypy/stubgenc.py b/mypy/stubgenc.py index add33e66cee3..da0fc5cee6b9 100755 --- a/mypy/stubgenc.py +++ b/mypy/stubgenc.py @@ -44,6 +44,17 @@ class SignatureGenerator: """Abstract base class for extracting a list of FunctionSigs for each function.""" + def remove_self_type( + self, inferred: list[FunctionSig] | None, self_var: str + ) -> list[FunctionSig] | None: + """Remove type annotation from self/cls argument""" + if inferred: + for signature in inferred: + if signature.args: + if signature.args[0].name == self_var: + signature.args[0].type = None + return inferred + @abstractmethod def get_function_sig( self, func: object, module_name: str, name: str @@ -52,7 +63,7 @@ def get_function_sig( @abstractmethod def get_method_sig( - self, func: object, module_name: str, class_name: str, name: str, self_var: str + self, cls: type, func: object, module_name: str, class_name: str, name: str, self_var: str ) -> list[FunctionSig] | None: pass @@ -83,7 +94,7 @@ def get_function_sig( return None def get_method_sig( - self, func: object, module_name: str, class_name: str, name: str, self_var: str + self, cls: type, func: object, module_name: str, class_name: str, name: str, self_var: str ) -> list[FunctionSig] | None: if ( name in ("__new__", "__init__") @@ -94,10 +105,11 @@ def get_method_sig( FunctionSig( name=name, args=infer_arg_sig_from_anon_docstring(self.class_sigs[class_name]), - ret_type="None" if name == "__init__" else "Any", + ret_type=infer_method_ret_type(name), ) ] - return self.get_function_sig(func, module_name, name) + inferred = self.get_function_sig(func, module_name, name) + return self.remove_self_type(inferred, self_var) class DocstringSignatureGenerator(SignatureGenerator): @@ -114,9 +126,19 @@ def get_function_sig( return inferred def get_method_sig( - self, func: object, module_name: str, class_name: str, name: str, self_var: str + self, + cls: type, + func: object, + module_name: str, + class_name: str, + func_name: str, + self_var: str, ) -> list[FunctionSig] | None: - return self.get_function_sig(func, module_name, name) + inferred = self.get_function_sig(func, module_name, func_name) + if not inferred and func_name == "__init__": + # look for class-level constructor signatures of the form () + inferred = self.get_function_sig(cls, module_name, class_name) + return self.remove_self_type(inferred, self_var) class FallbackSignatureGenerator(SignatureGenerator): @@ -132,19 +154,22 @@ def get_function_sig( ] def get_method_sig( - self, func: object, module_name: str, class_name: str, name: str, self_var: str + self, cls: type, func: object, module_name: str, class_name: str, name: str, self_var: str ) -> list[FunctionSig] | None: return [ FunctionSig( name=name, - args=infer_method_sig(name, self_var), - ret_type="None" if name == "__init__" else "Any", + args=infer_method_args(name, self_var), + ret_type=infer_method_ret_type(name), ) ] def generate_stub_for_c_module( - module_name: str, target: str, sig_generators: Iterable[SignatureGenerator] + module_name: str, + target: str, + known_modules: list[str], + sig_generators: Iterable[SignatureGenerator], ) -> None: """Generate stub for C module. @@ -166,11 +191,17 @@ def generate_stub_for_c_module( imports: list[str] = [] functions: list[str] = [] done = set() - items = sorted(module.__dict__.items(), key=lambda x: x[0]) + items = sorted(get_members(module), key=lambda x: x[0]) for name, obj in items: if is_c_function(obj): generate_c_function_stub( - module, name, obj, functions, imports=imports, sig_generators=sig_generators + module, + name, + obj, + output=functions, + known_modules=known_modules, + imports=imports, + sig_generators=sig_generators, ) done.add(name) types: list[str] = [] @@ -179,7 +210,13 @@ def generate_stub_for_c_module( continue if is_c_type(obj): generate_c_type_stub( - module, name, obj, types, imports=imports, sig_generators=sig_generators + module, + name, + obj, + output=types, + known_modules=known_modules, + imports=imports, + sig_generators=sig_generators, ) done.add(name) variables = [] @@ -187,7 +224,9 @@ def generate_stub_for_c_module( if name.startswith("__") and name.endswith("__"): continue if name not in done and not inspect.ismodule(obj): - type_str = strip_or_import(get_type_fullname(type(obj)), module, imports) + type_str = strip_or_import( + get_type_fullname(type(obj)), module, known_modules, imports + ) variables.append(f"{name}: {type_str}") output = sorted(set(imports)) for line in variables: @@ -218,6 +257,22 @@ def add_typing_import(output: list[str]) -> list[str]: return output[:] +def get_members(obj: object) -> list[tuple[str, Any]]: + obj_dict: Mapping[str, Any] = getattr(obj, "__dict__") # noqa: B009 + results = [] + for name in obj_dict: + if is_skipped_attribute(name): + continue + # Try to get the value via getattr + try: + value = getattr(obj, name) + except AttributeError: + continue + else: + results.append((name, value)) + return results + + def is_c_function(obj: object) -> bool: return inspect.isbuiltin(obj) or type(obj) is type(ord) @@ -257,10 +312,13 @@ def generate_c_function_stub( module: ModuleType, name: str, obj: object, + *, + known_modules: list[str], + sig_generators: Iterable[SignatureGenerator], output: list[str], imports: list[str], - sig_generators: Iterable[SignatureGenerator], self_var: str | None = None, + cls: type | None = None, class_name: str | None = None, ) -> None: """Generate stub for a single function or method. @@ -273,13 +331,16 @@ def generate_c_function_stub( inferred: list[FunctionSig] | None = None if class_name: # method: + assert cls is not None, "cls should be provided for methods" assert self_var is not None, "self_var should be provided for methods" for sig_gen in sig_generators: - inferred = sig_gen.get_method_sig(obj, module.__name__, class_name, name, self_var) + inferred = sig_gen.get_method_sig( + cls, obj, module.__name__, class_name, name, self_var + ) if inferred: # add self/cls var, if not present for sig in inferred: - if not sig.args or sig.args[0].name != self_var: + if not sig.args or sig.args[0].name not in ("self", "cls"): sig.args.insert(0, ArgSig(name=self_var)) break else: @@ -295,7 +356,6 @@ def generate_c_function_stub( "if FallbackSignatureGenerator is provided" ) - is_classmethod = self_var == "cls" is_overloaded = len(inferred) > 1 if inferred else False if is_overloaded: imports.append("from typing import overload") @@ -303,35 +363,35 @@ def generate_c_function_stub( for signature in inferred: args: list[str] = [] for arg in signature.args: - if arg.name == self_var: - arg_def = self_var - else: - arg_def = arg.name - if arg_def == "None": - arg_def = "_none" # None is not a valid argument name + arg_def = arg.name + if arg_def == "None": + arg_def = "_none" # None is not a valid argument name - if arg.type: - arg_def += ": " + strip_or_import(arg.type, module, imports) + if arg.type: + arg_def += ": " + strip_or_import(arg.type, module, known_modules, imports) - if arg.default: - arg_def += " = ..." + if arg.default: + arg_def += " = ..." args.append(arg_def) if is_overloaded: output.append("@overload") - if is_classmethod: + # a sig generator indicates @classmethod by specifying the cls arg + if class_name and signature.args and signature.args[0].name == "cls": output.append("@classmethod") output.append( "def {function}({args}) -> {ret}: ...".format( function=name, args=", ".join(args), - ret=strip_or_import(signature.ret_type, module, imports), + ret=strip_or_import(signature.ret_type, module, known_modules, imports), ) ) -def strip_or_import(typ: str, module: ModuleType, imports: list[str]) -> str: +def strip_or_import( + typ: str, module: ModuleType, known_modules: list[str], imports: list[str] +) -> str: """Strips unnecessary module names from typ. If typ represents a type that is inside module or is a type coming from builtins, remove @@ -340,21 +400,33 @@ def strip_or_import(typ: str, module: ModuleType, imports: list[str]) -> str: Arguments: typ: name of the type module: in which this type is used + known_modules: other modules being processed imports: list of import statements (may be modified during the call) """ + local_modules = ["builtins"] + if module: + local_modules.append(module.__name__) + stripped_type = typ if any(c in typ for c in "[,"): for subtyp in re.split(r"[\[,\]]", typ): - strip_or_import(subtyp.strip(), module, imports) - if module: - stripped_type = re.sub(r"(^|[\[, ]+)" + re.escape(module.__name__ + "."), r"\1", typ) - elif module and typ.startswith(module.__name__ + "."): - stripped_type = typ[len(module.__name__) + 1 :] + stripped_subtyp = strip_or_import(subtyp.strip(), module, known_modules, imports) + if stripped_subtyp != subtyp: + stripped_type = re.sub( + r"(^|[\[, ]+)" + re.escape(subtyp) + r"($|[\], ]+)", + r"\1" + stripped_subtyp + r"\2", + stripped_type, + ) elif "." in typ: - arg_module = typ[: typ.rindex(".")] - if arg_module == "builtins": - stripped_type = typ[len("builtins") + 1 :] + for module_name in local_modules + list(reversed(known_modules)): + if typ.startswith(module_name + "."): + if module_name in local_modules: + stripped_type = typ[len(module_name) + 1 :] + arg_module = module_name + break else: + arg_module = typ[: typ.rindex(".")] + if arg_module not in local_modules: imports.append(f"import {arg_module}") if stripped_type == "NoneType": stripped_type = "None" @@ -373,6 +445,7 @@ def generate_c_property_stub( ro_properties: list[str], readonly: bool, module: ModuleType | None = None, + known_modules: list[str] | None = None, imports: list[str] | None = None, ) -> None: """Generate property stub using introspection of 'obj'. @@ -392,10 +465,6 @@ def infer_prop_type(docstr: str | None) -> str | None: else: return None - # Ignore special properties/attributes. - if is_skipped_attribute(name): - return - inferred = infer_prop_type(getattr(obj, "__doc__", None)) if not inferred: fget = getattr(obj, "fget", None) @@ -403,8 +472,8 @@ def infer_prop_type(docstr: str | None) -> str | None: if not inferred: inferred = "Any" - if module is not None and imports is not None: - inferred = strip_or_import(inferred, module, imports) + if module is not None and imports is not None and known_modules is not None: + inferred = strip_or_import(inferred, module, known_modules, imports) if is_static_property(obj): trailing_comment = " # read-only" if readonly else "" @@ -422,6 +491,7 @@ def generate_c_type_stub( class_name: str, obj: type, output: list[str], + known_modules: list[str], imports: list[str], sig_generators: Iterable[SignatureGenerator], ) -> None: @@ -430,69 +500,75 @@ def generate_c_type_stub( The result lines will be appended to 'output'. If necessary, any required names will be added to 'imports'. """ - # typeshed gives obj.__dict__ the not quite correct type Dict[str, Any] - # (it could be a mappingproxy!), which makes mypyc mad, so obfuscate it. - obj_dict: Mapping[str, Any] = getattr(obj, "__dict__") # noqa: B009 - items = sorted(obj_dict.items(), key=lambda x: method_name_sort_key(x[0])) + raw_lookup = getattr(obj, "__dict__") # noqa: B009 + items = sorted(get_members(obj), key=lambda x: method_name_sort_key(x[0])) + names = set(x[0] for x in items) methods: list[str] = [] types: list[str] = [] static_properties: list[str] = [] rw_properties: list[str] = [] ro_properties: list[str] = [] - done: set[str] = set() + attrs: list[tuple[str, Any]] = [] for attr, value in items: + # use unevaluated descriptors when dealing with property inspection + raw_value = raw_lookup.get(attr, value) if is_c_method(value) or is_c_classmethod(value): - done.add(attr) - if not is_skipped_attribute(attr): - if attr == "__new__": - # TODO: We should support __new__. - if "__init__" in obj_dict: - # Avoid duplicate functions if both are present. - # But is there any case where .__new__() has a - # better signature than __init__() ? - continue - attr = "__init__" - if is_c_classmethod(value): - self_var = "cls" - else: - self_var = "self" - generate_c_function_stub( - module, - attr, - value, - methods, - imports=imports, - self_var=self_var, - class_name=class_name, - sig_generators=sig_generators, - ) - elif is_c_property(value): - done.add(attr) - generate_c_property_stub( + if attr == "__new__": + # TODO: We should support __new__. + if "__init__" in names: + # Avoid duplicate functions if both are present. + # But is there any case where .__new__() has a + # better signature than __init__() ? + continue + attr = "__init__" + if is_c_classmethod(value): + self_var = "cls" + else: + self_var = "self" + generate_c_function_stub( + module, attr, value, + output=methods, + known_modules=known_modules, + imports=imports, + self_var=self_var, + cls=obj, + class_name=class_name, + sig_generators=sig_generators, + ) + elif is_c_property(raw_value): + generate_c_property_stub( + attr, + raw_value, static_properties, rw_properties, ro_properties, - is_c_property_readonly(value), + is_c_property_readonly(raw_value), module=module, + known_modules=known_modules, imports=imports, ) elif is_c_type(value): generate_c_type_stub( - module, attr, value, types, imports=imports, sig_generators=sig_generators + module, + attr, + value, + types, + imports=imports, + known_modules=known_modules, + sig_generators=sig_generators, ) - done.add(attr) + else: + attrs.append((attr, value)) - for attr, value in items: - if is_skipped_attribute(attr): - continue - if attr not in done: - static_properties.append( - "{}: ClassVar[{}] = ...".format( - attr, strip_or_import(get_type_fullname(type(value)), module, imports) - ) + for attr, value in attrs: + static_properties.append( + "{}: ClassVar[{}] = ...".format( + attr, + strip_or_import(get_type_fullname(type(value)), module, known_modules, imports), ) + ) all_bases = type.mro(obj) if all_bases[-1] is object: # TODO: Is this always object? @@ -510,7 +586,8 @@ def generate_c_type_stub( bases.append(base) if bases: bases_str = "(%s)" % ", ".join( - strip_or_import(get_type_fullname(base), module, imports) for base in bases + strip_or_import(get_type_fullname(base), module, known_modules, imports) + for base in bases ) else: bases_str = "" @@ -559,6 +636,7 @@ def is_pybind_skipped_attribute(attr: str) -> bool: def is_skipped_attribute(attr: str) -> bool: return attr in ( + "__class__", "__getattribute__", "__str__", "__repr__", @@ -571,7 +649,7 @@ def is_skipped_attribute(attr: str) -> bool: ) -def infer_method_sig(name: str, self_var: str | None = None) -> list[ArgSig]: +def infer_method_args(name: str, self_var: str | None = None) -> list[ArgSig]: args: list[ArgSig] | None = None if name.startswith("__") and name.endswith("__"): name = name[2:-2] @@ -673,3 +751,18 @@ def infer_method_sig(name: str, self_var: str | None = None) -> list[ArgSig]: if args is None: args = [ArgSig(name="*args"), ArgSig(name="**kwargs")] return [ArgSig(name=self_var or "self")] + args + + +def infer_method_ret_type(name: str) -> str: + if name.startswith("__") and name.endswith("__"): + name = name[2:-2] + if name in ("float", "bool", "bytes", "int"): + return name + # Note: __eq__ and co may return arbitrary types, but bool is good enough for stubgen. + elif name in ("eq", "ne", "lt", "le", "gt", "ge", "contains"): + return "bool" + elif name in ("len", "hash", "sizeof", "trunc", "floor", "ceil"): + return "int" + elif name in ("init", "setitem"): + return "None" + return "Any" diff --git a/mypy/test/teststubgen.py b/mypy/test/teststubgen.py index c7b576f89389..47b664a46d72 100644 --- a/mypy/test/teststubgen.py +++ b/mypy/test/teststubgen.py @@ -38,7 +38,8 @@ generate_c_function_stub, generate_c_property_stub, generate_c_type_stub, - infer_method_sig, + infer_method_args, + infer_method_ret_type, is_c_property_readonly, ) from mypy.stubutil import common_dir_prefix, remove_misplaced_type_comments, walk_packages @@ -768,16 +769,18 @@ class StubgencSuite(unittest.TestCase): """ def test_infer_hash_sig(self) -> None: - assert_equal(infer_method_sig("__hash__"), [self_arg]) + assert_equal(infer_method_args("__hash__"), [self_arg]) + assert_equal(infer_method_ret_type("__hash__"), "int") def test_infer_getitem_sig(self) -> None: - assert_equal(infer_method_sig("__getitem__"), [self_arg, ArgSig(name="index")]) + assert_equal(infer_method_args("__getitem__"), [self_arg, ArgSig(name="index")]) def test_infer_setitem_sig(self) -> None: assert_equal( - infer_method_sig("__setitem__"), + infer_method_args("__setitem__"), [self_arg, ArgSig(name="index"), ArgSig(name="object")], ) + assert_equal(infer_method_ret_type("__setitem__"), "None") def test_infer_binary_op_sig(self) -> None: for op in ( @@ -794,11 +797,19 @@ def test_infer_binary_op_sig(self) -> None: "mul", "rmul", ): - assert_equal(infer_method_sig(f"__{op}__"), [self_arg, ArgSig(name="other")]) + assert_equal(infer_method_args(f"__{op}__"), [self_arg, ArgSig(name="other")]) + + def test_infer_equality_op_sig(self) -> None: + for op in ("eq", "ne", "lt", "le", "gt", "ge", "contains"): + assert_equal(infer_method_ret_type(f"__{op}__"), "bool") def test_infer_unary_op_sig(self) -> None: for op in ("neg", "pos"): - assert_equal(infer_method_sig(f"__{op}__"), [self_arg]) + assert_equal(infer_method_args(f"__{op}__"), [self_arg]) + + def test_infer_cast_sig(self) -> None: + for op in ("float", "bool", "bytes", "int"): + assert_equal(infer_method_ret_type(f"__{op}__"), op) def test_generate_c_type_stub_no_crash_for_object(self) -> None: output: list[str] = [] @@ -809,7 +820,8 @@ def test_generate_c_type_stub_no_crash_for_object(self) -> None: "alias", object, output, - imports, + imports=imports, + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(imports, []) @@ -828,7 +840,8 @@ class TestClassVariableCls: "C", TestClassVariableCls, output, - imports, + imports=imports, + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(imports, []) @@ -846,7 +859,8 @@ class TestClass(KeyError): "C", TestClass, output, - imports, + imports=imports, + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["class C(KeyError): ..."]) @@ -861,7 +875,8 @@ def test_generate_c_type_inheritance_same_module(self) -> None: "C", TestClass, output, - imports, + imports=imports, + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["class C(TestBaseClass): ..."]) @@ -881,7 +896,8 @@ class TestClass(argparse.Action): "C", TestClass, output, - imports, + imports=imports, + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["class C(argparse.Action): ..."]) @@ -899,7 +915,8 @@ class TestClass(type): "C", TestClass, output, - imports, + imports=imports, + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["class C(type): ..."]) @@ -919,10 +936,12 @@ def test(self, arg0: str) -> None: mod, "test", TestClass.test, - output, - imports, + output=output, + imports=imports, self_var="self", + cls=TestClass, class_name="TestClass", + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["def test(self, arg0: int) -> Any: ..."]) @@ -942,10 +961,12 @@ def test(self, arg0: str) -> None: mod, "test", TestClass.test, - output, - imports, + output=output, + imports=imports, self_var="self", + cls=TestClass, class_name="TestClass", + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["def test(self, arg0: int) -> Any: ..."]) @@ -964,10 +985,12 @@ def test(cls, arg0: str) -> None: mod, "test", TestClass.test, - output, - imports, + output=output, + imports=imports, self_var="cls", + cls=TestClass, class_name="TestClass", + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["@classmethod", "def test(cls, *args, **kwargs) -> Any: ..."]) @@ -990,10 +1013,12 @@ def test(self, arg0: str) -> None: mod, "test", TestClass.test, - output, - imports, + output=output, + imports=imports, self_var="cls", + cls=TestClass, class_name="TestClass", + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal( @@ -1023,10 +1048,12 @@ def test(self, arg0: str = "") -> None: mod, "test", TestClass.test, - output, - imports, + output=output, + imports=imports, self_var="self", + cls=TestClass, class_name="TestClass", + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["def test(self, arg0: str = ...) -> Any: ..."]) @@ -1048,22 +1075,23 @@ def test(arg0: str) -> None: mod, "test", test, - output, - imports, + output=output, + imports=imports, + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["def test(arg0: argparse.Action) -> Any: ..."]) assert_equal(imports, ["import argparse"]) - def test_generate_c_function_same_module_arg(self) -> None: - """Test that if argument references type from same module but using full path, no module + def test_generate_c_function_same_module(self) -> None: + """Test that if annotation references type from same module but using full path, no module will be imported, and type specification will be striped to local reference. """ # Provide different type in python spec than in docstring to make sure, that docstring # information is used. def test(arg0: str) -> None: """ - test(arg0: argparse.Action) + test(arg0: argparse.Action) -> argparse.Action """ output: list[str] = [] @@ -1073,19 +1101,20 @@ def test(arg0: str) -> None: mod, "test", test, - output, - imports, + output=output, + imports=imports, + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) - assert_equal(output, ["def test(arg0: Action) -> Any: ..."]) + assert_equal(output, ["def test(arg0: Action) -> Action: ..."]) assert_equal(imports, []) - def test_generate_c_function_other_module_ret(self) -> None: - """Test that if return type references type from other module, module will be imported.""" + def test_generate_c_function_other_module(self) -> None: + """Test that if annotation references type from other module, module will be imported.""" def test(arg0: str) -> None: """ - test(arg0: str) -> argparse.Action + test(arg0: argparse.Action) -> argparse.Action """ output: list[str] = [] @@ -1095,21 +1124,49 @@ def test(arg0: str) -> None: mod, "test", test, - output, - imports, + output=output, + imports=imports, + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) - assert_equal(output, ["def test(arg0: str) -> argparse.Action: ..."]) - assert_equal(imports, ["import argparse"]) + assert_equal(output, ["def test(arg0: argparse.Action) -> argparse.Action: ..."]) + assert_equal(set(imports), {"import argparse"}) - def test_generate_c_function_same_module_ret(self) -> None: - """Test that if return type references type from same module but using full path, - no module will be imported, and type specification will be striped to local reference. + def test_generate_c_function_same_module_nested(self) -> None: + """Test that if annotation references type from same module but using full path, no module + will be imported, and type specification will be stripped to local reference. """ + # Provide different type in python spec than in docstring to make sure, that docstring + # information is used. + def test(arg0: str) -> None: + """ + test(arg0: list[argparse.Action]) -> list[argparse.Action] + """ + output: list[str] = [] + imports: list[str] = [] + mod = ModuleType("argparse", "") + generate_c_function_stub( + mod, + "test", + test, + output=output, + imports=imports, + known_modules=[mod.__name__], + sig_generators=get_sig_generators(parse_options([])), + ) + assert_equal(output, ["def test(arg0: list[Action]) -> list[Action]: ..."]) + assert_equal(imports, []) + + def test_generate_c_function_same_module_compound(self) -> None: + """Test that if annotation references type from same module but using full path, no module + will be imported, and type specification will be stripped to local reference. + """ + # Provide different type in python spec than in docstring to make sure, that docstring + # information is used. def test(arg0: str) -> None: """ - test(arg0: str) -> argparse.Action + test(arg0: Union[argparse.Action, NoneType]) -> Tuple[argparse.Action, NoneType] """ output: list[str] = [] @@ -1119,13 +1176,38 @@ def test(arg0: str) -> None: mod, "test", test, - output, - imports, + output=output, + imports=imports, + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) - assert_equal(output, ["def test(arg0: str) -> Action: ..."]) + assert_equal(output, ["def test(arg0: Union[Action,None]) -> Tuple[Action,None]: ..."]) assert_equal(imports, []) + def test_generate_c_function_other_module_nested(self) -> None: + """Test that if annotation references type from other module, module will be imported, + and the import will be restricted to one of the known modules.""" + + def test(arg0: str) -> None: + """ + test(arg0: foo.bar.Action) -> other.Thing + """ + + output: list[str] = [] + imports: list[str] = [] + mod = ModuleType(self.__module__, "") + generate_c_function_stub( + mod, + "test", + test, + output=output, + imports=imports, + known_modules=["foo", "foo.spangle", "bar"], + sig_generators=get_sig_generators(parse_options([])), + ) + assert_equal(output, ["def test(arg0: foo.bar.Action) -> other.Thing: ..."]) + assert_equal(set(imports), {"import foo", "import other"}) + def test_generate_c_property_with_pybind11(self) -> None: """Signatures included by PyBind11 inside property.fget are read.""" @@ -1190,10 +1272,12 @@ def test(self, arg0: str) -> None: mod, "test", TestClass.test, - output, - imports, + output=output, + imports=imports, self_var="self", + cls=TestClass, class_name="TestClass", + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["def test(self, arg0: List[int]) -> Any: ..."]) @@ -1213,10 +1297,12 @@ def test(self, arg0: str) -> None: mod, "test", TestClass.test, - output, - imports, + output=output, + imports=imports, self_var="self", + cls=TestClass, class_name="TestClass", + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["def test(self, arg0: Dict[str,int]) -> Any: ..."]) @@ -1236,10 +1322,12 @@ def test(self, arg0: str) -> None: mod, "test", TestClass.test, - output, - imports, + output=output, + imports=imports, self_var="self", + cls=TestClass, class_name="TestClass", + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["def test(self, arg0: Dict[str,List[int]]) -> Any: ..."]) @@ -1259,10 +1347,12 @@ def test(self, arg0: str) -> None: mod, "test", TestClass.test, - output, - imports, + output=output, + imports=imports, self_var="self", + cls=TestClass, class_name="TestClass", + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["def test(self, arg0: Dict[argparse.Action,int]) -> Any: ..."]) @@ -1282,10 +1372,12 @@ def test(self, arg0: str) -> None: mod, "test", TestClass.test, - output, - imports, + output=output, + imports=imports, self_var="self", + cls=TestClass, class_name="TestClass", + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal(output, ["def test(self, arg0: Dict[str,argparse.Action]) -> Any: ..."]) @@ -1310,10 +1402,12 @@ def __init__(self, arg0: str) -> None: mod, "__init__", TestClass.__init__, - output, - imports, + output=output, + imports=imports, self_var="self", + cls=TestClass, class_name="TestClass", + known_modules=[mod.__name__], sig_generators=get_sig_generators(parse_options([])), ) assert_equal( @@ -1329,6 +1423,42 @@ def __init__(self, arg0: str) -> None: ) assert_equal(set(imports), {"from typing import overload"}) + def test_generate_c_type_with_overload_shiboken(self) -> None: + class TestClass: + """ + TestClass(self: TestClass, arg0: str) -> None + TestClass(self: TestClass, arg0: str, arg1: str) -> None + """ + + def __init__(self, arg0: str) -> None: + pass + + output: list[str] = [] + imports: list[str] = [] + mod = ModuleType(TestClass.__module__, "") + generate_c_function_stub( + mod, + "__init__", + TestClass.__init__, + output=output, + imports=imports, + self_var="self", + cls=TestClass, + class_name="TestClass", + known_modules=[mod.__name__], + sig_generators=get_sig_generators(parse_options([])), + ) + assert_equal( + output, + [ + "@overload", + "def __init__(self, arg0: str) -> None: ...", + "@overload", + "def __init__(self, arg0: str, arg1: str) -> None: ...", + ], + ) + assert_equal(set(imports), {"from typing import overload"}) + class ArgSigSuite(unittest.TestCase): def test_repr(self) -> None: