diff --git a/executorlib/shared/cache.py b/executorlib/shared/cache.py index a2b5975b..07fb36eb 100644 --- a/executorlib/shared/cache.py +++ b/executorlib/shared/cache.py @@ -1,17 +1,14 @@ -import hashlib import importlib.util import os import queue -import re import subprocess import sys from concurrent.futures import Future from typing import Any, Tuple -import cloudpickle - from executorlib.shared.executor import get_command_path from executorlib.shared.hdf import dump, get_output, load +from executorlib.shared.serialize import serialize_funct_h5 class FutureItem: @@ -152,7 +149,7 @@ def execute_tasks_h5( memory_dict=memory_dict, file_name_dict=file_name_dict, ) - task_key, data_dict = _serialize_funct_h5( + task_key, data_dict = serialize_funct_h5( task_dict["fn"], *task_args, **task_kwargs ) if task_key not in memory_dict.keys(): @@ -228,41 +225,6 @@ def _get_execute_command(file_name: str, cores: int = 1) -> list: return command_lst -def _get_hash(binary: bytes) -> str: - """ - Get the hash of a binary. - - Args: - binary (bytes): The binary to be hashed. - - Returns: - str: The hash of the binary. - - """ - # Remove specification of jupyter kernel from hash to be deterministic - binary_no_ipykernel = re.sub(b"(?<=/ipykernel_)(.*)(?=/)", b"", binary) - return str(hashlib.md5(binary_no_ipykernel).hexdigest()) - - -def _serialize_funct_h5(fn: callable, *args: Any, **kwargs: Any) -> Tuple[str, dict]: - """ - Serialize a function and its arguments and keyword arguments into an HDF5 file. - - Args: - fn (callable): The function to be serialized. - *args (Any): The arguments of the function. - **kwargs (Any): The keyword arguments of the function. - - Returns: - Tuple[str, dict]: A tuple containing the task key and the serialized data. - - """ - binary_all = cloudpickle.dumps({"fn": fn, "args": args, "kwargs": kwargs}) - task_key = fn.__name__ + _get_hash(binary=binary_all) - data = {"fn": fn, "args": args, "kwargs": kwargs} - return task_key, data - - def _check_task_output( task_key: str, future_obj: Future, cache_directory: str ) -> Future: diff --git a/executorlib/shared/serialize.py b/executorlib/shared/serialize.py new file mode 100644 index 00000000..4851dd99 --- /dev/null +++ b/executorlib/shared/serialize.py @@ -0,0 +1,40 @@ +import hashlib +import re +from typing import Any, Tuple + +import cloudpickle + + +def serialize_funct_h5(fn: callable, *args: Any, **kwargs: Any) -> Tuple[str, dict]: + """ + Serialize a function and its arguments and keyword arguments into an HDF5 file. + + Args: + fn (callable): The function to be serialized. + *args (Any): The arguments of the function. + **kwargs (Any): The keyword arguments of the function. + + Returns: + Tuple[str, dict]: A tuple containing the task key and the serialized data. + + """ + binary_all = cloudpickle.dumps({"fn": fn, "args": args, "kwargs": kwargs}) + task_key = fn.__name__ + _get_hash(binary=binary_all) + data = {"fn": fn, "args": args, "kwargs": kwargs} + return task_key, data + + +def _get_hash(binary: bytes) -> str: + """ + Get the hash of a binary. + + Args: + binary (bytes): The binary to be hashed. + + Returns: + str: The hash of the binary. + + """ + # Remove specification of jupyter kernel from hash to be deterministic + binary_no_ipykernel = re.sub(b"(?<=/ipykernel_)(.*)(?=/)", b"", binary) + return str(hashlib.md5(binary_no_ipykernel).hexdigest()) diff --git a/tests/test_cache_shared.py b/tests/test_cache_shared.py index de763835..ccc16b49 100644 --- a/tests/test_cache_shared.py +++ b/tests/test_cache_shared.py @@ -5,13 +5,13 @@ try: - from executorlib.shared.hdf import dump from executorlib.shared.cache import ( FutureItem, + execute_task_in_file, _check_task_output, - _serialize_funct_h5, ) - from executorlib.shared.cache import execute_task_in_file + from executorlib.shared.hdf import dump + from executorlib.shared.serialize import serialize_funct_h5 skip_h5io_test = False except ImportError: @@ -29,7 +29,7 @@ class TestSharedFunctions(unittest.TestCase): def test_execute_function_mixed(self): cache_directory = os.path.abspath("cache") os.makedirs(cache_directory, exist_ok=True) - task_key, data_dict = _serialize_funct_h5( + task_key, data_dict = serialize_funct_h5( my_funct, 1, b=2, @@ -52,7 +52,7 @@ def test_execute_function_mixed(self): def test_execute_function_args(self): cache_directory = os.path.abspath("cache") os.makedirs(cache_directory, exist_ok=True) - task_key, data_dict = _serialize_funct_h5( + task_key, data_dict = serialize_funct_h5( my_funct, 1, 2, @@ -75,7 +75,7 @@ def test_execute_function_args(self): def test_execute_function_kwargs(self): cache_directory = os.path.abspath("cache") os.makedirs(cache_directory, exist_ok=True) - task_key, data_dict = _serialize_funct_h5( + task_key, data_dict = serialize_funct_h5( my_funct, a=1, b=2,