Skip to content

Commit

Permalink
Prefer .pyi stubs (#2375)
Browse files Browse the repository at this point in the history
Exempt numpy from these changes for now
  • Loading branch information
jacobtylerwalls committed May 6, 2024
1 parent 2ec0115 commit 0984386
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 34 deletions.
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ What's New in astroid 3.2.0?
============================
Release date: TBA

* ``.pyi`` stub files are now preferred over ``.py`` files when resolving imports, (except for numpy).

Closes pylint-dev/#9185

* ``igetattr()`` returns the last same-named function in a class (instead of
the first). This avoids false positives in pylint with ``@overload``.

Expand Down
7 changes: 6 additions & 1 deletion astroid/interpreter/_import/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,14 @@ def find_module(
pass
submodule_path = sys.path

# We're looping on pyi first because if a pyi exists there's probably a reason
# (i.e. the code is hard or impossible to parse), so we take pyi into account
# But we're not quite ready to do this for numpy, see https://github.com/pylint-dev/astroid/pull/2375
suffixes = (".pyi", ".py", importlib.machinery.BYTECODE_SUFFIXES[0])
numpy_suffixes = (".py", ".pyi", importlib.machinery.BYTECODE_SUFFIXES[0])
for entry in submodule_path:
package_directory = os.path.join(entry, modname)
for suffix in (".py", ".pyi", importlib.machinery.BYTECODE_SUFFIXES[0]):
for suffix in numpy_suffixes if "numpy" in entry else suffixes:
package_file_name = "__init__" + suffix
file_path = os.path.join(package_directory, package_file_name)
if os.path.isfile(file_path):
Expand Down
9 changes: 5 additions & 4 deletions astroid/modutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@


if sys.platform.startswith("win"):
PY_SOURCE_EXTS = ("py", "pyw", "pyi")
PY_SOURCE_EXTS = ("pyi", "pyw", "py")
PY_COMPILED_EXTS = ("dll", "pyd")
else:
PY_SOURCE_EXTS = ("py", "pyi")
PY_SOURCE_EXTS = ("pyi", "py")
PY_COMPILED_EXTS = ("so",)


Expand Down Expand Up @@ -499,7 +499,7 @@ def get_source_file(filename: str, include_no_ext: bool = False) -> str:
base, orig_ext = os.path.splitext(filename)
if orig_ext == ".pyi" and os.path.exists(f"{base}{orig_ext}"):
return f"{base}{orig_ext}"
for ext in PY_SOURCE_EXTS:
for ext in PY_SOURCE_EXTS if "numpy" not in filename else reversed(PY_SOURCE_EXTS):
source_path = f"{base}.{ext}"
if os.path.exists(source_path):
return source_path
Expand Down Expand Up @@ -671,7 +671,8 @@ def _has_init(directory: str) -> str | None:
else return None.
"""
mod_or_pack = os.path.join(directory, "__init__")
for ext in (*PY_SOURCE_EXTS, "pyc", "pyo"):
exts = reversed(PY_SOURCE_EXTS) if "numpy" in directory else PY_SOURCE_EXTS
for ext in (*exts, "pyc", "pyo"):
if os.path.exists(mod_or_pack + "." + ext):
return mod_or_pack + "." + ext
return None
Expand Down
2 changes: 1 addition & 1 deletion requirements_full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Packages used to run additional tests
attrs
nose
numpy>=1.17.0; python_version<"3.11"
numpy>=1.17.0; python_version<"3.12"
python-dateutil
PyQt6
regex
Expand Down
30 changes: 3 additions & 27 deletions tests/brain/test_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class Eggs:
def test_attrs_transform(self) -> None:
"""Test brain for decorators of the 'attrs' package.
Package added support for 'attrs' a long side 'attr' in v21.3.0.
Package added support for 'attrs' alongside 'attr' in v21.3.0.
See: https://github.com/python-attrs/attrs/releases/tag/21.3.0
"""
module = astroid.parse(
Expand Down Expand Up @@ -153,36 +153,12 @@ class Eggs:
@frozen
class Legs:
d = attrs.field(default=attrs.Factory(dict))
m = Legs(d=1)
m.d['answer'] = 42
@define
class FooBar:
d = attrs.field(default=attrs.Factory(dict))
n = FooBar(d=1)
n.d['answer'] = 42
@mutable
class BarFoo:
d = attrs.field(default=attrs.Factory(dict))
o = BarFoo(d=1)
o.d['answer'] = 42
@my_mutable
class FooFoo:
d = attrs.field(default=attrs.Factory(dict))
p = FooFoo(d=1)
p.d['answer'] = 42
"""
)

for name in ("f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p"):
for name in ("f", "g", "h", "i", "j", "k", "l"):
should_be_unknown = next(module.getattr(name)[0].infer()).getattr("d")[0]
self.assertIsInstance(should_be_unknown, astroid.Unknown)
self.assertIsInstance(should_be_unknown, astroid.Unknown, name)

def test_special_attributes(self) -> None:
"""Make sure special attrs attributes exist"""
Expand Down
9 changes: 8 additions & 1 deletion tests/test_modutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,18 @@ def test(self) -> None:
def test_raise(self) -> None:
self.assertRaises(modutils.NoSourceFile, modutils.get_source_file, "whatever")

def test_(self) -> None:
def test_pyi(self) -> None:
package = resources.find("pyi_data")
module = os.path.join(package, "__init__.pyi")
self.assertEqual(modutils.get_source_file(module), os.path.normpath(module))

def test_pyi_preferred(self) -> None:
package = resources.find("pyi_data/find_test")
module = os.path.join(package, "__init__.py")
self.assertEqual(
modutils.get_source_file(module), os.path.normpath(module) + "i"
)


class IsStandardModuleTest(resources.SysPathSetup, unittest.TestCase):
"""
Expand Down

0 comments on commit 0984386

Please sign in to comment.