Runner get data#363
Conversation
This design document describes the get_data helper function that allows applications to retrieve the real input/output data of Runner tasks. The function: - Takes raw task input/output bytes (ObjectRef encoded data) - Decodes and retrieves data from the Flame cache - Resolves nested ObjectRef instances in RunnerRequest args/kwargs - Returns a structured dictionary with the resolved data Closes #1
Add comprehensive test suite for the `get_data` helper function in flamepy.runner as specified in RFE001-runner-get-data HLD. Test Coverage: - TC-GD-001: Basic task input retrieval (function with positional args) - TC-GD-002: Basic task output retrieval - TC-GD-003: Task input with keyword arguments - TC-GD-004: Task input with mixed args and kwargs - TC-GD-005: Task input with ObjectRef arguments (resolved) - TC-GD-006: Class method invocation input - TC-GD-007: Error handling - invalid data format - TC-GD-008: Error handling - empty bytes - TC-GD-009: Multiple tasks inspection - TC-GD-010: Task with no arguments - TC-GD-011: End-to-end workflow test Relates to #1
- Add structured response types (TaskInputData, TaskOutputData) - Add specific error types (DecodeError, CacheRetrievalError, DataFormatError) - Implement recursive resolution for nested structures (lists, dicts, tuples) - Remove unused imports (Optional, Tuple) - Add metadata field to response for object_ref_key - Add test cases for nested data structures and error types - Export new types from flamepy.runner module
TC-GD-007 and TC-GD-008 now expect DecodeError instead of ValueError/TypeError, matching the actual implementation. - Import DecodeError from flamepy.runner - Update test_get_data_invalid_data_format to expect DecodeError - Update test_get_data_empty_bytes to expect DecodeError - Update docstrings to reflect the correct exception type
…ceptions - Rename sdk/python/src/flamepy/runner/data.py to helper.py - Replace individual exceptions (GetDataError, DecodeError, CacheRetrievalError, DataFormatError) with single RunnerError class using ErrorType enum - Update __init__.py to export only get_data, RunnerError, and ErrorType (removed TaskInputData, TaskOutputData from public API) - Update e2e/tests/test_get_data.py to use new exception handling Closes xflops#2
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the observability and debugging capabilities within Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a get_data helper function to inspect task inputs and outputs, which is a valuable addition for debugging and observability. However, the implementation contains a critical insecure deserialization vulnerability. The function decodes an ObjectRef from user-supplied bytes and uses it to fetch and unpickle data from a remote endpoint specified within that ObjectRef, allowing an attacker to trigger Remote Code Execution (RCE) by providing a malicious ObjectRef. It is highly recommended to validate remote endpoints against a trusted allow-list and avoid using pickle or cloudpickle for data that can be influenced by untrusted sources. Additionally, the review suggests improving tests for robustness, ensuring design documentation aligns with the implementation, and minor code cleanup.
| def get_data(data: bytes) -> Dict[str, Any]: | ||
| """Retrieve the real data from task input or output. | ||
|
|
||
| This function takes the raw bytes from a Flame task's input or output, | ||
| decodes the ObjectRef, retrieves the data from cache, and resolves | ||
| any nested ObjectRef instances to their actual values. | ||
|
|
||
| Args: | ||
| data: Raw bytes from task input or output. This is expected to be | ||
| an encoded ObjectRef pointing to either: | ||
| - A RunnerRequest (for task input) | ||
| - A result object (for task output) | ||
|
|
||
| Returns: | ||
| A dictionary containing the resolved data: | ||
|
|
||
| For task input (RunnerRequest): | ||
| { | ||
| "type": "input", | ||
| "method": str | None, # Method name or None for callable | ||
| "args": tuple | None, # Resolved positional arguments | ||
| "kwargs": dict | None, # Resolved keyword arguments | ||
| "metadata": dict # Additional metadata | ||
| } | ||
|
|
||
| For task output (result): | ||
| { | ||
| "type": "output", | ||
| "result": Any, # The actual result value | ||
| "metadata": dict # Additional metadata | ||
| } | ||
|
|
||
| Raises: | ||
| RunnerError: With error_type indicating the specific error: | ||
| - ErrorType.DECODE_ERROR: If the data cannot be decoded as ObjectRef | ||
| - ErrorType.CACHE_RETRIEVAL_ERROR: If the object cannot be retrieved from cache | ||
| - ErrorType.DATA_FORMAT_ERROR: If the data format is not recognized | ||
|
|
||
| Example: | ||
| >>> from flamepy.runner import get_data | ||
| >>> from flamepy.core import get_session | ||
| >>> | ||
| >>> # Get a session and its tasks | ||
| >>> session = get_session("my-session-id") | ||
| >>> for task in session.tasks: | ||
| ... if task.input: | ||
| ... input_data = get_data(task.input) | ||
| ... print(f"Task {task.id} input: {input_data}") | ||
| ... if task.output: | ||
| ... output_data = get_data(task.output) | ||
| ... print(f"Task {task.id} output: {output_data}") | ||
| """ | ||
| # Step 1: Decode ObjectRef from bytes | ||
| try: | ||
| object_ref = ObjectRef.decode(data) | ||
| except Exception as e: | ||
| raise RunnerError( | ||
| ErrorType.DECODE_ERROR, | ||
| f"Failed to decode ObjectRef from data: {e}", | ||
| cause=e, | ||
| ) | ||
|
|
||
| # Step 2: Retrieve object from cache | ||
| try: | ||
| cached_data = get_object(object_ref) | ||
| except Exception as e: | ||
| raise RunnerError( | ||
| ErrorType.CACHE_RETRIEVAL_ERROR, | ||
| f"Failed to retrieve object from cache: {e}", | ||
| cause=e, | ||
| key=getattr(object_ref, "key", None), | ||
| ) | ||
|
|
||
| # Step 3: Check if it's serialized data (bytes) that needs unpickling | ||
| if isinstance(cached_data, bytes): | ||
| try: | ||
| cached_data = cloudpickle.loads(cached_data) | ||
| except Exception: | ||
| # Not pickled data, use as-is | ||
| pass | ||
|
|
||
| # Step 4: Determine type and process accordingly | ||
| if isinstance(cached_data, RunnerRequest): | ||
| # This is task input | ||
| return _process_runner_request(cached_data, object_ref) | ||
| else: | ||
| # This is task output (result) | ||
| output_data = TaskOutputData( | ||
| result=cached_data, | ||
| metadata={"object_ref_key": object_ref.key}, | ||
| ) | ||
| return output_data.to_dict() | ||
|
|
||
|
|
||
| def _process_runner_request(request: RunnerRequest, object_ref: ObjectRef = None) -> Dict[str, Any]: | ||
| """Process a RunnerRequest and resolve any ObjectRef instances. | ||
|
|
||
| Args: | ||
| request: The RunnerRequest to process. | ||
| object_ref: Optional ObjectRef for metadata. | ||
|
|
||
| Returns: | ||
| Dictionary with resolved input data. | ||
| """ | ||
| # Resolve args (recursively handles nested structures) | ||
| resolved_args = None | ||
| if request.args is not None: | ||
| resolved_args = tuple(_resolve_value(arg) for arg in request.args) | ||
|
|
||
| # Resolve kwargs (recursively handles nested structures) | ||
| resolved_kwargs = None | ||
| if request.kwargs is not None: | ||
| resolved_kwargs = {key: _resolve_value(value) for key, value in request.kwargs.items()} | ||
|
|
||
| metadata = {} | ||
| if object_ref is not None: | ||
| metadata["object_ref_key"] = object_ref.key | ||
|
|
||
| input_data = TaskInputData( | ||
| method=request.method, | ||
| args=resolved_args, | ||
| kwargs=resolved_kwargs, | ||
| metadata=metadata, | ||
| ) | ||
| return input_data.to_dict() | ||
|
|
||
|
|
||
| def _resolve_value(value: Any, max_depth: int = 10, _current_depth: int = 0) -> Any: | ||
| """Resolve a value, fetching from cache if it's an ObjectRef. | ||
|
|
||
| Recursively resolves nested structures (lists, dicts, tuples) that may | ||
| contain ObjectRef instances. | ||
|
|
||
| Args: | ||
| value: The value to resolve. | ||
| max_depth: Maximum recursion depth to prevent infinite loops. | ||
| _current_depth: Current recursion depth (internal use). | ||
|
|
||
| Returns: | ||
| The resolved value with all ObjectRef instances replaced by their actual data. | ||
|
|
||
| Raises: | ||
| RunnerError: With ErrorType.CACHE_RETRIEVAL_ERROR if an ObjectRef cannot be resolved. | ||
| """ | ||
| # Prevent infinite recursion | ||
| if _current_depth > max_depth: | ||
| return value | ||
|
|
||
| # Handle ObjectRef directly | ||
| if isinstance(value, ObjectRef): | ||
| try: | ||
| return get_object(value) | ||
| except Exception as e: | ||
| raise RunnerError( | ||
| ErrorType.CACHE_RETRIEVAL_ERROR, | ||
| f"Failed to resolve ObjectRef: {e}", | ||
| cause=e, | ||
| key=getattr(value, "key", None), | ||
| ) | ||
|
|
||
| # Handle bytes that might be encoded ObjectRef | ||
| if isinstance(value, bytes): | ||
| try: | ||
| object_ref = ObjectRef.decode(value) | ||
| return get_object(object_ref) | ||
| except Exception: | ||
| # Not an ObjectRef, return as-is | ||
| return value | ||
|
|
||
| # Handle lists - recursively resolve each element | ||
| if isinstance(value, list): | ||
| return [_resolve_value(item, max_depth, _current_depth + 1) for item in value] | ||
|
|
||
| # Handle tuples - recursively resolve each element | ||
| if isinstance(value, tuple): | ||
| return tuple(_resolve_value(item, max_depth, _current_depth + 1) for item in value) | ||
|
|
||
| # Handle dicts - recursively resolve each value | ||
| if isinstance(value, dict): | ||
| return {k: _resolve_value(v, max_depth, _current_depth + 1) for k, v in value.items()} | ||
|
|
||
| # Return other types as-is | ||
| return value |
There was a problem hiding this comment.
The get_data function and its recursive helper _resolve_value are vulnerable to insecure deserialization. The function takes raw bytes as input, decodes them into an ObjectRef (which contains a remote endpoint URL), and then retrieves and deserializes data from that endpoint using cloudpickle.loads (both directly on line 188 and indirectly via get_object on lines 176, 264, and 277). Since the remote endpoint is controlled by the input data, an attacker can provide a malicious ObjectRef pointing to a server they control. This server can then return a malicious pickle payload that executes arbitrary code when deserialized, leading to Remote Code Execution (RCE).
| Raises: | ||
| ValueError: If the data cannot be decoded or retrieved from cache | ||
| TypeError: If the data format is not recognized |
There was a problem hiding this comment.
The Raises section in the docstring is out of sync with the actual implementation. The code raises a custom RunnerError with specific ErrorType enums, not ValueError or TypeError. Please update the design document to reflect the implemented error handling strategy for consistency. The docstring in sdk/python/src/flamepy/runner/helper.py can be used as a reference.
|
|
||
| # At minimum, verify we can retrieve input data without errors | ||
| assert len(tasks) >= 3, f"Expected at least 3 tasks, got {len(tasks)}" |
There was a problem hiding this comment.
This test checks if an ObjectRef argument is resolved, but it's missing a final assertion to verify that the resolution actually occurred. The found_resolved flag is set but its value is never checked, which could lead to the test passing silently even if the feature is broken. Please add an assertion after the loop to confirm that found_resolved is True.
| # At minimum, verify we can retrieve input data without errors | |
| assert len(tasks) >= 3, f"Expected at least 3 tasks, got {len(tasks)}" | |
| assert found_resolved, "Did not find a task with a resolved ObjectRef argument." | |
| # At minimum, verify we can retrieve input data without errors | |
| assert len(tasks) >= 3, f"Expected at least 3 tasks, got {len(tasks)}" |
| output_results_found.add(output_data["result"]) | ||
|
|
||
| # Verify we found the expected results | ||
| assert 3 in output_results_found or 30 in output_results_found or 300 in output_results_found |
There was a problem hiding this comment.
The assertion assert 3 in output_results_found or 30 in output_results_found or 300 in output_results_found is not strict enough. It will pass if only one of the tasks succeeded, potentially masking issues with the other tasks. To ensure all tasks are correctly processed, you should verify that all expected inputs and outputs were found.
| assert 3 in output_results_found or 30 in output_results_found or 300 in output_results_found | |
| expected_outputs = {3, 30, 300} | |
| assert expected_outputs.issubset(output_results_found), f"Missing outputs: {expected_outputs - output_results_found}" | |
| expected_inputs = {(1, 2), (10, 20), (100, 200)} | |
| assert expected_inputs.issubset(input_args_found), f"Missing inputs: {expected_inputs - input_args_found}" |
| def test_get_data_nested_list_with_objectref(check_package_config, check_flmrun_app): | ||
| """TC-GD-012: Test get_data resolves ObjectRef in nested lists. |
There was a problem hiding this comment.
The test name test_get_data_nested_list_with_objectref and its docstring are misleading. The test implementation does not involve a nested list; instead, it re-tests a simple ObjectRef resolution scenario similar to test_get_data_task_input_objectref_resolved. Additionally, the assertion logic is weak as it only checks that the argument is not an ObjectRef without verifying its resolved value. Please either update the test to correctly test nested lists with ObjectRefs or remove it if it's redundant.
|
|
||
| DECODE_ERROR = "decode_error" | ||
| CACHE_RETRIEVAL_ERROR = "cache_retrieval_error" | ||
| DATA_FORMAT_ERROR = "data_format_error" |
There was a problem hiding this comment.
The ErrorType.DATA_FORMAT_ERROR enum member is defined but it's not used anywhere in the get_data function. The data_type parameter in the RunnerError constructor is also unused. This appears to be dead code. If it's intended for future use, consider adding a comment to clarify. Otherwise, it should be removed to improve code clarity and maintainability.
… data The get_data function now handles two data formats: 1. Encoded ObjectRef (BSON format) - fetches data from cache 2. Directly pickled data (RunnerRequest or result) This fixes the CI failure in test_get_data_task_input_positional_args where task.input contains pickled RunnerRequest bytes (starting with 0x80 0x05 pickle protocol header) instead of encoded ObjectRef. Changes: - Add _is_pickle_data() helper to detect pickle protocol headers - Update get_data() to try pickle decode first if data looks like pickle - Fall back to ObjectRef decode if pickle decode fails - Handle metadata correctly when no ObjectRef is present
Signed-off-by: Klaus Ma <klausm@nvidia.com>
fix #364