diff --git a/executorlib/task_scheduler/file/shared.py b/executorlib/task_scheduler/file/shared.py index 665ecbe4..98642e93 100644 --- a/executorlib/task_scheduler/file/shared.py +++ b/executorlib/task_scheduler/file/shared.py @@ -52,7 +52,6 @@ def done(self) -> bool: def execute_tasks_h5( future_queue: queue.Queue, - cache_directory: str, execute_function: Callable, resource_dict: dict, terminate_function: Optional[Callable] = None, @@ -65,7 +64,6 @@ def execute_tasks_h5( Args: future_queue (queue.Queue): The queue containing the tasks. - cache_directory (str): The directory to store the HDF5 files. resource_dict (dict): A dictionary of resources required by the task. With the following keys: - cores (int): number of MPI cores to be used for each function call - cwd (str/None): current working directory where the parallel python task is executed @@ -81,6 +79,7 @@ def execute_tasks_h5( """ memory_dict: dict = {} process_dict: dict = {} + cache_dir_dict: dict = {} file_name_dict: dict = {} while True: task_dict = None @@ -104,6 +103,7 @@ def execute_tasks_h5( {k: v for k, v in resource_dict.items() if k not in task_resource_dict} ) cache_key = task_resource_dict.pop("cache_key", None) + cache_directory = os.path.abspath(task_resource_dict.pop("cache_directory")) task_key, data_dict = serialize_funct_h5( fn=task_dict["fn"], fn_args=task_args, @@ -147,11 +147,14 @@ def execute_tasks_h5( cache_directory, task_key + "_o.h5" ) memory_dict[task_key] = task_dict["future"] + cache_dir_dict[task_key] = cache_directory future_queue.task_done() else: memory_dict = { key: _check_task_output( - task_key=key, future_obj=value, cache_directory=cache_directory + task_key=key, + future_obj=value, + cache_directory=cache_dir_dict[key], ) for key, value in memory_dict.items() if not value.done() diff --git a/executorlib/task_scheduler/file/task_scheduler.py b/executorlib/task_scheduler/file/task_scheduler.py index a373d902..65daffab 100644 --- a/executorlib/task_scheduler/file/task_scheduler.py +++ b/executorlib/task_scheduler/file/task_scheduler.py @@ -1,4 +1,3 @@ -import os from threading import Thread from typing import Callable, Optional @@ -27,7 +26,6 @@ class FileTaskScheduler(TaskSchedulerBase): def __init__( self, - cache_directory: str = "executorlib_cache", resource_dict: Optional[dict] = None, execute_function: Callable = execute_with_pysqa, terminate_function: Optional[Callable] = None, @@ -39,10 +37,10 @@ def __init__( Initialize the FileExecutor. Args: - cache_directory (str, optional): The directory to store cache files. Defaults to "executorlib_cache". resource_dict (dict): A dictionary of resources required by the task. With the following keys: - cores (int): number of MPI cores to be used for each function call - cwd (str/None): current working directory where the parallel python task is executed + - cache_directory (str): The directory to store cache files. execute_function (Callable, optional): The function to execute tasks. Defaults to execute_in_subprocess. terminate_function (Callable, optional): The function to terminate the tasks. pysqa_config_directory (str, optional): path to the pysqa config directory (only for pysqa based backend). @@ -53,6 +51,7 @@ def __init__( default_resource_dict = { "cores": 1, "cwd": None, + "cache_directory": "executorlib_cache", } if resource_dict is None: resource_dict = {} @@ -61,12 +60,10 @@ def __init__( ) if execute_function == execute_in_subprocess and terminate_function is None: terminate_function = terminate_subprocess - cache_directory_path = os.path.abspath(cache_directory) self._process_kwargs = { + "resource_dict": resource_dict, "future_queue": self._future_queue, "execute_function": execute_function, - "cache_directory": cache_directory_path, - "resource_dict": resource_dict, "terminate_function": terminate_function, "pysqa_config_directory": pysqa_config_directory, "backend": backend, @@ -81,11 +78,11 @@ def __init__( def create_file_executor( + resource_dict: dict, max_workers: Optional[int] = None, backend: str = "flux_submission", max_cores: Optional[int] = None, cache_directory: Optional[str] = None, - resource_dict: Optional[dict] = None, flux_executor=None, flux_executor_pmi_mode: Optional[str] = None, flux_executor_nesting: bool = False, @@ -96,8 +93,6 @@ def create_file_executor( init_function: Optional[Callable] = None, disable_dependencies: bool = False, ): - if cache_directory is None: - cache_directory = "executorlib_cache" if block_allocation: raise ValueError( "The option block_allocation is not available with the pysqa based backend." @@ -106,6 +101,8 @@ def create_file_executor( raise ValueError( "The option to specify an init_function is not available with the pysqa based backend." ) + if cache_directory is not None: + resource_dict["cache_directory"] = cache_directory check_flux_executor_pmi_mode(flux_executor_pmi_mode=flux_executor_pmi_mode) check_max_workers_and_cores(max_cores=max_cores, max_workers=max_workers) check_hostname_localhost(hostname_localhost=hostname_localhost) @@ -113,7 +110,6 @@ def create_file_executor( check_nested_flux_executor(nested_flux_executor=flux_executor_nesting) check_flux_log_files(flux_log_files=flux_log_files) return FileTaskScheduler( - cache_directory=cache_directory, resource_dict=resource_dict, pysqa_config_directory=pysqa_config_directory, backend=backend.split("_submission")[0], diff --git a/tests/test_cache_fileexecutor_serial.py b/tests/test_cache_fileexecutor_serial.py index 8b68df53..eb62c166 100644 --- a/tests/test_cache_fileexecutor_serial.py +++ b/tests/test_cache_fileexecutor_serial.py @@ -58,10 +58,12 @@ def test_executor_dependence_mixed(self): self.assertTrue(fs2.done()) def test_create_file_executor_error(self): + with self.assertRaises(TypeError): + create_file_executor() with self.assertRaises(ValueError): - create_file_executor(block_allocation=True) + create_file_executor(block_allocation=True, resource_dict={}) with self.assertRaises(ValueError): - create_file_executor(init_function=True) + create_file_executor(init_function=True, resource_dict={}) def test_executor_dependence_error(self): with self.assertRaises(ValueError): @@ -106,9 +108,8 @@ def test_executor_function(self): target=execute_tasks_h5, kwargs={ "future_queue": q, - "cache_directory": cache_dir, "execute_function": execute_in_subprocess, - "resource_dict": {"cores": 1, "cwd": None}, + "resource_dict": {"cores": 1, "cwd": None, "cache_directory": cache_dir}, "terminate_function": terminate_subprocess, }, ) @@ -147,9 +148,8 @@ def test_executor_function_dependence_kwargs(self): target=execute_tasks_h5, kwargs={ "future_queue": q, - "cache_directory": cache_dir, "execute_function": execute_in_subprocess, - "resource_dict": {"cores": 1, "cwd": None}, + "resource_dict": {"cores": 1, "cwd": None, "cache_directory": cache_dir}, "terminate_function": terminate_subprocess, }, ) @@ -188,9 +188,8 @@ def test_executor_function_dependence_args(self): target=execute_tasks_h5, kwargs={ "future_queue": q, - "cache_directory": cache_dir, "execute_function": execute_in_subprocess, - "resource_dict": {"cores": 1}, + "resource_dict": {"cores": 1, "cache_directory": cache_dir}, "terminate_function": terminate_subprocess, }, )