diff --git a/src/viam/robot/client.py b/src/viam/robot/client.py index 1b1fbaa97..9edb50da1 100644 --- a/src/viam/robot/client.py +++ b/src/viam/robot/client.py @@ -155,7 +155,9 @@ async def _with_channel(cls, channel: Union[Channel, ViamChannel], options: Opti self._should_close_channel = close_channel self._options = options self._address = self._channel._path if self._channel._path else f"{self._channel._host}:{self._channel._port}" - self._sessions_client = SessionsClient(self._channel, disabled=self._options.disable_sessions) + self._sessions_client = SessionsClient( + self._channel, self._address, self._options.dial_options, disabled=self._options.disable_sessions + ) try: await self.refresh() @@ -303,7 +305,12 @@ async def _check_connection(self, check_every: int, reconnect_every: int): self._channel = channel.channel self._viam_channel = channel self._client = RobotServiceStub(self._channel) - self._sessions_client = SessionsClient(channel=self._channel, disabled=self._options.disable_sessions) + self._sessions_client = SessionsClient( + channel=self._channel, + address=self._address, + dial_options=self._options.dial_options, + disabled=self._options.disable_sessions, + ) await self.refresh() self._connected = True diff --git a/src/viam/sessions_client.py b/src/viam/sessions_client.py index 0bc9b9877..8748ab314 100644 --- a/src/viam/sessions_client.py +++ b/src/viam/sessions_client.py @@ -1,6 +1,7 @@ import asyncio -import sys from datetime import timedelta +from enum import IntEnum +from threading import Thread, Lock from typing import Optional from grpclib import Status @@ -9,8 +10,9 @@ from grpclib.exceptions import GRPCError, StreamTerminatedError from grpclib.metadata import _MetadataLike -from viam import _TASK_PREFIX, logging -from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest, StartSessionResponse +from viam import logging +from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest +from viam.rpc.dial import DialOptions, dial LOGGER = logging.getLogger(__name__) SESSION_METADATA_KEY = "viam-sid" @@ -30,10 +32,10 @@ ) -def loop_kwargs(): - if sys.version_info <= (3, 9): - return {"loop": asyncio.get_running_loop()} - return {} +class _SupportedState(IntEnum): + UNKNOWN = 0 + TRUE = 1 + FALSE = 2 class SessionsClient: @@ -42,26 +44,34 @@ class SessionsClient: supports stopping actuating components when it's not. """ - _current_id: str = "" - _disabled: bool = False - _supported: Optional[bool] = None - _heartbeat_interval: Optional[timedelta] = None - - def __init__(self, channel: Channel, *, disabled: bool = False): + def __init__(self, channel: Channel, address: str, dial_options: Optional[DialOptions], *, disabled: bool = False): self.channel = channel self.client = RobotServiceStub(channel) + self._address = address + self._dial_options = dial_options self._disabled = disabled - self._lock = asyncio.Lock(**loop_kwargs()) + + self._lock: Lock = Lock() + self._current_id: str = "" + self._heartbeat_interval: Optional[timedelta] = None + self._supported: _SupportedState = _SupportedState.UNKNOWN + self._thread: Optional[Thread] = None listen(self.channel, SendRequest, self._send_request) listen(self.channel, RecvTrailingMetadata, self._recv_trailers) def reset(self): - if self._lock.locked(): - return + with self._lock: + self._reset() + def _reset(self): LOGGER.debug("resetting session") - self._supported = None + self._supported = _SupportedState.UNKNOWN + self._current_id = "" + self._heartbeat_interval = None + if self._thread is not None: + self._thread.join(timeout=1) + self._thread = None async def _send_request(self, event: SendRequest): if self._disabled: @@ -79,77 +89,85 @@ async def _recv_trailers(self, event: RecvTrailingMetadata): @property async def metadata(self) -> _MetadataLike: - if self._disabled: - return self._metadata - - if self._supported: - return self._metadata - - async with self._lock: - if self._supported is False: + with self._lock: + if self._disabled or self._supported != _SupportedState.UNKNOWN: return self._metadata - request = StartSessionRequest(resume=self._current_id) - response: Optional[StartSessionResponse] = None - - try: - response = await self.client.StartSession(request) - except GRPCError as error: - if error.status == Status.UNIMPLEMENTED: - self._supported = False + request = StartSessionRequest(resume=self._current_id) + try: + response = await self.client.StartSession(request) + except GRPCError as error: + if error.status == Status.UNIMPLEMENTED: + with self._lock: + self._reset() + self._supported = _SupportedState.FALSE return self._metadata - else: - raise else: - if response is None: - raise GRPCError(status=Status.INTERNAL, message="Expected response to start session") - - if response.heartbeat_window is None: - raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session") + raise - self._supported = True - self._heartbeat_interval = response.heartbeat_window.ToTimedelta() - self._current_id = response.id + if response is None: + raise GRPCError(status=Status.INTERNAL, message="Expected response to start session") - # tick once to ensure heartbeats are supported - await self._heartbeat_tick() - - if self._supported: - # We send heartbeats slightly faster than the interval window to - # ensure that we don't fall outside of it and expire the session. - wait = self._heartbeat_interval.total_seconds() / 5 - asyncio.create_task(self._heartbeat_task(wait), name=f"{_TASK_PREFIX}-heartbeat") - - return self._metadata - - async def _heartbeat_task(self, wait: float): - while self._supported: - await asyncio.sleep(wait) - await self._heartbeat_tick() + if response.heartbeat_window is None: + raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session") - async def _heartbeat_tick(self): - if not self._supported: - return + with self._lock: + self._supported = _SupportedState.TRUE + self._heartbeat_interval = response.heartbeat_window.ToTimedelta() + self._current_id = response.id - while self._lock.locked(): - pass + # tick once to ensure heartbeats are supported + await self._heartbeat_tick(self.client) + + with self._lock: + if self._thread is not None: + self._reset() + if self._supported == _SupportedState.TRUE: + # We send heartbeats faster than the interval window to + # ensure that we don't fall outside of it and expire the session. + wait = self._heartbeat_interval.total_seconds() / 5 + + self._thread = Thread( + name="heartbeat-thread", + target=asyncio.run, + args=(self._heartbeat_process(wait),), + daemon=True, + ) + self._thread.start() - request = SendSessionHeartbeatRequest(id=self._current_id) + return self._metadata - if self._heartbeat_interval is None: - raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session") + async def _heartbeat_tick(self, client: RobotServiceStub): + with self._lock: + if not self._current_id: + LOGGER.debug("Failed to send heartbeat, session client reset") + return + request = SendSessionHeartbeatRequest(id=self._current_id) try: - await self.client.SendSessionHeartbeat(request) + await client.SendSessionHeartbeat(request) except (GRPCError, StreamTerminatedError): LOGGER.debug("Heartbeat terminated", exc_info=True) self.reset() else: LOGGER.debug("Sent heartbeat successfully") + async def _heartbeat_process(self, wait: float): + dial_options = self._dial_options if self._dial_options is not None else DialOptions() + dial_options.disable_webrtc = True + + channel = await dial(address=self._address, options=dial_options) + client = RobotServiceStub(channel.channel) + while True: + with self._lock: + if self._supported != _SupportedState.TRUE: + return + await self._heartbeat_tick(client) + await asyncio.sleep(wait) + @property def _metadata(self) -> _MetadataLike: - if self._supported and self._current_id != "": + if self._supported == _SupportedState.TRUE and self._current_id != "": return {SESSION_METADATA_KEY: self._current_id} return {} diff --git a/tests/mocks/robot.py b/tests/mocks/robot.py new file mode 100644 index 000000000..ee9a6220a --- /dev/null +++ b/tests/mocks/robot.py @@ -0,0 +1,102 @@ +from datetime import timedelta + +from google.protobuf.duration_pb2 import Duration +from grpclib.server import Stream + +from viam.errors import MethodNotImplementedError +from viam.proto.robot import ( + BlockForOperationRequest, + BlockForOperationResponse, + CancelOperationRequest, + CancelOperationResponse, + DiscoverComponentsRequest, + DiscoverComponentsResponse, + FrameSystemConfigRequest, + FrameSystemConfigResponse, + GetOperationsRequest, + GetOperationsResponse, + GetSessionsRequest, + GetSessionsResponse, + GetStatusRequest, + GetStatusResponse, + ResourceNamesRequest, + ResourceNamesResponse, + ResourceRPCSubtypesRequest, + ResourceRPCSubtypesResponse, + RobotServiceBase, + SendSessionHeartbeatRequest, + SendSessionHeartbeatResponse, + StartSessionRequest, + StartSessionResponse, + StopAllRequest, + StopAllResponse, + StreamStatusRequest, + StreamStatusResponse, + TransformPCDRequest, + TransformPCDResponse, + TransformPoseRequest, + TransformPoseResponse, +) + + +class MockRobot(RobotServiceBase): + SESSION_ID = "sid" + HEARTBEAT_INTERVAL = 2 + + def __init__(self): + self.heartbeat_count = 0 + super().__init__() + + async def StartSession(self, stream: Stream[StartSessionRequest, StartSessionResponse]) -> None: + request = await stream.recv_message() + assert request is not None + heartbeat_window = Duration() + heartbeat_window.FromTimedelta(timedelta(seconds=self.HEARTBEAT_INTERVAL)) + response = StartSessionResponse(id=self.SESSION_ID, heartbeat_window=heartbeat_window) + await stream.send_message(response) + + async def SendSessionHeartbeat(self, stream: Stream[SendSessionHeartbeatRequest, SendSessionHeartbeatResponse]) -> None: + request = await stream.recv_message() + assert request is not None + self.heartbeat_count += 1 + response = SendSessionHeartbeatResponse() + await stream.send_message(response) + + async def ResourceNames(self, stream: Stream[ResourceNamesRequest, ResourceNamesResponse]) -> None: + raise MethodNotImplementedError("ResourceNames").grpc_error + + async def GetStatus(self, stream: Stream[GetStatusRequest, GetStatusResponse]) -> None: + raise MethodNotImplementedError("GetStatus").grpc_error + + async def StreamStatus(self, stream: Stream[StreamStatusRequest, StreamStatusResponse]) -> None: + raise MethodNotImplementedError("StreamStatus").grpc_error + + async def GetOperations(self, stream: Stream[GetOperationsRequest, GetOperationsResponse]) -> None: + raise MethodNotImplementedError("GetOperations").grpc_error + + async def ResourceRPCSubtypes(self, stream: Stream[ResourceRPCSubtypesRequest, ResourceRPCSubtypesResponse]) -> None: + raise MethodNotImplementedError("ResourceRPCSubtypes").grpc_error + + async def CancelOperation(self, stream: Stream[CancelOperationRequest, CancelOperationResponse]) -> None: + raise MethodNotImplementedError("CancelOperation").grpc_error + + async def BlockForOperation(self, stream: Stream[BlockForOperationRequest, BlockForOperationResponse]) -> None: + raise MethodNotImplementedError("BlockForOperation").grpc_error + + async def FrameSystemConfig(self, stream: Stream[FrameSystemConfigRequest, FrameSystemConfigResponse]) -> None: + raise MethodNotImplementedError("FrameSystemConfig").grpc_error + + async def TransformPose(self, stream: Stream[TransformPoseRequest, TransformPoseResponse]) -> None: + raise MethodNotImplementedError("TransformPose").grpc_error + + async def DiscoverComponents(self, stream: Stream[DiscoverComponentsRequest, DiscoverComponentsResponse]) -> None: + raise MethodNotImplementedError("DiscoverComponents").grpc_error + + async def StopAll(self, stream: Stream[StopAllRequest, StopAllResponse]) -> None: + raise MethodNotImplementedError("StopAll").grpc_error + + async def GetSessions(self, stream: Stream[GetSessionsRequest, GetSessionsResponse]) -> None: + raise MethodNotImplementedError("GetSessions").grpc_error + + async def TransformPCD(self, stream: Stream[TransformPCDRequest, TransformPCDResponse]) -> None: + raise MethodNotImplementedError("TransformPCD").grpc_error diff --git a/tests/test_sessions_client.py b/tests/test_sessions_client.py index 81834a799..97b02a80f 100644 --- a/tests/test_sessions_client.py +++ b/tests/test_sessions_client.py @@ -1,68 +1,57 @@ -from datetime import timedelta +import asyncio +import socket +import time +from concurrent.futures import ThreadPoolExecutor +from typing import List import pytest -from google.protobuf.duration_pb2 import Duration from grpclib import GRPCError, Status -from grpclib.server import Stream +from grpclib._typing import IServable +from grpclib.server import Server as GRPCServer from grpclib.testing import ChannelFor -from viam.proto.robot import SendSessionHeartbeatRequest, SendSessionHeartbeatResponse, StartSessionRequest, StartSessionResponse -from viam.resource.manager import ResourceManager -from viam.robot.service import RobotService -from viam.sessions_client import SESSION_METADATA_KEY, SessionsClient +from viam.errors import MethodNotImplementedError +from viam.rpc.dial import DialOptions, dial +from viam.sessions_client import SESSION_METADATA_KEY, SessionsClient, _SupportedState -SESSION_ID = "sid" -HEARTBEAT_INTERVAL = 2 +from .mocks.robot import MockRobot @pytest.fixture(scope="function") -def service() -> RobotService: - async def StartSession(stream: Stream[StartSessionRequest, StartSessionResponse]) -> None: - request = await stream.recv_message() - assert request is not None - heartbeat_window = Duration() - heartbeat_window.FromTimedelta(timedelta(seconds=HEARTBEAT_INTERVAL)) - response = StartSessionResponse(id=SESSION_ID, heartbeat_window=heartbeat_window) - await stream.send_message(response) - - async def SendSessionHeartbeat(stream: Stream[SendSessionHeartbeatRequest, SendSessionHeartbeatResponse]) -> None: - request = await stream.recv_message() - assert request is not None - response = SendSessionHeartbeatResponse() - await stream.send_message(response) - - manager = ResourceManager([]) - service = RobotService(manager) - service.StartSession = StartSession - service.SendSessionHeartbeat = SendSessionHeartbeat - - return service +def service() -> MockRobot: + return MockRobot() @pytest.fixture(scope="function") -def service_without_session(service: RobotService) -> RobotService: - del service.StartSession +def service_without_session(service: MockRobot) -> MockRobot: + async def StartSession(stream) -> None: + raise MethodNotImplementedError("StartSession").grpc_error + + service.StartSession = StartSession return service @pytest.fixture(scope="function") -def service_without_heartbeat(service: RobotService) -> RobotService: - del service.SendSessionHeartbeat +def service_without_heartbeat(service: MockRobot) -> MockRobot: + async def SendSessionHeartbeat(stream) -> None: + raise MethodNotImplementedError("SendSessionHeartbeat").grpc_error + + service.SendSessionHeartbeat = SendSessionHeartbeat return service @pytest.mark.asyncio async def test_init_client(): async with ChannelFor([]) as channel: - client = SessionsClient(channel) + client = SessionsClient(channel, "", None) assert client._current_id == "" - assert client._supported is None + assert client._supported == _SupportedState.UNKNOWN @pytest.mark.asyncio async def test_sessions_error(): async with ChannelFor([]) as channel: - client = SessionsClient(channel) + client = SessionsClient(channel, "", None) with pytest.raises(GRPCError) as e_info: assert await client.metadata == {} @@ -73,42 +62,71 @@ async def test_sessions_error(): @pytest.mark.asyncio async def test_sessions_not_supported(): async with ChannelFor([]) as channel: - client = SessionsClient(channel) - client._supported = False + client = SessionsClient(channel, "", None) + client._supported = _SupportedState.FALSE assert await client.metadata == {} - assert client._supported is False + assert client._supported == _SupportedState.FALSE @pytest.mark.asyncio -async def test_sessions_not_implemented(service_without_session: RobotService): +async def test_sessions_not_implemented(service_without_session: MockRobot): async with ChannelFor([service_without_session]) as channel: - client = SessionsClient(channel) + client = SessionsClient(channel, "", None) assert await client.metadata == {} - assert client._supported is False + assert client._supported == _SupportedState.FALSE @pytest.mark.asyncio -async def test_sessions_heartbeat_disconnect(service_without_heartbeat: RobotService): +async def test_sessions_heartbeat_disconnect(service_without_heartbeat: MockRobot): async with ChannelFor([service_without_heartbeat]) as channel: - client = SessionsClient(channel) + client = SessionsClient(channel, "", None) assert await client.metadata == {} - assert client._supported is None + assert client._supported == _SupportedState.UNKNOWN + + +async def _run_server(sock: socket.socket, handlers: List[IServable], shutdown_signal: asyncio.Event): + server = GRPCServer(handlers=handlers) + await server.start(sock=sock) + # shutdown_signal.wait() seems to be bugged <3.9 and blocks the thread, + # so have to do a bit of a busy wait here + while not shutdown_signal.is_set(): + await asyncio.sleep(0.1) + server.close() @pytest.mark.asyncio -async def test_sessions_heartbeat(service: RobotService): - async with ChannelFor([service]) as channel: - client = SessionsClient(channel) - assert await client.metadata == {SESSION_METADATA_KEY: SESSION_ID} - assert client._supported - assert client._heartbeat_interval and client._heartbeat_interval.total_seconds() == HEARTBEAT_INTERVAL - assert client._current_id == SESSION_ID +async def test_sessions_heartbeat_thread_blocked(): + sock = socket.socket() + sock.bind(("", 0)) + + shutdown_signal = asyncio.Event() + m = MockRobot() + t = ThreadPoolExecutor() + t.submit(asyncio.run, _run_server(sock, [m], shutdown_signal)) + + await asyncio.sleep(0.5) + + port = sock.getsockname()[1] + addr = f"localhost:{port}" + options = DialOptions(disable_webrtc=True, insecure=True) + channel = await dial(address=addr, options=options) + + client = SessionsClient(channel.channel, addr, options) + assert await client.metadata == {SESSION_METADATA_KEY: MockRobot.SESSION_ID} + + assert client._supported == _SupportedState.TRUE + assert client._heartbeat_interval and client._heartbeat_interval.total_seconds() == MockRobot.HEARTBEAT_INTERVAL + + time.sleep(3) + shutdown_signal.set() + client.reset() + assert m.heartbeat_count >= 5 @pytest.mark.asyncio -async def test_sessions_disabled(service: RobotService): +async def test_sessions_disabled(service: MockRobot): async with ChannelFor([service]) as channel: - client = SessionsClient(channel, disabled=True) + client = SessionsClient(channel, "", None, disabled=True) assert await client.metadata == {} - assert client._supported is None + assert client._supported == _SupportedState.UNKNOWN assert not client._heartbeat_interval