Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/viam/robot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ async def _with_channel(cls, channel: Union[Channel, ViamChannel], options: Opti
self._should_close_channel = close_channel
self._options = options
self._address = self._channel._path if self._channel._path else f"{self._channel._host}:{self._channel._port}"
self._sessions_client = SessionsClient(self._channel, disabled=self._options.disable_sessions)
self._sessions_client = SessionsClient(
self._channel, self._address, self._options.dial_options, disabled=self._options.disable_sessions
)

try:
await self.refresh()
Expand Down Expand Up @@ -303,7 +305,12 @@ async def _check_connection(self, check_every: int, reconnect_every: int):
self._channel = channel.channel
self._viam_channel = channel
self._client = RobotServiceStub(self._channel)
self._sessions_client = SessionsClient(channel=self._channel, disabled=self._options.disable_sessions)
self._sessions_client = SessionsClient(
channel=self._channel,
address=self._address,
dial_options=self._options.dial_options,
disabled=self._options.disable_sessions,
)

await self.refresh()
self._connected = True
Expand Down
154 changes: 86 additions & 68 deletions src/viam/sessions_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import sys
from datetime import timedelta
from enum import IntEnum
from threading import Thread, Lock
from typing import Optional

from grpclib import Status
Expand All @@ -9,8 +10,9 @@
from grpclib.exceptions import GRPCError, StreamTerminatedError
from grpclib.metadata import _MetadataLike

from viam import _TASK_PREFIX, logging
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest, StartSessionResponse
from viam import logging
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest
from viam.rpc.dial import DialOptions, dial

LOGGER = logging.getLogger(__name__)
SESSION_METADATA_KEY = "viam-sid"
Expand All @@ -30,10 +32,10 @@
)


def loop_kwargs():
if sys.version_info <= (3, 9):
return {"loop": asyncio.get_running_loop()}
return {}
class _SupportedState(IntEnum):
UNKNOWN = 0
TRUE = 1
FALSE = 2


class SessionsClient:
Expand All @@ -42,26 +44,34 @@ class SessionsClient:
supports stopping actuating components when it's not.
"""

_current_id: str = ""
_disabled: bool = False
_supported: Optional[bool] = None
_heartbeat_interval: Optional[timedelta] = None

def __init__(self, channel: Channel, *, disabled: bool = False):
def __init__(self, channel: Channel, address: str, dial_options: Optional[DialOptions], *, disabled: bool = False):
self.channel = channel
self.client = RobotServiceStub(channel)
self._address = address
self._dial_options = dial_options
self._disabled = disabled
self._lock = asyncio.Lock(**loop_kwargs())

self._lock: Lock = Lock()
self._current_id: str = ""
self._heartbeat_interval: Optional[timedelta] = None
self._supported: _SupportedState = _SupportedState.UNKNOWN
self._thread: Optional[Thread] = None

listen(self.channel, SendRequest, self._send_request)
listen(self.channel, RecvTrailingMetadata, self._recv_trailers)

def reset(self):
if self._lock.locked():
return
with self._lock:
self._reset()

def _reset(self):
LOGGER.debug("resetting session")
self._supported = None
self._supported = _SupportedState.UNKNOWN
self._current_id = ""
self._heartbeat_interval = None
if self._thread is not None:
self._thread.join(timeout=1)
self._thread = None

async def _send_request(self, event: SendRequest):
if self._disabled:
Expand All @@ -79,77 +89,85 @@ async def _recv_trailers(self, event: RecvTrailingMetadata):

@property
async def metadata(self) -> _MetadataLike:
if self._disabled:
return self._metadata

if self._supported:
return self._metadata

async with self._lock:
if self._supported is False:
with self._lock:
if self._disabled or self._supported != _SupportedState.UNKNOWN:
return self._metadata

request = StartSessionRequest(resume=self._current_id)
response: Optional[StartSessionResponse] = None

try:
response = await self.client.StartSession(request)
except GRPCError as error:
if error.status == Status.UNIMPLEMENTED:
self._supported = False
request = StartSessionRequest(resume=self._current_id)
try:
response = await self.client.StartSession(request)
except GRPCError as error:
if error.status == Status.UNIMPLEMENTED:
with self._lock:
self._reset()
self._supported = _SupportedState.FALSE
return self._metadata
else:
raise
else:
if response is None:
raise GRPCError(status=Status.INTERNAL, message="Expected response to start session")

if response.heartbeat_window is None:
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session")
raise

self._supported = True
self._heartbeat_interval = response.heartbeat_window.ToTimedelta()
self._current_id = response.id
if response is None:
raise GRPCError(status=Status.INTERNAL, message="Expected response to start session")

# tick once to ensure heartbeats are supported
await self._heartbeat_tick()

if self._supported:
# We send heartbeats slightly faster than the interval window to
# ensure that we don't fall outside of it and expire the session.
wait = self._heartbeat_interval.total_seconds() / 5
asyncio.create_task(self._heartbeat_task(wait), name=f"{_TASK_PREFIX}-heartbeat")

return self._metadata

async def _heartbeat_task(self, wait: float):
while self._supported:
await asyncio.sleep(wait)
await self._heartbeat_tick()
if response.heartbeat_window is None:
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session")

async def _heartbeat_tick(self):
if not self._supported:
return
with self._lock:
self._supported = _SupportedState.TRUE
self._heartbeat_interval = response.heartbeat_window.ToTimedelta()
self._current_id = response.id

while self._lock.locked():
pass
# tick once to ensure heartbeats are supported
await self._heartbeat_tick(self.client)

with self._lock:
if self._thread is not None:
self._reset()
if self._supported == _SupportedState.TRUE:
# We send heartbeats faster than the interval window to
# ensure that we don't fall outside of it and expire the session.
wait = self._heartbeat_interval.total_seconds() / 5

self._thread = Thread(
name="heartbeat-thread",
target=asyncio.run,
args=(self._heartbeat_process(wait),),
daemon=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure this should be daemonized?

Copy link
Member Author

@cheukt cheukt Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main issue I ran into a non-daemonized thread is that it doesn't stop on a ctrl+c, and there doesn't seem to be a place to catch to interrupt and kill the thread either? since the daemonic thread just dies on program exit, the cleanup can be handled like normal (since the one rust connection is reused).

)
self._thread.start()
Comment on lines +130 to +136
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: is it possible to get here with an active thread? (e.g. self._thread is not None). If so, do we need to do any cleanup?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is possible, added a reset if so


request = SendSessionHeartbeatRequest(id=self._current_id)
return self._metadata

if self._heartbeat_interval is None:
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session")
async def _heartbeat_tick(self, client: RobotServiceStub):
with self._lock:
if not self._current_id:
LOGGER.debug("Failed to send heartbeat, session client reset")
return
request = SendSessionHeartbeatRequest(id=self._current_id)

try:
await self.client.SendSessionHeartbeat(request)
await client.SendSessionHeartbeat(request)
except (GRPCError, StreamTerminatedError):
LOGGER.debug("Heartbeat terminated", exc_info=True)
self.reset()
else:
LOGGER.debug("Sent heartbeat successfully")

async def _heartbeat_process(self, wait: float):
dial_options = self._dial_options if self._dial_options is not None else DialOptions()
dial_options.disable_webrtc = True

channel = await dial(address=self._address, options=dial_options)
client = RobotServiceStub(channel.channel)
while True:
with self._lock:
if self._supported != _SupportedState.TRUE:
return
await self._heartbeat_tick(client)
await asyncio.sleep(wait)

@property
def _metadata(self) -> _MetadataLike:
if self._supported and self._current_id != "":
if self._supported == _SupportedState.TRUE and self._current_id != "":
return {SESSION_METADATA_KEY: self._current_id}

return {}
102 changes: 102 additions & 0 deletions tests/mocks/robot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from datetime import timedelta

from google.protobuf.duration_pb2 import Duration
from grpclib.server import Stream

from viam.errors import MethodNotImplementedError
from viam.proto.robot import (
BlockForOperationRequest,
BlockForOperationResponse,
CancelOperationRequest,
CancelOperationResponse,
DiscoverComponentsRequest,
DiscoverComponentsResponse,
FrameSystemConfigRequest,
FrameSystemConfigResponse,
GetOperationsRequest,
GetOperationsResponse,
GetSessionsRequest,
GetSessionsResponse,
GetStatusRequest,
GetStatusResponse,
ResourceNamesRequest,
ResourceNamesResponse,
ResourceRPCSubtypesRequest,
ResourceRPCSubtypesResponse,
RobotServiceBase,
SendSessionHeartbeatRequest,
SendSessionHeartbeatResponse,
StartSessionRequest,
StartSessionResponse,
StopAllRequest,
StopAllResponse,
StreamStatusRequest,
StreamStatusResponse,
TransformPCDRequest,
TransformPCDResponse,
TransformPoseRequest,
TransformPoseResponse,
)


class MockRobot(RobotServiceBase):
SESSION_ID = "sid"
HEARTBEAT_INTERVAL = 2

def __init__(self):
self.heartbeat_count = 0
super().__init__()

async def StartSession(self, stream: Stream[StartSessionRequest, StartSessionResponse]) -> None:
request = await stream.recv_message()
assert request is not None
heartbeat_window = Duration()
heartbeat_window.FromTimedelta(timedelta(seconds=self.HEARTBEAT_INTERVAL))
response = StartSessionResponse(id=self.SESSION_ID, heartbeat_window=heartbeat_window)
await stream.send_message(response)

async def SendSessionHeartbeat(self, stream: Stream[SendSessionHeartbeatRequest, SendSessionHeartbeatResponse]) -> None:
request = await stream.recv_message()
assert request is not None
self.heartbeat_count += 1
response = SendSessionHeartbeatResponse()
await stream.send_message(response)

async def ResourceNames(self, stream: Stream[ResourceNamesRequest, ResourceNamesResponse]) -> None:
raise MethodNotImplementedError("ResourceNames").grpc_error

async def GetStatus(self, stream: Stream[GetStatusRequest, GetStatusResponse]) -> None:
raise MethodNotImplementedError("GetStatus").grpc_error

async def StreamStatus(self, stream: Stream[StreamStatusRequest, StreamStatusResponse]) -> None:
raise MethodNotImplementedError("StreamStatus").grpc_error

async def GetOperations(self, stream: Stream[GetOperationsRequest, GetOperationsResponse]) -> None:
raise MethodNotImplementedError("GetOperations").grpc_error

async def ResourceRPCSubtypes(self, stream: Stream[ResourceRPCSubtypesRequest, ResourceRPCSubtypesResponse]) -> None:
raise MethodNotImplementedError("ResourceRPCSubtypes").grpc_error

async def CancelOperation(self, stream: Stream[CancelOperationRequest, CancelOperationResponse]) -> None:
raise MethodNotImplementedError("CancelOperation").grpc_error

async def BlockForOperation(self, stream: Stream[BlockForOperationRequest, BlockForOperationResponse]) -> None:
raise MethodNotImplementedError("BlockForOperation").grpc_error

async def FrameSystemConfig(self, stream: Stream[FrameSystemConfigRequest, FrameSystemConfigResponse]) -> None:
raise MethodNotImplementedError("FrameSystemConfig").grpc_error

async def TransformPose(self, stream: Stream[TransformPoseRequest, TransformPoseResponse]) -> None:
raise MethodNotImplementedError("TransformPose").grpc_error

async def DiscoverComponents(self, stream: Stream[DiscoverComponentsRequest, DiscoverComponentsResponse]) -> None:
raise MethodNotImplementedError("DiscoverComponents").grpc_error

async def StopAll(self, stream: Stream[StopAllRequest, StopAllResponse]) -> None:
raise MethodNotImplementedError("StopAll").grpc_error

async def GetSessions(self, stream: Stream[GetSessionsRequest, GetSessionsResponse]) -> None:
raise MethodNotImplementedError("GetSessions").grpc_error

async def TransformPCD(self, stream: Stream[TransformPCDRequest, TransformPCDResponse]) -> None:
raise MethodNotImplementedError("TransformPCD").grpc_error
Loading