Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add products with typing.Annotation. #394

Merged
merged 7 commits into from
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ repos:
hooks:
- id: python-check-blanket-noqa
- id: python-check-mock-methods
- id: python-no-eval
exclude: expression.py
- id: python-no-log-warn
- id: python-use-type-annotations
- id: text-unicode-replacement-char
- repo: https://github.com/asottile/reorder-python-imports
rev: v3.9.0
Expand Down
1 change: 1 addition & 0 deletions docs/rtd_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies:
- rich
- sqlalchemy >=1.4.36
- tomli >=1.0.0
- typing_extensions

- pip:
- ../
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies:
- rich
- sqlalchemy >=1.4.36
- tomli >=1.0.0
- typing_extensions

# Misc
- black
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ convention = "numpy"


[tool.pytest.ini_options]
addopts = ["--doctest-modules"]
testpaths = ["src", "tests"]
markers = [
"wip: Tests that are work-in-progress.",
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ install_requires =
rich
sqlalchemy>=1.4.36
tomli>=1.0.0
typing-extensions
python_requires = >=3.8
include_package_data = True
package_dir =
Expand Down
140 changes: 140 additions & 0 deletions src/_pytask/_inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from __future__ import annotations

import functools
import sys
import types
from typing import Any
from typing import Callable
from typing import Mapping


__all__ = ["get_annotations"]


if sys.version_info >= (3, 10):
from inspect import get_annotations
else:

def get_annotations( # noqa: C901, PLR0912, PLR0915
obj: Callable[..., object] | type[Any] | types.ModuleType,
*,
globals: Mapping[str, Any] | None = None, # noqa: A002
locals: Mapping[str, Any] | None = None, # noqa: A002
eval_str: bool = False,
) -> dict[str, Any]:
"""Compute the annotations dict for an object.

obj may be a callable, class, or module.
Passing in an object of any other type raises TypeError.

Returns a dict. get_annotations() returns a new dict every time
it's called; calling it twice on the same object will return two
different but equivalent dicts.

This function handles several details for you:

* If eval_str is true, values of type str will
be un-stringized using eval(). This is intended
for use with stringized annotations
("from __future__ import annotations").
* If obj doesn't have an annotations dict, returns an
empty dict. (Functions and methods always have an
annotations dict; classes, modules, and other types of
callables may not.)
* Ignores inherited annotations on classes. If a class
doesn't have its own annotations dict, returns an empty dict.
* All accesses to object members and dict values are done
using getattr() and dict.get() for safety.
* Always, always, always returns a freshly-created dict.

eval_str controls whether or not values of type str are replaced
with the result of calling eval() on those values:

* If eval_str is true, eval() is called on values of type str.
* If eval_str is false (the default), values of type str are unchanged.

globals and locals are passed in to eval(); see the documentation
for eval() for more information. If either globals or locals is
None, this function may replace that value with a context-specific
default, contingent on type(obj):

* If obj is a module, globals defaults to obj.__dict__.
* If obj is a class, globals defaults to
sys.modules[obj.__module__].__dict__ and locals
defaults to the obj class namespace.
* If obj is a callable, globals defaults to obj.__globals__,
although if obj is a wrapped function (using
functools.update_wrapper()) it is first unwrapped.
"""
if isinstance(obj, type):
# class
obj_dict = getattr(obj, "__dict__", None)
if obj_dict and hasattr(obj_dict, "get"):
ann = obj_dict.get("__annotations__", None)
if isinstance(ann, types.GetSetDescriptorType):
ann = None
else:
ann = None

obj_globals = None
module_name = getattr(obj, "__module__", None)
if module_name:
module = sys.modules.get(module_name, None)
if module:
obj_globals = getattr(module, "__dict__", None)
obj_locals = dict(vars(obj))
unwrap = obj
elif isinstance(obj, types.ModuleType):
# module
ann = getattr(obj, "__annotations__", None)
obj_globals = obj.__dict__
obj_locals = None
unwrap = None
elif callable(obj):
# this includes types.Function, types.BuiltinFunctionType,
# types.BuiltinMethodType, functools.partial, functools.singledispatch,
# "class funclike" from Lib/test/test_inspect... on and on it goes.
ann = getattr(obj, "__annotations__", None)
obj_globals = getattr(obj, "__globals__", None)
obj_locals = None
unwrap = obj
else:
raise TypeError(f"{obj!r} is not a module, class, or callable.")

if ann is None:
return {}

if not isinstance(ann, dict):
raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")

if not ann:
return {}

if not eval_str:
return dict(ann)

if unwrap is not None:
while True:
if hasattr(unwrap, "__wrapped__"):
unwrap = unwrap.__wrapped__
continue
if isinstance(unwrap, functools.partial):
unwrap = unwrap.func
continue
break
if hasattr(unwrap, "__globals__"):
obj_globals = unwrap.__globals__

if globals is None:
globals = obj_globals # noqa: A001
if locals is None:
locals = obj_locals # noqa: A001

eval_func = eval
return_value = {
key: value
if not isinstance(value, str)
else eval_func(value, globals, locals)
for key, value in ann.items()
}
return return_value
6 changes: 1 addition & 5 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from _pytask.collect_utils import parse_dependencies_from_task_function
from _pytask.collect_utils import parse_nodes
from _pytask.collect_utils import parse_products_from_task_function
from _pytask.collect_utils import produces
from _pytask.config import hookimpl
from _pytask.config import IS_FILE_SYSTEM_CASE_SENSITIVE
from _pytask.console import console
Expand Down Expand Up @@ -178,10 +177,7 @@ def pytask_collect_task(
session, path, name, obj
)

if has_mark(obj, "produces"):
products = parse_nodes(session, path, name, obj, produces)
else:
products = parse_products_from_task_function(session, path, name, obj)
products = parse_products_from_task_function(session, path, name, obj)

markers = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []

Expand Down
105 changes: 97 additions & 8 deletions src/_pytask/collect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
from typing import Iterable
from typing import TYPE_CHECKING

from _pytask._inspect import get_annotations
from _pytask.exceptions import NodeNotCollectedError
from _pytask.mark_utils import has_mark
from _pytask.mark_utils import remove_marks
from _pytask.nodes import ProductType
from _pytask.nodes import PythonNode
from _pytask.shared import find_duplicates
from _pytask.task_utils import parse_keyword_arguments_from_signature_defaults
from attrs import define
from attrs import field
from pybaum.tree_util import tree_map
from typing_extensions import Annotated
from typing_extensions import get_origin


if TYPE_CHECKING:
Expand Down Expand Up @@ -211,8 +216,12 @@ def parse_dependencies_from_task_function(
kwargs = {**signature_defaults, **task_kwargs}
kwargs.pop("produces", None)

parameters_with_product_annot = _find_args_with_product_annotation(obj)

dependencies = {}
for name, value in kwargs.items():
if name in parameters_with_product_annot:
continue
parsed_value = tree_map(
lambda x: _collect_dependencies(session, path, name, x), value # noqa: B023
)
Expand All @@ -223,29 +232,108 @@ def parse_dependencies_from_task_function(
return dependencies


_ERROR_MULTIPLE_PRODUCT_DEFINITIONS = (
"The task uses multiple ways to define products. Products should be defined with "
"either\n\n- 'typing.Annotated[Path(...), Product]' (recommended)\n"
"- '@pytask.mark.task(kwargs={'produces': Path(...)})'\n"
"- as a default argument for 'produces': 'produces = Path(...)'\n"
"- '@pytask.mark.produces(Path(...))' (deprecated).\n\n"
"Read more about products in the documentation: https://tinyurl.com/yrezszr4."
)


def parse_products_from_task_function(
session: Session, path: Path, name: str, obj: Any
) -> dict[str, Any]:
"""Parse dependencies from task function."""
"""Parse products from task function.

Raises
------
NodeNotCollectedError
If multiple ways were used to specify products.

"""
has_produces_decorator = False
has_task_decorator = False
has_signature_default = False
has_annotation = False
out = {}

if has_mark(obj, "produces"):
has_produces_decorator = True
nodes = parse_nodes(session, path, name, obj, produces)
out = {"produces": nodes}

task_kwargs = obj.pytask_meta.kwargs if hasattr(obj, "pytask_meta") else {}
if "produces" in task_kwargs:
return tree_map(
collected_products = tree_map(
lambda x: _collect_product(session, path, name, x, is_string_allowed=True),
task_kwargs["produces"],
)
out = {"produces": collected_products}

parameters = inspect.signature(obj).parameters
if "produces" in parameters:

if not has_mark(obj, "task") and "produces" in parameters:
parameter = parameters["produces"]
if parameter.default is not parameter.empty:
has_signature_default = True
# Use _collect_new_node to not collect strings.
return tree_map(
collected_products = tree_map(
lambda x: _collect_product(
session, path, name, x, is_string_allowed=False
),
parameter.default,
)
return {}
out = {"produces": collected_products}

parameters_with_product_annot = _find_args_with_product_annotation(obj)
if parameters_with_product_annot:
has_annotation = True
for parameter_name in parameters_with_product_annot:
parameter = parameters[parameter_name]
if parameter.default is not parameter.empty:
# Use _collect_new_node to not collect strings.
collected_products = tree_map(
lambda x: _collect_product(
session, path, name, x, is_string_allowed=False
),
parameter.default,
)
out = {parameter_name: collected_products}

if (
sum(
(
has_produces_decorator,
has_task_decorator,
has_signature_default,
has_annotation,
)
)
>= 2 # noqa: PLR2004
):
raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS)

return out


def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]:
"""Find args with product annotation."""
annotations = get_annotations(func, eval_str=True)
metas = {
name: annotation.__metadata__
for name, annotation in annotations.items()
if get_origin(annotation) is Annotated
}

args_with_product_annot = []
for name, meta in metas.items():
has_product_annot = any(isinstance(i, ProductType) for i in meta)
if has_product_annot:
args_with_product_annot.append(name)

return args_with_product_annot


def _collect_old_dependencies(
Expand Down Expand Up @@ -331,9 +419,10 @@ def _collect_product(
# The parameter defaults only support Path objects.
if not isinstance(node, Path) and not is_string_allowed:
raise ValueError(
"If you use 'produces' as an argument of a task, it can only accept values "
"of type 'pathlib.Path' or the same value nested in "
f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}."
"If you use 'produces' as a function argument of a task and pass values as "
"function defaults, it can only accept values of type 'pathlib.Path' or "
"the same value nested in tuples, lists, and dictionaries. Here, "
f"{node!r} has type {type(node)}."
)

if isinstance(node, str):
Expand Down
5 changes: 3 additions & 2 deletions src/_pytask/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ def pytask_execute_task(session: Session, task: Task) -> bool:
for name, value in task.depends_on.items():
kwargs[name] = tree_map(lambda x: x.value, value)

if task.produces and "produces" in parameters:
kwargs["produces"] = tree_map(lambda x: x.value, task.produces)
for name, value in task.produces.items():
if name in parameters:
kwargs[name] = tree_map(lambda x: x.value, value)

task.execute(**kwargs)
return True
Expand Down
10 changes: 9 additions & 1 deletion src/_pytask/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
from _pytask.mark import Mark


__all__ = ["FilePathNode", "MetaNode", "Task"]
__all__ = ["FilePathNode", "MetaNode", "Product", "Task"]


@define(frozen=True)
class ProductType:
"""A class to mark products."""


Product = ProductType()


class MetaNode(metaclass=ABCMeta):
Expand Down