diff --git a/scripts/gen_bridge_client.py b/scripts/gen_bridge_client.py index 89ae54ec1..13ee47bc5 100644 --- a/scripts/gen_bridge_client.py +++ b/scripts/gen_bridge_client.py @@ -172,7 +172,7 @@ def generate_rust_service_call(service_descriptor: ServiceDescriptor) -> str: call: RpcCall, ) -> PyResult> { use temporal_client::${descriptor_name}; - let mut retry_client = self.retry_client.clone(); + let mut retry_client = self.retry_client()?.clone(); self.runtime.future_into_py(py, async move { let bytes = match call.rpc.as_str() { $match_arms diff --git a/temporalio/bridge/client.py b/temporalio/bridge/client.py index dafd6fb71..709ba6def 100644 --- a/temporalio/bridge/client.py +++ b/temporalio/bridge/client.py @@ -116,6 +116,13 @@ def update_api_key(self, api_key: Optional[str]) -> None: """Update underlying API key on Core client.""" self._ref.update_api_key(api_key) + def unsafe_close(self) -> None: + """Force Core client to drop the underlying Grpc client. + + Client behavior after this call is undefined. + """ + self._ref.drop_client() + async def call( self, *, diff --git a/temporalio/bridge/src/client.rs b/temporalio/bridge/src/client.rs index dfbd432a1..c205a3270 100644 --- a/temporalio/bridge/src/client.rs +++ b/temporalio/bridge/src/client.rs @@ -17,11 +17,12 @@ use crate::runtime; pyo3::create_exception!(temporal_sdk_bridge, RPCError, PyException); -type Client = RetryClient>; +type TemporalClient = ConfiguredClient; +type Client = RetryClient; #[pyclass] pub struct ClientRef { - pub(crate) retry_client: Client, + retry_client: Option, pub(crate) runtime: runtime::Runtime, } @@ -95,10 +96,13 @@ pub fn connect_client<'a>( let runtime = runtime_ref.runtime.clone(); runtime_ref.runtime.future_into_py(py, async move { Ok(ClientRef { - retry_client: opts - .connect_no_namespace(runtime.core.telemetry().get_temporal_metric_meter()) - .await - .map_err(|err| PyRuntimeError::new_err(format!("Failed client connect: {err}")))?, + retry_client: Some( + opts.connect_no_namespace(runtime.core.telemetry().get_temporal_metric_meter()) + .await + .map_err(|err| { + PyRuntimeError::new_err(format!("Failed client connect: {err}")) + })?, + ), runtime, }) }) @@ -117,15 +121,18 @@ macro_rules! rpc_call { #[pymethods] impl ClientRef { + fn drop_client(&mut self) { + self.retry_client = None + } + fn update_metadata(&self, headers: HashMap) -> PyResult<()> { let (ascii_headers, binary_headers) = partition_headers(headers); + let client = self.configured_client()?; - self.retry_client - .get_client() + client .set_headers(ascii_headers) .map_err(|err| PyValueError::new_err(err.to_string()))?; - self.retry_client - .get_client() + client .set_binary_headers(binary_headers) .map_err(|err| PyValueError::new_err(err.to_string()))?; @@ -133,7 +140,24 @@ impl ClientRef { } fn update_api_key(&self, api_key: Option) { - self.retry_client.get_client().set_api_key(api_key); + if let Ok(client) = self.configured_client() { + client.set_api_key(api_key); + } + } +} + +impl ClientRef { + pub(crate) fn retry_client(&self) -> Result<&Client, PyErr> { + self.retry_client + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("client has been dropped")) + } + + fn configured_client(&self) -> Result<&TemporalClient, PyErr> { + self.retry_client + .as_ref() + .map(RetryClient::get_client) + .ok_or_else(|| PyRuntimeError::new_err("client has been dropped")) } } diff --git a/temporalio/bridge/src/client_rpc_generated.rs b/temporalio/bridge/src/client_rpc_generated.rs index 659f5d8cf..681e538af 100644 --- a/temporalio/bridge/src/client_rpc_generated.rs +++ b/temporalio/bridge/src/client_rpc_generated.rs @@ -16,7 +16,7 @@ impl ClientRef { call: RpcCall, ) -> PyResult> { use temporal_client::WorkflowService; - let mut retry_client = self.retry_client.clone(); + let mut retry_client = self.retry_client()?.clone(); self.runtime.future_into_py(py, async move { let bytes = match call.rpc.as_str() { "count_workflow_executions" => { @@ -567,7 +567,7 @@ impl ClientRef { call: RpcCall, ) -> PyResult> { use temporal_client::OperatorService; - let mut retry_client = self.retry_client.clone(); + let mut retry_client = self.retry_client()?.clone(); self.runtime.future_into_py(py, async move { let bytes = match call.rpc.as_str() { "add_or_update_remote_cluster" => { @@ -629,7 +629,7 @@ impl ClientRef { fn call_cloud_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult> { use temporal_client::CloudService; - let mut retry_client = self.retry_client.clone(); + let mut retry_client = self.retry_client()?.clone(); self.runtime.future_into_py(py, async move { let bytes = match call.rpc.as_str() { "add_namespace_region" => { @@ -843,7 +843,7 @@ impl ClientRef { fn call_test_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult> { use temporal_client::TestService; - let mut retry_client = self.retry_client.clone(); + let mut retry_client = self.retry_client()?.clone(); self.runtime.future_into_py(py, async move { let bytes = match call.rpc.as_str() { "get_current_time" => { @@ -882,7 +882,7 @@ impl ClientRef { fn call_health_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult> { use temporal_client::HealthService; - let mut retry_client = self.retry_client.clone(); + let mut retry_client = self.retry_client()?.clone(); self.runtime.future_into_py(py, async move { let bytes = match call.rpc.as_str() { "check" => { diff --git a/temporalio/bridge/src/worker.rs b/temporalio/bridge/src/worker.rs index 92b43f356..643618541 100644 --- a/temporalio/bridge/src/worker.rs +++ b/temporalio/bridge/src/worker.rs @@ -479,7 +479,7 @@ pub fn new_worker( let worker = temporal_sdk_core::init_worker( &runtime_ref.runtime.core, config, - client.retry_client.clone().into_inner(), + client.retry_client()?.clone().into_inner(), ) .context("Failed creating worker")?; Ok(WorkerRef { @@ -648,7 +648,13 @@ impl WorkerRef { self.worker .as_ref() .expect("missing worker") - .replace_client(client.retry_client.clone().into_inner()); + .replace_client( + client + .retry_client() + .expect("client ref had no client") + .clone() + .into_inner(), + ); } fn initiate_shutdown(&self) -> PyResult<()> { diff --git a/temporalio/service.py b/temporalio/service.py index bfbaaa051..6eeaa074e 100644 --- a/temporalio/service.py +++ b/temporalio/service.py @@ -265,6 +265,16 @@ def update_api_key(self, api_key: Optional[str]) -> None: """Update service client's API key.""" raise NotImplementedError + @abstractmethod + def unsafe_close(self) -> None: + """Force disconnect of the client. + Only advanced users should consider using this. + + Any use of clients or objects that rely on clients is + undefined after this call. + """ + raise NotImplementedError + @abstractmethod async def _rpc_call( self, @@ -346,6 +356,10 @@ def update_api_key(self, api_key: Optional[str]) -> None: if self._bridge_client: self._bridge_client.update_api_key(api_key) + def unsafe_close(self) -> None: + if self._bridge_client: + self._bridge_client.unsafe_close() + async def _rpc_call( self, rpc: str, diff --git a/tests/test_client.py b/tests/test_client.py index 63dec2810..0c2048be0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,3 +1,4 @@ +import asyncio import dataclasses import json import os @@ -7,6 +8,7 @@ from unittest import mock import google.protobuf.any_pb2 +import psutil import pytest from google.protobuf import json_format @@ -86,6 +88,7 @@ from temporalio.testing import WorkflowEnvironment from tests.helpers import ( assert_eq_eventually, + assert_eventually, ensure_search_attributes_present, new_worker, worker_versioning_enabled, @@ -1541,3 +1544,45 @@ async def get_schedule_result() -> Tuple[int, Optional[str]]: ) await handle.delete() + + +async def test_unsafe_close(env: WorkflowEnvironment): + proc = psutil.Process() + # proc.connections() is deprecated in newer versions, but uv resolved to a version + # that doesn't have the recommended proc.net_connections(). This future proofs + # against an upgrade + list_conns = getattr(proc, "net_connections", proc.connections) + + target_host = env.client.config()["service_client"].config.target_host + target_ip = target_host.split(":")[0] + + def sum_connections() -> int: + return sum(1 for p in list_conns() if p.raddr[0] == target_ip) + + # get the number of connections to the target host. + num_conn = sum_connections() + + # create new client that has a connection + client = await Client.connect(target_host) + + # get number of connections now that we have a second one. + num_conn_after_client = sum_connections() + assert num_conn_after_client > num_conn + + # force drop our connection via bridge + client.service_client.unsafe_close() + + # now that we've forced our connection to drop, bridge will raise a RuntimeError + # if you do anything that uses the client. + with pytest.raises(RuntimeError) as dropped_err: + await client.start_workflow( + "some-workflow", id=f"wf-{uuid.uuid4()}", task_queue=f"tq-{uuid.uuid4()}" + ) + assert dropped_err.match("client has been dropped") + + async def assert_less_connections(): + # get number of connections now that we've closed the one we opened. + num_conn_after_close = sum_connections() + assert num_conn_after_close < num_conn_after_client + + await assert_eventually(assert_less_connections, timeout=timedelta(seconds=1))