Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 2 additions & 40 deletions executorlib/shared/cache.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import hashlib
import importlib.util
import os
import queue
import re
import subprocess
import sys
from concurrent.futures import Future
from typing import Any, Tuple

import cloudpickle

from executorlib.shared.executor import get_command_path
from executorlib.shared.hdf import dump, get_output, load
from executorlib.shared.serialize import serialize_funct_h5


class FutureItem:
Expand Down Expand Up @@ -152,7 +149,7 @@ def execute_tasks_h5(
memory_dict=memory_dict,
file_name_dict=file_name_dict,
)
task_key, data_dict = _serialize_funct_h5(
task_key, data_dict = serialize_funct_h5(
task_dict["fn"], *task_args, **task_kwargs
)
if task_key not in memory_dict.keys():
Expand Down Expand Up @@ -228,41 +225,6 @@ def _get_execute_command(file_name: str, cores: int = 1) -> list:
return command_lst


def _get_hash(binary: bytes) -> str:
"""
Get the hash of a binary.

Args:
binary (bytes): The binary to be hashed.

Returns:
str: The hash of the binary.

"""
# Remove specification of jupyter kernel from hash to be deterministic
binary_no_ipykernel = re.sub(b"(?<=/ipykernel_)(.*)(?=/)", b"", binary)
return str(hashlib.md5(binary_no_ipykernel).hexdigest())


def _serialize_funct_h5(fn: callable, *args: Any, **kwargs: Any) -> Tuple[str, dict]:
"""
Serialize a function and its arguments and keyword arguments into an HDF5 file.

Args:
fn (callable): The function to be serialized.
*args (Any): The arguments of the function.
**kwargs (Any): The keyword arguments of the function.

Returns:
Tuple[str, dict]: A tuple containing the task key and the serialized data.

"""
binary_all = cloudpickle.dumps({"fn": fn, "args": args, "kwargs": kwargs})
task_key = fn.__name__ + _get_hash(binary=binary_all)
data = {"fn": fn, "args": args, "kwargs": kwargs}
return task_key, data


def _check_task_output(
task_key: str, future_obj: Future, cache_directory: str
) -> Future:
Expand Down
40 changes: 40 additions & 0 deletions executorlib/shared/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import hashlib
import re
from typing import Any, Tuple

import cloudpickle


def serialize_funct_h5(fn: callable, *args: Any, **kwargs: Any) -> Tuple[str, dict]:
"""
Serialize a function and its arguments and keyword arguments into an HDF5 file.
Args:
fn (callable): The function to be serialized.
*args (Any): The arguments of the function.
**kwargs (Any): The keyword arguments of the function.
Returns:
Tuple[str, dict]: A tuple containing the task key and the serialized data.
"""
binary_all = cloudpickle.dumps({"fn": fn, "args": args, "kwargs": kwargs})
task_key = fn.__name__ + _get_hash(binary=binary_all)
data = {"fn": fn, "args": args, "kwargs": kwargs}
return task_key, data
Comment on lines +8 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add input validation and error handling.

The function should validate inputs and handle serialization errors gracefully.

Consider applying these improvements:

 def serialize_funct_h5(fn: callable, *args: Any, **kwargs: Any) -> Tuple[str, dict]:
+    if not callable(fn):
+        raise TypeError("fn must be callable")
+    
+    try:
         binary_all = cloudpickle.dumps({"fn": fn, "args": args, "kwargs": kwargs})
+    except Exception as e:
+        raise ValueError(f"Failed to serialize function and arguments: {str(e)}")
+
     task_key = fn.__name__ + _get_hash(binary=binary_all)
     data = {"fn": fn, "args": args, "kwargs": kwargs}
     return task_key, data
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def serialize_funct_h5(fn: callable, *args: Any, **kwargs: Any) -> Tuple[str, dict]:
"""
Serialize a function and its arguments and keyword arguments into an HDF5 file.
Args:
fn (callable): The function to be serialized.
*args (Any): The arguments of the function.
**kwargs (Any): The keyword arguments of the function.
Returns:
Tuple[str, dict]: A tuple containing the task key and the serialized data.
"""
binary_all = cloudpickle.dumps({"fn": fn, "args": args, "kwargs": kwargs})
task_key = fn.__name__ + _get_hash(binary=binary_all)
data = {"fn": fn, "args": args, "kwargs": kwargs}
return task_key, data
def serialize_funct_h5(fn: callable, *args: Any, **kwargs: Any) -> Tuple[str, dict]:
"""
Serialize a function and its arguments and keyword arguments into an HDF5 file.
Args:
fn (callable): The function to be serialized.
*args (Any): The arguments of the function.
**kwargs (Any): The keyword arguments of the function.
Returns:
Tuple[str, dict]: A tuple containing the task key and the serialized data.
"""
if not callable(fn):
raise TypeError("fn must be callable")
try:
binary_all = cloudpickle.dumps({"fn": fn, "args": args, "kwargs": kwargs})
except Exception as e:
raise ValueError(f"Failed to serialize function and arguments: {str(e)}")
task_key = fn.__name__ + _get_hash(binary=binary_all)
data = {"fn": fn, "args": args, "kwargs": kwargs}
return task_key, data



def _get_hash(binary: bytes) -> str:
"""
Get the hash of a binary.
Args:
binary (bytes): The binary to be hashed.
Returns:
str: The hash of the binary.
"""
# Remove specification of jupyter kernel from hash to be deterministic
binary_no_ipykernel = re.sub(b"(?<=/ipykernel_)(.*)(?=/)", b"", binary)
return str(hashlib.md5(binary_no_ipykernel).hexdigest())
Comment on lines +27 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding input validation and improving hash reliability.

The hash function could be more robust with input validation and a more reliable kernel path handling.

Consider these improvements:

 def _get_hash(binary: bytes) -> str:
+    if not isinstance(binary, bytes):
+        raise TypeError("Input must be bytes")
+
     # Remove specification of jupyter kernel from hash to be deterministic
-    binary_no_ipykernel = re.sub(b"(?<=/ipykernel_)(.*)(?=/)", b"", binary)
+    # Handle both Windows and Unix-style paths
+    binary_no_ipykernel = re.sub(
+        b"(?<=/ipykernel_|\\\\ipykernel_)(.*)(?=/|\\\\)",
+        b"",
+        binary
+    )
     return str(hashlib.md5(binary_no_ipykernel).hexdigest())

Also, consider adding a comment explaining why MD5 is sufficient for this use case:

# MD5 is used here for generating cache keys, not for security purposes

12 changes: 6 additions & 6 deletions tests/test_cache_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@


try:
from executorlib.shared.hdf import dump
from executorlib.shared.cache import (
FutureItem,
execute_task_in_file,
_check_task_output,
_serialize_funct_h5,
)
from executorlib.shared.cache import execute_task_in_file
from executorlib.shared.hdf import dump
from executorlib.shared.serialize import serialize_funct_h5

skip_h5io_test = False
except ImportError:
Expand All @@ -29,7 +29,7 @@ class TestSharedFunctions(unittest.TestCase):
def test_execute_function_mixed(self):
cache_directory = os.path.abspath("cache")
os.makedirs(cache_directory, exist_ok=True)
task_key, data_dict = _serialize_funct_h5(
task_key, data_dict = serialize_funct_h5(
my_funct,
1,
b=2,
Expand All @@ -52,7 +52,7 @@ def test_execute_function_mixed(self):
def test_execute_function_args(self):
cache_directory = os.path.abspath("cache")
os.makedirs(cache_directory, exist_ok=True)
task_key, data_dict = _serialize_funct_h5(
task_key, data_dict = serialize_funct_h5(
my_funct,
1,
2,
Comment on lines +55 to 58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider refactoring duplicate test code.

The test methods test_execute_function_args and test_execute_function_kwargs share significant code with test_execute_function_mixed. Consider extracting common test logic into a helper method to reduce duplication.

def _run_serialization_test(self, *args, **kwargs):
    cache_directory = os.path.abspath("cache")
    os.makedirs(cache_directory, exist_ok=True)
    
    task_key, data_dict = serialize_funct_h5(my_funct, *args, **kwargs)
    file_name = os.path.join(cache_directory, task_key + ".h5in")
    dump(file_name=file_name, data_dict=data_dict)
    execute_task_in_file(file_name=file_name)
    
    future_obj = Future()
    _check_task_output(
        task_key=task_key, 
        future_obj=future_obj, 
        cache_directory=cache_directory
    )
    
    self.assertTrue(future_obj.done())
    self.assertEqual(future_obj.result(), 3)
    future_file_obj = FutureItem(
        file_name=os.path.join(cache_directory, task_key + ".h5out")
    )
    self.assertTrue(future_file_obj.done())
    self.assertEqual(future_file_obj.result(), 3)

Then use it in your test methods:

def test_execute_function_mixed(self):
    self._run_serialization_test(1, b=2)

def test_execute_function_args(self):
    self._run_serialization_test(1, 2)

def test_execute_function_kwargs(self):
    self._run_serialization_test(a=1, b=2)

Also applies to: 78-81

Expand All @@ -75,7 +75,7 @@ def test_execute_function_args(self):
def test_execute_function_kwargs(self):
cache_directory = os.path.abspath("cache")
os.makedirs(cache_directory, exist_ok=True)
task_key, data_dict = _serialize_funct_h5(
task_key, data_dict = serialize_funct_h5(
my_funct,
a=1,
b=2,
Expand Down
Loading