diff --git a/Cargo.lock b/Cargo.lock index 0f8abcd..d0ee508 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -533,9 +533,9 @@ dependencies = [ [[package]] name = "restate-sdk-shared-core" -version = "0.0.4" +version = "0.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54026e33285f65cf92da3770cb4810bc095ad4831f33dfd70fb799fe57720409" +checksum = "badf8da7bdf9459a8ff675272c823db201084adc9d0bcb7942dd9ba7d2bc12f2" dependencies = [ "ambassador", "base64 0.22.1", diff --git a/Cargo.toml b/Cargo.toml index dd350f1..ad7a637 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,4 +14,4 @@ doc = false [dependencies] pyo3 = { version = "0.22.0", features = ["extension-module"] } tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } -restate-sdk-shared-core = { version = "0.0.4" } +restate-sdk-shared-core = "0.0.5" diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 25b8a75..06acd88 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -205,7 +205,7 @@ async def await_point(): return await_point() # do not await here, the caller will do it. def state_keys(self) -> Awaitable[List[str]]: - raise NotImplementedError + return self.create_poll_coroutine(self.vm.sys_get_state_keys()) # type: ignore def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None: """Set the value associated with the given name.""" diff --git a/python/restate/vm.py b/python/restate/vm.py index 2035f36..4f23a62 100644 --- a/python/restate/vm.py +++ b/python/restate/vm.py @@ -15,7 +15,7 @@ from dataclasses import dataclass import typing -from restate._internal import PyVM, PyFailure, PySuspended, PyVoid # pylint: disable=import-error,no-name-in-module +from restate._internal import PyVM, PyFailure, PySuspended, PyVoid, PyStateKeys # pylint: disable=import-error,no-name-in-module @dataclass class Invocation: @@ -110,6 +110,9 @@ def take_async_result(self, handle: typing.Any) -> AsyncResultType: if isinstance(result, bytes): # success with a non empty value return result + if isinstance(result, PyStateKeys): + # success with state keys + return result.keys if isinstance(result, PyFailure): # a terminal failure code = result.code @@ -179,6 +182,17 @@ def sys_get_state(self, name) -> int: """ return self.vm.sys_get_state(name) + + def sys_get_state_keys(self) -> int: + """ + Retrieves all keys. + + Returns: + The state keys + """ + return self.vm.sys_get_state_keys() + + def sys_set_state(self, name: str, value: bytes): """ Sets a key-value binding. diff --git a/src/lib.rs b/src/lib.rs index af45526..495cf4e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,10 @@ use pyo3::create_exception; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyNone}; -use restate_sdk_shared_core::{ - AsyncResultHandle, CoreVM, Failure, Header, IdentityHeaderMap, IdentityVerifier, Input, - NonEmptyValue, ResponseHead, RunEnterResult, SuspendedOrVMError, TakeOutputResult, Target, - VMError, Value, VM, -}; +use restate_sdk_shared_core::{AsyncResultHandle, CoreVM, Failure, Header, IdentityVerifier, Input, NonEmptyValue, ResponseHead, RunEnterResult, SuspendedOrVMError, TakeOutputResult, Target, VMError, Value, VM}; use std::borrow::Cow; -use std::convert::Infallible; use std::time::Duration; + // Data model #[pyclass] @@ -103,6 +99,13 @@ impl From for Failure { } } +#[pyclass] +#[derive(Clone)] +struct PyStateKeys { + #[pyo3(get, set)] + keys: Vec +} + #[pyclass] pub struct PyInput { #[pyo3(get, set)] @@ -233,6 +236,9 @@ impl PyVM { Ok(Some(Value::Failure(f))) => { Ok(PyFailure::from(f).into_py(py).into_bound(py).into_any()) } + Ok(Some(Value::StateKeys(keys))) => { + Ok(PyStateKeys {keys}.into_py(py).into_bound(py).into_any()) + } } } @@ -248,7 +254,17 @@ impl PyVM { ) -> Result { self_ .vm - .sys_get_state(key) + .sys_state_get(key) + .map(Into::into) + .map_err(Into::into) + } + + fn sys_get_state_keys( + mut self_: PyRefMut<'_, Self>, + ) -> Result { + self_ + .vm + .sys_state_get_keys() .map(Into::into) .map_err(Into::into) } @@ -260,16 +276,16 @@ impl PyVM { ) -> Result<(), PyVMError> { self_ .vm - .sys_set_state(key, buffer.as_bytes().to_vec()) + .sys_state_set(key, buffer.as_bytes().to_vec()) .map_err(Into::into) } fn sys_clear_state(mut self_: PyRefMut<'_, Self>, key: String) -> Result<(), PyVMError> { - self_.vm.sys_clear_state(key).map_err(Into::into) + self_.vm.sys_state_clear(key).map_err(Into::into) } fn sys_clear_all_state(mut self_: PyRefMut<'_, Self>) -> Result<(), PyVMError> { - self_.vm.sys_clear_all_state().map_err(Into::into) + self_.vm.sys_state_clear_all().map_err(Into::into) } fn sys_sleep( @@ -484,21 +500,6 @@ struct PyIdentityVerifier { verifier: IdentityVerifier, } -struct PyIdentityHeaders(Vec<(String, String)>); - -impl IdentityHeaderMap for PyIdentityHeaders { - type Error = Infallible; - - fn extract(&self, name: &str) -> Result, Self::Error> { - for (k, v) in &self.0 { - if k.eq_ignore_ascii_case(name) { - return Ok(Some(v)); - } - } - Ok(None) - } -} - // Exceptions create_exception!( restate_sdk_python_core, @@ -531,7 +532,7 @@ impl PyIdentityVerifier { ) -> PyResult<()> { self_ .verifier - .verify_identity(&PyIdentityHeaders(headers), &path) + .verify_identity(&headers, &path) .map_err(|e| IdentityVerificationException::new_err(e.to_string())) } } @@ -549,6 +550,7 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/test-services/Dockerfile b/test-services/Dockerfile index ff39f52..9492bc1 100644 --- a/test-services/Dockerfile +++ b/test-services/Dockerfile @@ -4,7 +4,7 @@ FROM ghcr.io/pyo3/maturin AS build-sdk WORKDIR /usr/src/app -COPY --exclude=test-services/ . . +COPY . . RUN maturin build --out dist --interpreter python3.12 diff --git a/test-services/exclusions.yaml b/test-services/exclusions.yaml index 48c6585..ff0ab43 100644 --- a/test-services/exclusions.yaml +++ b/test-services/exclusions.yaml @@ -1,13 +1,8 @@ exclusions: "alwaysSuspending": - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.State" "default": - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.KafkaIngress" - - "dev.restate.sdktesting.tests.State" - "lazyState": - - "dev.restate.sdktesting.tests.State" + "lazyState": [] "singleThreadSinglePartition": - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.State"