Skip to content

Commit

Permalink
Merge 70c55be into ae31b65
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasraabe committed Jul 8, 2023
2 parents ae31b65 + 70c55be commit 05fc3c6
Show file tree
Hide file tree
Showing 18 changed files with 369 additions and 22 deletions.
3 changes: 0 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ repos:
hooks:
- id: python-check-blanket-noqa
- id: python-check-mock-methods
- id: python-no-eval
exclude: expression.py
- id: python-no-log-warn
- id: python-use-type-annotations
- id: text-unicode-replacement-char
- repo: https://github.com/asottile/reorder-python-imports
rev: v3.9.0
Expand Down
1 change: 1 addition & 0 deletions docs/rtd_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies:
- rich
- sqlalchemy >=1.4.36
- tomli >=1.0.0
- typing_extensions

- pip:
- ../
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies:
- rich
- sqlalchemy >=1.4.36
- tomli >=1.0.0
- typing_extensions

# Misc
- black
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ convention = "numpy"


[tool.pytest.ini_options]
addopts = ["--doctest-modules"]
testpaths = ["src", "tests"]
markers = [
"wip: Tests that are work-in-progress.",
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ install_requires =
rich
sqlalchemy>=1.4.36
tomli>=1.0.0
typing-extensions
python_requires = >=3.8
include_package_data = True
package_dir =
Expand Down
140 changes: 140 additions & 0 deletions src/_pytask/_inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from __future__ import annotations

import functools
import sys
import types
from typing import Any
from typing import Callable
from typing import Mapping


__all__ = ["get_annotations"]


if sys.version_info >= (3, 10):
from inspect import get_annotations
else:

def get_annotations( # noqa: C901, PLR0912, PLR0915
obj: Callable[..., object] | type[Any] | types.ModuleType,
*,
globals: Mapping[str, Any] | None = None, # noqa: A002
locals: Mapping[str, Any] | None = None, # noqa: A002
eval_str: bool = False,
) -> dict[str, Any]:
"""Compute the annotations dict for an object.
obj may be a callable, class, or module.
Passing in an object of any other type raises TypeError.
Returns a dict. get_annotations() returns a new dict every time
it's called; calling it twice on the same object will return two
different but equivalent dicts.
This function handles several details for you:
* If eval_str is true, values of type str will
be un-stringized using eval(). This is intended
for use with stringized annotations
("from __future__ import annotations").
* If obj doesn't have an annotations dict, returns an
empty dict. (Functions and methods always have an
annotations dict; classes, modules, and other types of
callables may not.)
* Ignores inherited annotations on classes. If a class
doesn't have its own annotations dict, returns an empty dict.
* All accesses to object members and dict values are done
using getattr() and dict.get() for safety.
* Always, always, always returns a freshly-created dict.
eval_str controls whether or not values of type str are replaced
with the result of calling eval() on those values:
* If eval_str is true, eval() is called on values of type str.
* If eval_str is false (the default), values of type str are unchanged.
globals and locals are passed in to eval(); see the documentation
for eval() for more information. If either globals or locals is
None, this function may replace that value with a context-specific
default, contingent on type(obj):
* If obj is a module, globals defaults to obj.__dict__.
* If obj is a class, globals defaults to
sys.modules[obj.__module__].__dict__ and locals
defaults to the obj class namespace.
* If obj is a callable, globals defaults to obj.__globals__,
although if obj is a wrapped function (using
functools.update_wrapper()) it is first unwrapped.
"""
if isinstance(obj, type):
# class
obj_dict = getattr(obj, "__dict__", None)
if obj_dict and hasattr(obj_dict, "get"):
ann = obj_dict.get("__annotations__", None)
if isinstance(ann, types.GetSetDescriptorType):
ann = None
else:
ann = None

obj_globals = None
module_name = getattr(obj, "__module__", None)
if module_name:
module = sys.modules.get(module_name, None)
if module:
obj_globals = getattr(module, "__dict__", None)
obj_locals = dict(vars(obj))
unwrap = obj
elif isinstance(obj, types.ModuleType):
# module
ann = getattr(obj, "__annotations__", None)
obj_globals = obj.__dict__
obj_locals = None
unwrap = None
elif callable(obj):
# this includes types.Function, types.BuiltinFunctionType,
# types.BuiltinMethodType, functools.partial, functools.singledispatch,
# "class funclike" from Lib/test/test_inspect... on and on it goes.
ann = getattr(obj, "__annotations__", None)
obj_globals = getattr(obj, "__globals__", None)
obj_locals = None
unwrap = obj
else:
raise TypeError(f"{obj!r} is not a module, class, or callable.")

if ann is None:
return {}

if not isinstance(ann, dict):
raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")

if not ann:
return {}

if not eval_str:
return dict(ann)

if unwrap is not None:
while True:
if hasattr(unwrap, "__wrapped__"):
unwrap = unwrap.__wrapped__
continue
if isinstance(unwrap, functools.partial):
unwrap = unwrap.func
continue
break
if hasattr(unwrap, "__globals__"):
obj_globals = unwrap.__globals__

if globals is None:
globals = obj_globals # noqa: A001
if locals is None:
locals = obj_locals # noqa: A001

eval_func = eval
return_value = {
key: value
if not isinstance(value, str)
else eval_func(value, globals, locals)
for key, value in ann.items()
}
return return_value
6 changes: 1 addition & 5 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from _pytask.collect_utils import parse_dependencies_from_task_function
from _pytask.collect_utils import parse_nodes
from _pytask.collect_utils import parse_products_from_task_function
from _pytask.collect_utils import produces
from _pytask.config import hookimpl
from _pytask.config import IS_FILE_SYSTEM_CASE_SENSITIVE
from _pytask.console import console
Expand Down Expand Up @@ -178,10 +177,7 @@ def pytask_collect_task(
session, path, name, obj
)

if has_mark(obj, "produces"):
products = parse_nodes(session, path, name, obj, produces)
else:
products = parse_products_from_task_function(session, path, name, obj)
products = parse_products_from_task_function(session, path, name, obj)

markers = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []

Expand Down
105 changes: 97 additions & 8 deletions src/_pytask/collect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
from typing import Iterable
from typing import TYPE_CHECKING

from _pytask._inspect import get_annotations
from _pytask.exceptions import NodeNotCollectedError
from _pytask.mark_utils import has_mark
from _pytask.mark_utils import remove_marks
from _pytask.nodes import ProductType
from _pytask.nodes import PythonNode
from _pytask.shared import find_duplicates
from _pytask.task_utils import parse_keyword_arguments_from_signature_defaults
from attrs import define
from attrs import field
from pybaum.tree_util import tree_map
from typing_extensions import Annotated
from typing_extensions import get_origin


if TYPE_CHECKING:
Expand Down Expand Up @@ -211,8 +216,12 @@ def parse_dependencies_from_task_function(
kwargs = {**signature_defaults, **task_kwargs}
kwargs.pop("produces", None)

parameters_with_product_annot = _find_args_with_product_annotation(obj)

dependencies = {}
for name, value in kwargs.items():
if name in parameters_with_product_annot:
continue
parsed_value = tree_map(
lambda x: _collect_dependencies(session, path, name, x), value # noqa: B023
)
Expand All @@ -223,29 +232,108 @@ def parse_dependencies_from_task_function(
return dependencies


_ERROR_MULTIPLE_PRODUCT_DEFINITIONS = (
"The task uses multiple ways to define products. Products should be defined with "
"either\n\n- 'typing.Annotated[Path(...), Product]' (recommended)\n"
"- '@pytask.mark.task(kwargs={'produces': Path(...)})'\n"
"- as a default argument for 'produces': 'produces = Path(...)'\n"
"- '@pytask.mark.produces(Path(...))' (deprecated).\n\n"
"Read more about products in the documentation: https://tinyurl.com/yrezszr4."
)


def parse_products_from_task_function(
session: Session, path: Path, name: str, obj: Any
) -> dict[str, Any]:
"""Parse dependencies from task function."""
"""Parse products from task function.
Raises
------
NodeNotCollectedError
If multiple ways were used to specify products.
"""
has_produces_decorator = False
has_task_decorator = False
has_signature_default = False
has_annotation = False
out = {}

if has_mark(obj, "produces"):
has_produces_decorator = True
nodes = parse_nodes(session, path, name, obj, produces)
out = {"produces": nodes}

task_kwargs = obj.pytask_meta.kwargs if hasattr(obj, "pytask_meta") else {}
if "produces" in task_kwargs:
return tree_map(
collected_products = tree_map(
lambda x: _collect_product(session, path, name, x, is_string_allowed=True),
task_kwargs["produces"],
)
out = {"produces": collected_products}

parameters = inspect.signature(obj).parameters
if "produces" in parameters:

if not has_mark(obj, "task") and "produces" in parameters:
parameter = parameters["produces"]
if parameter.default is not parameter.empty:
has_signature_default = True
# Use _collect_new_node to not collect strings.
return tree_map(
collected_products = tree_map(
lambda x: _collect_product(
session, path, name, x, is_string_allowed=False
),
parameter.default,
)
return {}
out = {"produces": collected_products}

parameters_with_product_annot = _find_args_with_product_annotation(obj)
if parameters_with_product_annot:
has_annotation = True
for parameter_name in parameters_with_product_annot:
parameter = parameters[parameter_name]
if parameter.default is not parameter.empty:
# Use _collect_new_node to not collect strings.
collected_products = tree_map(
lambda x: _collect_product(
session, path, name, x, is_string_allowed=False
),
parameter.default,
)
out = {parameter_name: collected_products}

if (
sum(
(
has_produces_decorator,
has_task_decorator,
has_signature_default,
has_annotation,
)
)
>= 2 # noqa: PLR2004
):
raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS)

return out


def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]:
"""Find args with product annotation."""
annotations = get_annotations(func, eval_str=True)
metas = {
name: annotation.__metadata__
for name, annotation in annotations.items()
if get_origin(annotation) is Annotated
}

args_with_product_annot = []
for name, meta in metas.items():
has_product_annot = any(isinstance(i, ProductType) for i in meta)
if has_product_annot:
args_with_product_annot.append(name)

return args_with_product_annot


def _collect_old_dependencies(
Expand Down Expand Up @@ -331,9 +419,10 @@ def _collect_product(
# The parameter defaults only support Path objects.
if not isinstance(node, Path) and not is_string_allowed:
raise ValueError(
"If you use 'produces' as an argument of a task, it can only accept values "
"of type 'pathlib.Path' or the same value nested in "
f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}."
"If you use 'produces' as a function argument of a task and pass values as "
"function defaults, it can only accept values of type 'pathlib.Path' or "
"the same value nested in tuples, lists, and dictionaries. Here, "
f"{node!r} has type {type(node)}."
)

if isinstance(node, str):
Expand Down
5 changes: 3 additions & 2 deletions src/_pytask/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ def pytask_execute_task(session: Session, task: Task) -> bool:
for name, value in task.depends_on.items():
kwargs[name] = tree_map(lambda x: x.value, value)

if task.produces and "produces" in parameters:
kwargs["produces"] = tree_map(lambda x: x.value, task.produces)
for name, value in task.produces.items():
if name in parameters:
kwargs[name] = tree_map(lambda x: x.value, value)

task.execute(**kwargs)
return True
Expand Down
10 changes: 9 additions & 1 deletion src/_pytask/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
from _pytask.mark import Mark


__all__ = ["FilePathNode", "MetaNode", "Task"]
__all__ = ["FilePathNode", "MetaNode", "Product", "Task"]


@define(frozen=True)
class ProductType:
"""A class to mark products."""


Product = ProductType()


class MetaNode(metaclass=ABCMeta):
Expand Down
Loading

0 comments on commit 05fc3c6

Please sign in to comment.