2121from pytask import hookimpl
2222from pytask import Mark
2323from pytask import parse_warning_filter
24+ from pytask import PNode
2425from pytask import PTask
26+ from pytask import PythonNode
2527from pytask import remove_internal_traceback_frames_from_exc_info
2628from pytask import Session
2729from pytask import Task
@@ -114,7 +116,11 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
114116 warning_reports = []
115117 # A task raised an exception.
116118 else :
117- warning_reports , task_exception = future .result ()
119+ (
120+ python_nodes ,
121+ warning_reports ,
122+ task_exception ,
123+ ) = future .result ()
118124 session .warnings .extend (warning_reports )
119125 exc_info = (
120126 _parse_future_exception (future .exception ())
@@ -132,6 +138,19 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
132138 session .scheduler .done (task_name )
133139 else :
134140 task = session .dag .nodes [task_name ]["task" ]
141+
142+ # Update PythonNodes with the values from the future if
143+ # not threads.
144+ if (
145+ session .config ["parallel_backend" ]
146+ != ParallelBackend .THREADS
147+ ):
148+ task .produces = tree_map (
149+ _update_python_node ,
150+ task .produces ,
151+ python_nodes ,
152+ )
153+
135154 try :
136155 session .hook .pytask_execute_task_teardown (
137156 session = session , task = task
@@ -169,6 +188,12 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
169188 return None
170189
171190
191+ def _update_python_node (x : PNode , y : PythonNode | None ) -> PNode :
192+ if y :
193+ x .save (y .load ())
194+ return x
195+
196+
172197def _parse_future_exception (
173198 exc : BaseException | None ,
174199) -> tuple [type [BaseException ], BaseException , TracebackType ] | None :
@@ -240,7 +265,11 @@ def _execute_task( # noqa: PLR0913
240265 console_options : ConsoleOptions ,
241266 session_filterwarnings : tuple [str , ...],
242267 task_filterwarnings : tuple [Mark , ...],
243- ) -> tuple [list [WarningReport ], tuple [type [BaseException ], BaseException , str ] | None ]:
268+ ) -> tuple [
269+ PyTree [PythonNode | None ],
270+ list [WarningReport ],
271+ tuple [type [BaseException ], BaseException , str ] | None ,
272+ ]:
244273 """Unserialize and execute task.
245274
246275 This function receives bytes and unpickles them to a task which is them execute in a
@@ -251,9 +280,6 @@ def _execute_task( # noqa: PLR0913
251280 _patch_set_trace_and_breakpoint ()
252281
253282 with warnings .catch_warnings (record = True ) as log :
254- # mypy can't infer that record=True means log is not None; help it.
255- assert log is not None # noqa: S101
256-
257283 for arg in session_filterwarnings :
258284 warnings .filterwarnings (* parse_warning_filter (arg , escape = False ))
259285
@@ -301,7 +327,11 @@ def _execute_task( # noqa: PLR0913
301327 )
302328 )
303329
304- return warning_reports , processed_exc_info
330+ python_nodes = tree_map (
331+ lambda x : x if isinstance (x , PythonNode ) else None , task .produces
332+ )
333+
334+ return python_nodes , warning_reports , processed_exc_info
305335
306336
307337def _process_exception (
@@ -339,7 +369,9 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
339369
340370def _mock_processes_for_threads (
341371 func : Callable [..., Any ], ** kwargs : Any
342- ) -> tuple [list [Any ], tuple [type [BaseException ], BaseException , TracebackType ] | None ]:
372+ ) -> tuple [
373+ None , list [Any ], tuple [type [BaseException ], BaseException , TracebackType ] | None
374+ ]:
343375 """Mock execution function such that it returns the same as for processes.
344376
345377 The function for processes returns ``warning_reports`` and an ``exception``. With
@@ -354,7 +386,7 @@ def _mock_processes_for_threads(
354386 exc_info = sys .exc_info ()
355387 else :
356388 exc_info = None
357- return [], exc_info
389+ return None , [], exc_info
358390
359391
360392def _create_kwargs_for_task (task : PTask ) -> dict [str , PyTree [Any ]]:
@@ -395,7 +427,7 @@ def sleep(self) -> None:
395427 time .sleep (self .timings [self .timing_idx ])
396428
397429
398- def _get_module (func : Callable [..., Any ], path : Path ) -> ModuleType :
430+ def _get_module (func : Callable [..., Any ], path : Path | None ) -> ModuleType :
399431 """Get the module of a python function.
400432
401433 For Python <3.10, functools.partial does not set a `__module__` attribute which is
@@ -410,4 +442,6 @@ def _get_module(func: Callable[..., Any], path: Path) -> ModuleType:
410442 does not really support ``functools.partial``. Instead, use ``@task(kwargs=...)``.
411443
412444 """
413- return inspect .getmodule (func , path .as_posix ())
445+ if path :
446+ return inspect .getmodule (func , path .as_posix ())
447+ return inspect .getmodule (func )
0 commit comments