diff --git a/sdk/python/src/flamepy/__init__.py b/sdk/python/src/flamepy/__init__.py index 4a0d688c..8daf4be0 100644 --- a/sdk/python/src/flamepy/__init__.py +++ b/sdk/python/src/flamepy/__init__.py @@ -44,7 +44,6 @@ ApplicationState, Shim, FlameErrorCode, - DataSource, # Classes FlameError, SessionAttributes, diff --git a/sdk/python/src/flamepy/cache.py b/sdk/python/src/flamepy/cache.py index 0af02647..43002af5 100644 --- a/sdk/python/src/flamepy/cache.py +++ b/sdk/python/src/flamepy/cache.py @@ -15,8 +15,10 @@ from pydantic import BaseModel import logging import contextlib +import pickle +from typing import Any, Optional -from .types import ObjectRef, DataSource, FlameContext +from .types import ObjectRef, FlameContext @contextlib.contextmanager @@ -52,51 +54,80 @@ class ObjectMetadata(BaseModel): size: int -def put_object(session_id: str, data: bytes) -> "ObjectRef": - """Put an object into the cache.""" +def put_object(session_id: str, obj: Any) -> "ObjectRef": + """Put an object into the cache. + + Args: + session_id: The session ID for the object + obj: The object to cache (will be pickled) + + Returns: + ObjectRef pointing to the cached object + + Raises: + Exception: If cache endpoint is not configured or request fails + """ context = FlameContext() - if context._cache_endpoint is None or data is None: - return ObjectRef(source=DataSource.LOCAL, data=data) + if context._cache_endpoint is None: + raise Exception("Cache endpoint is not configured") + + # Serialize the object using pickle + data = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) with suppress_dependency_logs(): response = httpx.post(f"{context._cache_endpoint}/objects/{session_id}", data=data) response.raise_for_status() metadata = ObjectMetadata.model_validate(response.json()) - return ObjectRef(source=DataSource.REMOTE, url=metadata.endpoint, data=data, version=metadata.version) - - -def get_object(de: ObjectRef) -> "ObjectRef": - """Get an object from the cache.""" - if de.source != DataSource.REMOTE: - return de - + return ObjectRef(url=metadata.endpoint, version=metadata.version) + + +def get_object(ref: ObjectRef) -> Any: + """Get an object from the cache. + + Args: + ref: ObjectRef pointing to the cached object + + Returns: + The deserialized object + + Raises: + Exception: If request fails + """ with suppress_dependency_logs(): - response = httpx.get(de.url) + response = httpx.get(ref.url) response.raise_for_status() obj = Object.model_validate(response.json()) - - de.data = bytes(obj.data) - de.version = obj.version - - return de - - -def update_object(de: ObjectRef) -> "ObjectRef": - """Update an object in the cache.""" - if de.source != DataSource.REMOTE: - return de - - obj = Object(version=de.version, data=list(de.data)) + data = bytes(obj.data) + + # Deserialize the object using pickle + return pickle.loads(data) + + +def update_object(ref: ObjectRef, new_obj: Any) -> "ObjectRef": + """Update an object in the cache. + + Args: + ref: ObjectRef pointing to the cached object to update + new_obj: The new object to store (will be pickled) + + Returns: + Updated ObjectRef with new version + + Raises: + Exception: If request fails + """ + # Serialize the new object using pickle + new_data = pickle.dumps(new_obj, protocol=pickle.HIGHEST_PROTOCOL) + + obj = Object(version=ref.version, data=list(new_data)) data = obj.model_dump_json() with suppress_dependency_logs(): - response = httpx.put(de.url, data=data) + response = httpx.put(ref.url, data=data) response.raise_for_status() metadata = ObjectMetadata.model_validate(response.json()) - de.version = metadata.version - - return de + return ObjectRef(url=ref.url, version=metadata.version) diff --git a/sdk/python/src/flamepy/client.py b/sdk/python/src/flamepy/client.py index a8c48609..a41b5e3d 100644 --- a/sdk/python/src/flamepy/client.py +++ b/sdk/python/src/flamepy/client.py @@ -316,9 +316,7 @@ def create_session(self, attrs: SessionAttributes) -> "Session": session_id = short_name(attrs.application) if attrs.id is None else attrs.id - common_data_bin = pickle.dumps(attrs.common_data, protocol=pickle.HIGHEST_PROTOCOL) - - object_ref = put_object(session_id, common_data_bin) + object_ref = put_object(session_id, attrs.common_data) session_spec = SessionSpec( application=attrs.application, @@ -511,9 +509,10 @@ def __init__( def common_data(self) -> Any: """Get the common data of Session.""" - self._common_data = get_object(self._common_data) - - return pickle.loads(self._common_data.data) if self._common_data is not None else None + if self._common_data is None: + return None + + return get_object(self._common_data) def create_task(self, input_data: Any) -> Task: """Create a new task in the session.""" diff --git a/sdk/python/src/flamepy/runpy.py b/sdk/python/src/flamepy/runpy.py index 99581388..7b416e47 100644 --- a/sdk/python/src/flamepy/runpy.py +++ b/sdk/python/src/flamepy/runpy.py @@ -57,11 +57,10 @@ def _resolve_object_ref(self, value: Any) -> Any: """ if isinstance(value, ObjectRef): logger.debug(f"Resolving ObjectRef: {value}") - obj_ref = get_object(value) - if obj_ref is None or obj_ref.data is None: + resolved_value = get_object(value) + if resolved_value is None: raise ValueError(f"Failed to retrieve ObjectRef from cache: {value}") - resolved_value = pickle.loads(obj_ref.data) logger.debug(f"Resolved ObjectRef to type: {type(resolved_value)}") return resolved_value diff --git a/sdk/python/src/flamepy/service.py b/sdk/python/src/flamepy/service.py index 0320fac5..6444aae3 100644 --- a/sdk/python/src/flamepy/service.py +++ b/sdk/python/src/flamepy/service.py @@ -67,16 +67,17 @@ class SessionContext: def common_data(self) -> Any: """Get the common data.""" - self._object_ref = get_object(self._object_ref) - return pickle.loads(self._object_ref.data) if self._object_ref is not None else None + if self._object_ref is None: + return None + + return get_object(self._object_ref) def update_common_data(self, data: Any): """Update the common data.""" if self._object_ref is None: return - self._object_ref.data = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) - self._object_ref = update_object(self._object_ref) + self._object_ref = update_object(self._object_ref, data) @dataclass diff --git a/sdk/python/src/flamepy/types.py b/sdk/python/src/flamepy/types.py index c312d81d..74d45a69 100644 --- a/sdk/python/src/flamepy/types.py +++ b/sdk/python/src/flamepy/types.py @@ -232,28 +232,15 @@ def __init__(self): self._cache_endpoint = cache_endpoint -class DataSource(IntEnum): - """Data location enumeration.""" - - LOCAL = 0 - REMOTE = 1 - - @dataclass class ObjectRef: - """Object reference.""" + """Object reference for remote cached objects.""" - source: DataSource - url: Optional[str] = None + url: str version: int = 0 - data: Optional[bytes] = None def encode(self) -> bytes: data = asdict(self) - # For remote data, the data is not included in the JSON - if self.source == DataSource.REMOTE: - data["data"] = None - return bson.dumps(data) @classmethod