Skip to content

Commit

Permalink
Merge pull request #449 from dimbleby/type-annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
jaraco committed Apr 22, 2023
2 parents 512a3df + 2e78162 commit 705a757
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 43 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
v6.6.0
======

* #449: Expanded type annotations.

v6.5.1
======

Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
nitpick_ignore = [
# Workaround for #316
('py:class', 'importlib_metadata.EntryPoints'),
('py:class', 'importlib_metadata.PackagePath'),
('py:class', 'importlib_metadata.SelectableGroups'),
('py:class', 'importlib_metadata._meta._T'),
# Workaround for #435
Expand Down
95 changes: 53 additions & 42 deletions importlib_metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import zipp
import email
import inspect
import pathlib
import operator
import textwrap
Expand All @@ -14,12 +15,12 @@
import posixpath
import contextlib
import collections
import inspect

from . import _adapters, _meta, _py39compat
from ._collections import FreezableDefaultDict, Pair
from ._compat import (
NullFinder,
StrPath,
install,
pypy_partial,
)
Expand All @@ -31,8 +32,7 @@
from importlib import import_module
from importlib.abc import MetaPathFinder
from itertools import starmap
from typing import List, Mapping, Optional, cast

from typing import Iterable, List, Mapping, Optional, Set, cast

__all__ = [
'Distribution',
Expand All @@ -53,11 +53,11 @@
class PackageNotFoundError(ModuleNotFoundError):
"""The package was not found."""

def __str__(self):
def __str__(self) -> str:
return f"No package metadata was found for {self.name}"

@property
def name(self):
def name(self) -> str: # type: ignore[override]
(name,) = self.args
return name

Expand Down Expand Up @@ -123,7 +123,7 @@ def read(text, filter_=None):
yield Pair(name, value)

@staticmethod
def valid(line):
def valid(line: str):
return line and not line.startswith('#')


Expand Down Expand Up @@ -198,7 +198,7 @@ class EntryPoint(DeprecatedTuple):

dist: Optional['Distribution'] = None

def __init__(self, name, value, group):
def __init__(self, name: str, value: str, group: str) -> None:
vars(self).update(name=name, value=value, group=group)

def load(self):
Expand All @@ -212,18 +212,21 @@ def load(self):
return functools.reduce(getattr, attrs, module)

@property
def module(self):
def module(self) -> str:
match = self.pattern.match(self.value)
assert match is not None
return match.group('module')

@property
def attr(self):
def attr(self) -> str:
match = self.pattern.match(self.value)
assert match is not None
return match.group('attr')

@property
def extras(self):
def extras(self) -> List[str]:
match = self.pattern.match(self.value)
assert match is not None
return re.findall(r'\w+', match.group('extras') or '')

def _for(self, dist):
Expand Down Expand Up @@ -271,7 +274,7 @@ def __repr__(self):
f'group={self.group!r})'
)

def __hash__(self):
def __hash__(self) -> int:
return hash(self._key())


Expand All @@ -282,7 +285,7 @@ class EntryPoints(tuple):

__slots__ = ()

def __getitem__(self, name): # -> EntryPoint:
def __getitem__(self, name: str) -> EntryPoint: # type: ignore[override]
"""
Get the EntryPoint in self matching name.
"""
Expand All @@ -299,14 +302,14 @@ def select(self, **params):
return EntryPoints(ep for ep in self if _py39compat.ep_matches(ep, **params))

@property
def names(self):
def names(self) -> Set[str]:
"""
Return the set of all names of all entry points.
"""
return {ep.name for ep in self}

@property
def groups(self):
def groups(self) -> Set[str]:
"""
Return the set of all groups of all entry points.
"""
Expand All @@ -327,24 +330,28 @@ def _from_text(text):
class PackagePath(pathlib.PurePosixPath):
"""A reference to a path in a package"""

def read_text(self, encoding='utf-8'):
hash: Optional["FileHash"]
size: int
dist: "Distribution"

def read_text(self, encoding: str = 'utf-8') -> str: # type: ignore[override]
with self.locate().open(encoding=encoding) as stream:
return stream.read()

def read_binary(self):
def read_binary(self) -> bytes:
with self.locate().open('rb') as stream:
return stream.read()

def locate(self):
def locate(self) -> pathlib.Path:
"""Return a path-like object for this path"""
return self.dist.locate_file(self)


class FileHash:
def __init__(self, spec):
def __init__(self, spec: str) -> None:
self.mode, _, self.value = spec.partition('=')

def __repr__(self):
def __repr__(self) -> str:
return f'<FileHash mode: {self.mode} value: {self.value}>'


Expand Down Expand Up @@ -379,14 +386,14 @@ def read_text(self, filename) -> Optional[str]:
"""

@abc.abstractmethod
def locate_file(self, path):
def locate_file(self, path: StrPath) -> pathlib.Path:
"""
Given a path to a file in this distribution, return a path
to it.
"""

@classmethod
def from_name(cls, name: str):
def from_name(cls, name: str) -> "Distribution":
"""Return the Distribution for the given package name.
:param name: The name of the distribution package to search for.
Expand All @@ -399,12 +406,12 @@ def from_name(cls, name: str):
if not name:
raise ValueError("A distribution name is required.")
try:
return next(cls.discover(name=name))
return next(iter(cls.discover(name=name)))
except StopIteration:
raise PackageNotFoundError(name)

@classmethod
def discover(cls, **kwargs):
def discover(cls, **kwargs) -> Iterable["Distribution"]:
"""Return an iterable of Distribution objects for all packages.
Pass a ``context`` or pass keyword arguments for constructing
Expand All @@ -422,7 +429,7 @@ def discover(cls, **kwargs):
)

@staticmethod
def at(path):
def at(path: StrPath) -> "Distribution":
"""Return a Distribution for the indicated metadata path
:param path: a string or path-like object
Expand Down Expand Up @@ -457,7 +464,7 @@ def metadata(self) -> _meta.PackageMetadata:
return _adapters.Message(email.message_from_string(text))

@property
def name(self):
def name(self) -> str:
"""Return the 'Name' metadata for the distribution package."""
return self.metadata['Name']

Expand All @@ -467,16 +474,16 @@ def _normalized_name(self):
return Prepared.normalize(self.name)

@property
def version(self):
def version(self) -> str:
"""Return the 'Version' metadata for the distribution package."""
return self.metadata['Version']

@property
def entry_points(self):
def entry_points(self) -> EntryPoints:
return EntryPoints._from_text_for(self.read_text('entry_points.txt'), self)

@property
def files(self):
def files(self) -> Optional[List[PackagePath]]:
"""Files in this distribution.
:return: List of PackagePath for this distribution or None
Expand Down Expand Up @@ -561,7 +568,7 @@ def _read_files_egginfo_sources(self):
return text and map('"{}"'.format, text.splitlines())

@property
def requires(self):
def requires(self) -> Optional[List[str]]:
"""Generated requirements specified for this Distribution"""
reqs = self._read_dist_info_reqs() or self._read_egg_info_reqs()
return reqs and list(reqs)
Expand Down Expand Up @@ -640,7 +647,7 @@ def __init__(self, **kwargs):
vars(self).update(kwargs)

@property
def path(self):
def path(self) -> List[str]:
"""
The sequence of directory path that a distribution finder
should search.
Expand All @@ -651,7 +658,7 @@ def path(self):
return vars(self).get('path', sys.path)

@abc.abstractmethod
def find_distributions(self, context=Context()):
def find_distributions(self, context=Context()) -> Iterable[Distribution]:
"""
Find distributions.
Expand Down Expand Up @@ -786,7 +793,9 @@ class MetadataPathFinder(NullFinder, DistributionFinder):
of Python that do not have a PathFinder find_distributions().
"""

def find_distributions(self, context=DistributionFinder.Context()):
def find_distributions(
self, context=DistributionFinder.Context()
) -> Iterable["PathDistribution"]:
"""
Find distributions.
Expand All @@ -806,19 +815,19 @@ def _search_paths(cls, name, paths):
path.search(prepared) for path in map(FastPath, paths)
)

def invalidate_caches(cls):
def invalidate_caches(cls) -> None:
FastPath.__new__.cache_clear()


class PathDistribution(Distribution):
def __init__(self, path: SimplePath):
def __init__(self, path: SimplePath) -> None:
"""Construct a distribution.
:param path: SimplePath indicating the metadata directory.
"""
self._path = path

def read_text(self, filename):
def read_text(self, filename: StrPath) -> Optional[str]:
with suppress(
FileNotFoundError,
IsADirectoryError,
Expand All @@ -828,9 +837,11 @@ def read_text(self, filename):
):
return self._path.joinpath(filename).read_text(encoding='utf-8')

return None

read_text.__doc__ = Distribution.read_text.__doc__

def locate_file(self, path):
def locate_file(self, path: StrPath) -> pathlib.Path:
return self._path.parent / path

@property
Expand Down Expand Up @@ -863,7 +874,7 @@ def _name_from_stem(stem):
return name


def distribution(distribution_name):
def distribution(distribution_name) -> Distribution:
"""Get the ``Distribution`` instance for the named package.
:param distribution_name: The name of the distribution package as a string.
Expand All @@ -872,7 +883,7 @@ def distribution(distribution_name):
return Distribution.from_name(distribution_name)


def distributions(**kwargs):
def distributions(**kwargs) -> Iterable[Distribution]:
"""Get all ``Distribution`` instances in the current environment.
:return: An iterable of ``Distribution`` instances.
Expand All @@ -889,7 +900,7 @@ def metadata(distribution_name) -> _meta.PackageMetadata:
return Distribution.from_name(distribution_name).metadata


def version(distribution_name):
def version(distribution_name) -> str:
"""Get the version string for the named package.
:param distribution_name: The name of the distribution package to query.
Expand Down Expand Up @@ -923,7 +934,7 @@ def entry_points(**params) -> EntryPoints:
return EntryPoints(eps).select(**params)


def files(distribution_name):
def files(distribution_name) -> Optional[List[PackagePath]]:
"""Return a list of files for the named package.
:param distribution_name: The name of the distribution package to query.
Expand All @@ -932,11 +943,11 @@ def files(distribution_name):
return distribution(distribution_name).files


def requires(distribution_name):
def requires(distribution_name) -> Optional[List[str]]:
"""
Return a list of requirements for the named package.
:return: An iterator of requirements, suitable for
:return: An iterable of requirements, suitable for
packaging.requirement.Requirement.
"""
return distribution(distribution_name).requires
Expand Down
10 changes: 10 additions & 0 deletions importlib_metadata/_compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import sys
import platform

from typing import Union


__all__ = ['install', 'NullFinder', 'Protocol']

Expand Down Expand Up @@ -70,3 +73,10 @@ def pypy_partial(val):
"""
is_pypy = platform.python_implementation() == 'PyPy'
return val + is_pypy


if sys.version_info >= (3, 9):
StrPath = Union[str, os.PathLike[str]]
else:
# PathLike is only subscriptable at runtime in 3.9+
StrPath = Union[str, "os.PathLike[str]"] # pragma: no cover
2 changes: 1 addition & 1 deletion importlib_metadata/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SimplePath(Protocol[_T]):
A minimal subset of pathlib.Path required by PathDistribution.
"""

def joinpath(self) -> _T:
def joinpath(self, other: Union[str, _T]) -> _T:
... # pragma: no cover

def __truediv__(self, other: Union[str, _T]) -> _T:
Expand Down

0 comments on commit 705a757

Please sign in to comment.