Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/gen_bridge_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def generate_rust_service_call(service_descriptor: ServiceDescriptor) -> str:
call: RpcCall,
) -> PyResult<Bound<'p, PyAny>> {
use temporal_client::${descriptor_name};
let mut retry_client = self.retry_client.clone();
let mut retry_client = self.retry_client()?.clone();
self.runtime.future_into_py(py, async move {
let bytes = match call.rpc.as_str() {
$match_arms
Expand Down
7 changes: 7 additions & 0 deletions temporalio/bridge/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ def update_api_key(self, api_key: Optional[str]) -> None:
"""Update underlying API key on Core client."""
self._ref.update_api_key(api_key)

def unsafe_close(self) -> None:
"""Force Core client to drop the underlying Grpc client.

Client behavior after this call is undefined.
"""
self._ref.drop_client()

async def call(
self,
*,
Expand Down
46 changes: 35 additions & 11 deletions temporalio/bridge/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ use crate::runtime;

pyo3::create_exception!(temporal_sdk_bridge, RPCError, PyException);

type Client = RetryClient<ConfiguredClient<TemporalServiceClientWithMetrics>>;
type TemporalClient = ConfiguredClient<TemporalServiceClientWithMetrics>;
type Client = RetryClient<TemporalClient>;

#[pyclass]
pub struct ClientRef {
pub(crate) retry_client: Client,
retry_client: Option<Client>,
pub(crate) runtime: runtime::Runtime,
}

Expand Down Expand Up @@ -95,10 +96,13 @@ pub fn connect_client<'a>(
let runtime = runtime_ref.runtime.clone();
runtime_ref.runtime.future_into_py(py, async move {
Ok(ClientRef {
retry_client: opts
.connect_no_namespace(runtime.core.telemetry().get_temporal_metric_meter())
.await
.map_err(|err| PyRuntimeError::new_err(format!("Failed client connect: {err}")))?,
retry_client: Some(
opts.connect_no_namespace(runtime.core.telemetry().get_temporal_metric_meter())
.await
.map_err(|err| {
PyRuntimeError::new_err(format!("Failed client connect: {err}"))
})?,
),
runtime,
})
})
Expand All @@ -117,23 +121,43 @@ macro_rules! rpc_call {

#[pymethods]
impl ClientRef {
fn drop_client(&mut self) {
self.retry_client = None
Copy link
Member

@cretz cretz Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is predictable enough. The goal is to guarantee that when unsafe_close returns, the socket is no longer connected. But in this case, if a call was being made or the client is held in some other way by some other Core aspect, the connection will remain.

I think we may have to provide an "unsafe close" on the Rust side that does a predictable channel drop and confirms it is the only one using the channel. Hopefully that's not too hard to do, I have not looked into it. This may require putting the channel in a non-cloneable wrapper and having an Arc of that. Or maybe you have to get access to the non-cloneable underneath part. Unsure.

}
Comment on lines +124 to +126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure if this is sufficient. Clients are Clone themselves, and share the underlying connection. It's entirely possible that on the Core side we might clone the client somewhere and then dropping the client ref won't actually seal the deal.

That said, there's also no readily available way to actually close the underlying connections Core side either. So, we might have to just attach a small caveat here in the docstrings or something.


fn update_metadata(&self, headers: HashMap<String, RpcMetadataValue>) -> PyResult<()> {
let (ascii_headers, binary_headers) = partition_headers(headers);
let client = self.configured_client()?;

self.retry_client
.get_client()
client
.set_headers(ascii_headers)
.map_err(|err| PyValueError::new_err(err.to_string()))?;
self.retry_client
.get_client()
client
.set_binary_headers(binary_headers)
.map_err(|err| PyValueError::new_err(err.to_string()))?;

Ok(())
}

fn update_api_key(&self, api_key: Option<String>) {
self.retry_client.get_client().set_api_key(api_key);
if let Ok(client) = self.configured_client() {
client.set_api_key(api_key);
}
}
}

impl ClientRef {
pub(crate) fn retry_client(&self) -> Result<&Client, PyErr> {
self.retry_client
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("client has been dropped"))
}

fn configured_client(&self) -> Result<&TemporalClient, PyErr> {
self.retry_client
.as_ref()
.map(RetryClient::get_client)
.ok_or_else(|| PyRuntimeError::new_err("client has been dropped"))
}
}

Expand Down
10 changes: 5 additions & 5 deletions temporalio/bridge/src/client_rpc_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl ClientRef {
call: RpcCall,
) -> PyResult<Bound<'p, PyAny>> {
use temporal_client::WorkflowService;
let mut retry_client = self.retry_client.clone();
let mut retry_client = self.retry_client()?.clone();
self.runtime.future_into_py(py, async move {
let bytes = match call.rpc.as_str() {
"count_workflow_executions" => {
Expand Down Expand Up @@ -567,7 +567,7 @@ impl ClientRef {
call: RpcCall,
) -> PyResult<Bound<'p, PyAny>> {
use temporal_client::OperatorService;
let mut retry_client = self.retry_client.clone();
let mut retry_client = self.retry_client()?.clone();
self.runtime.future_into_py(py, async move {
let bytes = match call.rpc.as_str() {
"add_or_update_remote_cluster" => {
Expand Down Expand Up @@ -629,7 +629,7 @@ impl ClientRef {

fn call_cloud_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<Bound<'p, PyAny>> {
use temporal_client::CloudService;
let mut retry_client = self.retry_client.clone();
let mut retry_client = self.retry_client()?.clone();
self.runtime.future_into_py(py, async move {
let bytes = match call.rpc.as_str() {
"add_namespace_region" => {
Expand Down Expand Up @@ -843,7 +843,7 @@ impl ClientRef {

fn call_test_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<Bound<'p, PyAny>> {
use temporal_client::TestService;
let mut retry_client = self.retry_client.clone();
let mut retry_client = self.retry_client()?.clone();
self.runtime.future_into_py(py, async move {
let bytes = match call.rpc.as_str() {
"get_current_time" => {
Expand Down Expand Up @@ -882,7 +882,7 @@ impl ClientRef {

fn call_health_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<Bound<'p, PyAny>> {
use temporal_client::HealthService;
let mut retry_client = self.retry_client.clone();
let mut retry_client = self.retry_client()?.clone();
self.runtime.future_into_py(py, async move {
let bytes = match call.rpc.as_str() {
"check" => {
Expand Down
10 changes: 8 additions & 2 deletions temporalio/bridge/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ pub fn new_worker(
let worker = temporal_sdk_core::init_worker(
&runtime_ref.runtime.core,
config,
client.retry_client.clone().into_inner(),
client.retry_client()?.clone().into_inner(),
)
.context("Failed creating worker")?;
Ok(WorkerRef {
Expand Down Expand Up @@ -648,7 +648,13 @@ impl WorkerRef {
self.worker
.as_ref()
.expect("missing worker")
.replace_client(client.retry_client.clone().into_inner());
.replace_client(
client
.retry_client()
.expect("client ref had no client")
.clone()
.into_inner(),
);
}

fn initiate_shutdown(&self) -> PyResult<()> {
Expand Down
14 changes: 14 additions & 0 deletions temporalio/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,16 @@ def update_api_key(self, api_key: Optional[str]) -> None:
"""Update service client's API key."""
raise NotImplementedError

@abstractmethod
def unsafe_close(self) -> None:
"""Force disconnect of the client.
Only advanced users should consider using this.

Any use of clients or objects that rely on clients is
undefined after this call.
"""
raise NotImplementedError

@abstractmethod
async def _rpc_call(
self,
Expand Down Expand Up @@ -346,6 +356,10 @@ def update_api_key(self, api_key: Optional[str]) -> None:
if self._bridge_client:
self._bridge_client.update_api_key(api_key)

def unsafe_close(self) -> None:
if self._bridge_client:
self._bridge_client.unsafe_close()

async def _rpc_call(
self,
rpc: str,
Expand Down
45 changes: 45 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import dataclasses
import json
import os
Expand All @@ -7,6 +8,7 @@
from unittest import mock

import google.protobuf.any_pb2
import psutil
import pytest
from google.protobuf import json_format

Expand Down Expand Up @@ -86,6 +88,7 @@
from temporalio.testing import WorkflowEnvironment
from tests.helpers import (
assert_eq_eventually,
assert_eventually,
ensure_search_attributes_present,
new_worker,
worker_versioning_enabled,
Expand Down Expand Up @@ -1541,3 +1544,45 @@ async def get_schedule_result() -> Tuple[int, Optional[str]]:
)

await handle.delete()


async def test_unsafe_close(env: WorkflowEnvironment):
proc = psutil.Process()
# proc.connections() is deprecated in newer versions, but uv resolved to a version
# that doesn't have the recommended proc.net_connections(). This future proofs
# against an upgrade
list_conns = getattr(proc, "net_connections", proc.connections)

target_host = env.client.config()["service_client"].config.target_host
target_ip = target_host.split(":")[0]

def sum_connections() -> int:
return sum(1 for p in list_conns() if p.raddr[0] == target_ip)

# get the number of connections to the target host.
num_conn = sum_connections()

# create new client that has a connection
client = await Client.connect(target_host)

# get number of connections now that we have a second one.
num_conn_after_client = sum_connections()
assert num_conn_after_client > num_conn

# force drop our connection via bridge
client.service_client.unsafe_close()

# now that we've forced our connection to drop, bridge will raise a RuntimeError
# if you do anything that uses the client.
with pytest.raises(RuntimeError) as dropped_err:
await client.start_workflow(
"some-workflow", id=f"wf-{uuid.uuid4()}", task_queue=f"tq-{uuid.uuid4()}"
)
assert dropped_err.match("client has been dropped")

async def assert_less_connections():
# get number of connections now that we've closed the one we opened.
num_conn_after_close = sum_connections()
assert num_conn_after_close < num_conn_after_client

await assert_eventually(assert_less_connections, timeout=timedelta(seconds=1))
Loading