-
Notifications
You must be signed in to change notification settings - Fork 64
RSDK-4455 - Run session heartbeat loop in separate thread #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
02e6820
eaa7fe6
3865703
1cf0388
48e061c
aa96a24
c18a306
c2dc642
3f1f61d
9712192
a9ee2c6
02117e9
ae4c86b
6418165
47ac09f
9b88c15
e7bbe73
ec6a1c8
9c41432
bfab59c
645a128
de20e08
345b293
a629208
b74a73d
240e42f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import asyncio | ||
import sys | ||
from datetime import timedelta | ||
from enum import IntEnum | ||
from threading import Thread, Lock | ||
from typing import Optional | ||
|
||
from grpclib import Status | ||
|
@@ -9,8 +10,9 @@ | |
from grpclib.exceptions import GRPCError, StreamTerminatedError | ||
from grpclib.metadata import _MetadataLike | ||
|
||
from viam import _TASK_PREFIX, logging | ||
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest, StartSessionResponse | ||
from viam import logging | ||
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest | ||
from viam.rpc.dial import DialOptions, dial | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
SESSION_METADATA_KEY = "viam-sid" | ||
|
@@ -30,10 +32,10 @@ | |
) | ||
|
||
|
||
def loop_kwargs(): | ||
if sys.version_info <= (3, 9): | ||
return {"loop": asyncio.get_running_loop()} | ||
return {} | ||
class _SupportedState(IntEnum): | ||
UNKNOWN = 0 | ||
TRUE = 1 | ||
FALSE = 2 | ||
|
||
|
||
class SessionsClient: | ||
|
@@ -42,26 +44,34 @@ class SessionsClient: | |
supports stopping actuating components when it's not. | ||
""" | ||
|
||
_current_id: str = "" | ||
_disabled: bool = False | ||
_supported: Optional[bool] = None | ||
_heartbeat_interval: Optional[timedelta] = None | ||
|
||
def __init__(self, channel: Channel, *, disabled: bool = False): | ||
def __init__(self, channel: Channel, address: str, dial_options: Optional[DialOptions], *, disabled: bool = False): | ||
self.channel = channel | ||
self.client = RobotServiceStub(channel) | ||
self._address = address | ||
self._dial_options = dial_options | ||
self._disabled = disabled | ||
self._lock = asyncio.Lock(**loop_kwargs()) | ||
|
||
self._lock: Lock = Lock() | ||
self._current_id: str = "" | ||
self._heartbeat_interval: Optional[timedelta] = None | ||
self._supported: _SupportedState = _SupportedState.UNKNOWN | ||
self._thread: Optional[Thread] = None | ||
|
||
listen(self.channel, SendRequest, self._send_request) | ||
listen(self.channel, RecvTrailingMetadata, self._recv_trailers) | ||
|
||
def reset(self): | ||
if self._lock.locked(): | ||
return | ||
with self._lock: | ||
self._reset() | ||
|
||
def _reset(self): | ||
LOGGER.debug("resetting session") | ||
self._supported = None | ||
self._supported = _SupportedState.UNKNOWN | ||
self._current_id = "" | ||
self._heartbeat_interval = None | ||
if self._thread is not None: | ||
self._thread.join(timeout=1) | ||
self._thread = None | ||
|
||
async def _send_request(self, event: SendRequest): | ||
if self._disabled: | ||
|
@@ -79,77 +89,85 @@ async def _recv_trailers(self, event: RecvTrailingMetadata): | |
|
||
@property | ||
async def metadata(self) -> _MetadataLike: | ||
if self._disabled: | ||
return self._metadata | ||
|
||
if self._supported: | ||
return self._metadata | ||
|
||
async with self._lock: | ||
if self._supported is False: | ||
with self._lock: | ||
if self._disabled or self._supported != _SupportedState.UNKNOWN: | ||
return self._metadata | ||
|
||
request = StartSessionRequest(resume=self._current_id) | ||
response: Optional[StartSessionResponse] = None | ||
|
||
try: | ||
response = await self.client.StartSession(request) | ||
except GRPCError as error: | ||
if error.status == Status.UNIMPLEMENTED: | ||
self._supported = False | ||
request = StartSessionRequest(resume=self._current_id) | ||
try: | ||
response = await self.client.StartSession(request) | ||
except GRPCError as error: | ||
if error.status == Status.UNIMPLEMENTED: | ||
with self._lock: | ||
self._reset() | ||
self._supported = _SupportedState.FALSE | ||
return self._metadata | ||
else: | ||
raise | ||
else: | ||
if response is None: | ||
raise GRPCError(status=Status.INTERNAL, message="Expected response to start session") | ||
|
||
if response.heartbeat_window is None: | ||
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session") | ||
raise | ||
|
||
self._supported = True | ||
self._heartbeat_interval = response.heartbeat_window.ToTimedelta() | ||
self._current_id = response.id | ||
if response is None: | ||
raise GRPCError(status=Status.INTERNAL, message="Expected response to start session") | ||
|
||
# tick once to ensure heartbeats are supported | ||
await self._heartbeat_tick() | ||
|
||
if self._supported: | ||
# We send heartbeats slightly faster than the interval window to | ||
# ensure that we don't fall outside of it and expire the session. | ||
wait = self._heartbeat_interval.total_seconds() / 5 | ||
asyncio.create_task(self._heartbeat_task(wait), name=f"{_TASK_PREFIX}-heartbeat") | ||
|
||
return self._metadata | ||
|
||
async def _heartbeat_task(self, wait: float): | ||
while self._supported: | ||
await asyncio.sleep(wait) | ||
await self._heartbeat_tick() | ||
if response.heartbeat_window is None: | ||
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session") | ||
|
||
async def _heartbeat_tick(self): | ||
if not self._supported: | ||
return | ||
with self._lock: | ||
self._supported = _SupportedState.TRUE | ||
self._heartbeat_interval = response.heartbeat_window.ToTimedelta() | ||
self._current_id = response.id | ||
|
||
while self._lock.locked(): | ||
pass | ||
# tick once to ensure heartbeats are supported | ||
await self._heartbeat_tick(self.client) | ||
|
||
with self._lock: | ||
if self._thread is not None: | ||
self._reset() | ||
if self._supported == _SupportedState.TRUE: | ||
# We send heartbeats faster than the interval window to | ||
# ensure that we don't fall outside of it and expire the session. | ||
wait = self._heartbeat_interval.total_seconds() / 5 | ||
|
||
self._thread = Thread( | ||
name="heartbeat-thread", | ||
target=asyncio.run, | ||
args=(self._heartbeat_process(wait),), | ||
daemon=True, | ||
) | ||
self._thread.start() | ||
Comment on lines
+130
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: is it possible to get here with an active thread? (e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is possible, added a reset if so |
||
|
||
request = SendSessionHeartbeatRequest(id=self._current_id) | ||
return self._metadata | ||
|
||
if self._heartbeat_interval is None: | ||
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session") | ||
async def _heartbeat_tick(self, client: RobotServiceStub): | ||
with self._lock: | ||
if not self._current_id: | ||
LOGGER.debug("Failed to send heartbeat, session client reset") | ||
return | ||
request = SendSessionHeartbeatRequest(id=self._current_id) | ||
|
||
try: | ||
await self.client.SendSessionHeartbeat(request) | ||
await client.SendSessionHeartbeat(request) | ||
except (GRPCError, StreamTerminatedError): | ||
LOGGER.debug("Heartbeat terminated", exc_info=True) | ||
self.reset() | ||
else: | ||
LOGGER.debug("Sent heartbeat successfully") | ||
|
||
async def _heartbeat_process(self, wait: float): | ||
dial_options = self._dial_options if self._dial_options is not None else DialOptions() | ||
dial_options.disable_webrtc = True | ||
|
||
channel = await dial(address=self._address, options=dial_options) | ||
client = RobotServiceStub(channel.channel) | ||
while True: | ||
with self._lock: | ||
if self._supported != _SupportedState.TRUE: | ||
return | ||
await self._heartbeat_tick(client) | ||
await asyncio.sleep(wait) | ||
|
||
@property | ||
def _metadata(self) -> _MetadataLike: | ||
if self._supported and self._current_id != "": | ||
if self._supported == _SupportedState.TRUE and self._current_id != "": | ||
return {SESSION_METADATA_KEY: self._current_id} | ||
|
||
return {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from datetime import timedelta | ||
|
||
from google.protobuf.duration_pb2 import Duration | ||
from grpclib.server import Stream | ||
|
||
from viam.errors import MethodNotImplementedError | ||
from viam.proto.robot import ( | ||
BlockForOperationRequest, | ||
BlockForOperationResponse, | ||
CancelOperationRequest, | ||
CancelOperationResponse, | ||
DiscoverComponentsRequest, | ||
DiscoverComponentsResponse, | ||
FrameSystemConfigRequest, | ||
FrameSystemConfigResponse, | ||
GetOperationsRequest, | ||
GetOperationsResponse, | ||
GetSessionsRequest, | ||
GetSessionsResponse, | ||
GetStatusRequest, | ||
GetStatusResponse, | ||
ResourceNamesRequest, | ||
ResourceNamesResponse, | ||
ResourceRPCSubtypesRequest, | ||
ResourceRPCSubtypesResponse, | ||
RobotServiceBase, | ||
SendSessionHeartbeatRequest, | ||
SendSessionHeartbeatResponse, | ||
StartSessionRequest, | ||
StartSessionResponse, | ||
StopAllRequest, | ||
StopAllResponse, | ||
StreamStatusRequest, | ||
StreamStatusResponse, | ||
TransformPCDRequest, | ||
TransformPCDResponse, | ||
TransformPoseRequest, | ||
TransformPoseResponse, | ||
) | ||
|
||
|
||
class MockRobot(RobotServiceBase): | ||
SESSION_ID = "sid" | ||
HEARTBEAT_INTERVAL = 2 | ||
|
||
def __init__(self): | ||
self.heartbeat_count = 0 | ||
super().__init__() | ||
|
||
async def StartSession(self, stream: Stream[StartSessionRequest, StartSessionResponse]) -> None: | ||
request = await stream.recv_message() | ||
assert request is not None | ||
heartbeat_window = Duration() | ||
heartbeat_window.FromTimedelta(timedelta(seconds=self.HEARTBEAT_INTERVAL)) | ||
response = StartSessionResponse(id=self.SESSION_ID, heartbeat_window=heartbeat_window) | ||
await stream.send_message(response) | ||
|
||
async def SendSessionHeartbeat(self, stream: Stream[SendSessionHeartbeatRequest, SendSessionHeartbeatResponse]) -> None: | ||
request = await stream.recv_message() | ||
assert request is not None | ||
self.heartbeat_count += 1 | ||
response = SendSessionHeartbeatResponse() | ||
await stream.send_message(response) | ||
|
||
async def ResourceNames(self, stream: Stream[ResourceNamesRequest, ResourceNamesResponse]) -> None: | ||
raise MethodNotImplementedError("ResourceNames").grpc_error | ||
|
||
async def GetStatus(self, stream: Stream[GetStatusRequest, GetStatusResponse]) -> None: | ||
raise MethodNotImplementedError("GetStatus").grpc_error | ||
|
||
async def StreamStatus(self, stream: Stream[StreamStatusRequest, StreamStatusResponse]) -> None: | ||
raise MethodNotImplementedError("StreamStatus").grpc_error | ||
|
||
async def GetOperations(self, stream: Stream[GetOperationsRequest, GetOperationsResponse]) -> None: | ||
raise MethodNotImplementedError("GetOperations").grpc_error | ||
|
||
async def ResourceRPCSubtypes(self, stream: Stream[ResourceRPCSubtypesRequest, ResourceRPCSubtypesResponse]) -> None: | ||
raise MethodNotImplementedError("ResourceRPCSubtypes").grpc_error | ||
|
||
async def CancelOperation(self, stream: Stream[CancelOperationRequest, CancelOperationResponse]) -> None: | ||
raise MethodNotImplementedError("CancelOperation").grpc_error | ||
|
||
async def BlockForOperation(self, stream: Stream[BlockForOperationRequest, BlockForOperationResponse]) -> None: | ||
raise MethodNotImplementedError("BlockForOperation").grpc_error | ||
|
||
async def FrameSystemConfig(self, stream: Stream[FrameSystemConfigRequest, FrameSystemConfigResponse]) -> None: | ||
raise MethodNotImplementedError("FrameSystemConfig").grpc_error | ||
|
||
async def TransformPose(self, stream: Stream[TransformPoseRequest, TransformPoseResponse]) -> None: | ||
raise MethodNotImplementedError("TransformPose").grpc_error | ||
|
||
async def DiscoverComponents(self, stream: Stream[DiscoverComponentsRequest, DiscoverComponentsResponse]) -> None: | ||
raise MethodNotImplementedError("DiscoverComponents").grpc_error | ||
|
||
async def StopAll(self, stream: Stream[StopAllRequest, StopAllResponse]) -> None: | ||
raise MethodNotImplementedError("StopAll").grpc_error | ||
|
||
async def GetSessions(self, stream: Stream[GetSessionsRequest, GetSessionsResponse]) -> None: | ||
raise MethodNotImplementedError("GetSessions").grpc_error | ||
|
||
async def TransformPCD(self, stream: Stream[TransformPCDRequest, TransformPCDResponse]) -> None: | ||
raise MethodNotImplementedError("TransformPCD").grpc_error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure this should be daemonized?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
main issue I ran into a non-daemonized thread is that it doesn't stop on a ctrl+c, and there doesn't seem to be a place to catch to interrupt and kill the thread either? since the daemonic thread just dies on program exit, the cleanup can be handled like normal (since the one rust connection is reused).