Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c27f1a1
add test watcher
maximpertsov May 12, 2023
235a7bf
spike session client
maximpertsov Apr 28, 2023
4992c05
remove unused
maximpertsov May 19, 2023
f6da55a
client testing
maximpertsov May 19, 2023
1ed1298
revert debug
maximpertsov May 19, 2023
7cd5139
send heartbeats faster than interval
maximpertsov May 19, 2023
beaa5bf
init sessions client in viam channel
maximpertsov May 19, 2023
fb2ff54
wip: pass metadata into request in viam channel
maximpertsov May 19, 2023
b48a9a2
fixup
maximpertsov May 19, 2023
5063724
make inner metadata property sync
maximpertsov May 19, 2023
4e9ab91
reset session when intercepting expiration event
maximpertsov May 24, 2023
7720fcf
wip
maximpertsov May 24, 2023
66aa2d9
reset session on close
maximpertsov May 24, 2023
d9eda25
fixup expiration check
maximpertsov May 24, 2023
19919ee
5%
maximpertsov May 24, 2023
0ad2f5d
allow sessions to be disabled
maximpertsov May 24, 2023
327d56a
do not send session metadata for session-specific calls
maximpertsov May 24, 2023
622c9ba
lock
maximpertsov May 24, 2023
ce86b42
change option
maximpertsov May 24, 2023
0222499
shorter beat interval
maximpertsov May 25, 2023
b4ad97a
async lock + only acquire lock at startup
maximpertsov May 25, 2023
c4396b9
send session metadata with heartbeat
maximpertsov May 25, 2023
3ef9be1
reduce debug logs
maximpertsov May 25, 2023
cb886a6
cleanup flow
maximpertsov May 25, 2023
6843d96
Merge branch 'main' into sessions-python-RSDK-2506
maximpertsov May 25, 2023
a5f448d
make session_id and supported field internal
maximpertsov May 25, 2023
afaff1b
remove unused
maximpertsov May 25, 2023
9f0c4e0
CR@edaniels exclude methods from receiving session metadata
maximpertsov May 26, 2023
bf95c3a
CR@njooma: sentence case logs/errors
maximpertsov May 30, 2023
cf91a1f
remove pytest-watcher
maximpertsov May 30, 2023
50c6bda
add sessions disabled test
maximpertsov May 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/viam/robot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from viam.resource.types import RESOURCE_TYPE_COMPONENT, RESOURCE_TYPE_SERVICE, Subtype
from viam.rpc.dial import DialOptions, ViamChannel, dial
from viam.services.service_base import ServiceBase
from viam.sessions_client import SessionsClient
from viam.utils import dict_to_struct

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -93,6 +94,11 @@ class Options:
The frequency (in seconds) at which to attempt to reconnect a disconnected robot. 0 (zero) signifies no reconnection attempts
"""

disable_sessions: bool = False
"""
Whether sessions are disabled
"""

@classmethod
async def at_address(cls, address: str, options: Options) -> Self:
"""Create a robot client that is connected to the robot at the provided address.
Expand Down Expand Up @@ -138,6 +144,7 @@ async def _with_channel(cls, channel: Union[Channel, ViamChannel], options: Opti
else:
self._channel = channel.channel
self._viam_channel = channel

self._connected = True
self._client = RobotServiceStub(self._channel)
self._manager = ResourceManager()
Expand All @@ -146,6 +153,7 @@ async def _with_channel(cls, channel: Union[Channel, ViamChannel], options: Opti
self._should_close_channel = close_channel
self._options = options
self._address = self._channel._path if self._channel._path else f"{self._channel._host}:{self._channel._port}"
self._sessions_client = SessionsClient(self._channel, disabled=self._options.disable_sessions)

try:
await self.refresh()
Expand Down Expand Up @@ -180,6 +188,7 @@ async def _with_channel(cls, channel: Union[Channel, ViamChannel], options: Opti
_resource_names: List[ResourceName]
_should_close_channel: bool
_closed: bool = False
_sessions_client: SessionsClient

async def refresh(self):
"""
Expand Down Expand Up @@ -270,6 +279,8 @@ async def _check_connection(self, check_every: int, reconnect_every: int):

while not self._connected:
try:
self._sessions_client.reset()

channel = await dial(self._address, self._options.dial_options)

client: RobotServiceStub
Expand All @@ -286,12 +297,14 @@ async def _check_connection(self, check_every: int, reconnect_every: int):
self._channel = channel.channel
self._viam_channel = channel
self._client = RobotServiceStub(self._channel)
self._sessions_client = SessionsClient(channel=self._channel, disabled=self._options.disable_sessions)

await self.refresh()
self._connected = True
LOGGER.debug("Successfully reconnected robot")
except Exception as e:
LOGGER.error(f"Failed to reconnect, trying again in {reconnect_every}sec", exc_info=e)
self._sessions_client.reset()
self._close_channel()
await asyncio.sleep(reconnect_every)

Expand Down Expand Up @@ -423,6 +436,8 @@ async def close(self):
except RuntimeError:
pass

self._sessions_client.reset()

# Cancel all tasks created by VIAM
LOGGER.debug("Closing tasks spawned by Viam")
tasks = [task for task in asyncio.all_tasks() if task.get_name().startswith(viam._TASK_PREFIX)]
Expand Down
145 changes: 145 additions & 0 deletions src/viam/sessions_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import asyncio
from datetime import timedelta
from typing import Optional

from grpclib import Status
from grpclib.client import Channel
from grpclib.events import RecvTrailingMetadata, SendRequest, listen
from grpclib.exceptions import GRPCError, StreamTerminatedError
from grpclib.metadata import _MetadataLike

from viam import _TASK_PREFIX, logging
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest, StartSessionResponse

LOGGER = logging.getLogger(__name__)
SESSION_METADATA_KEY = "viam-sid"

EXEMPT_METADATA_METHODS = frozenset(
[
"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
"/proto.rpc.webrtc.v1.SignalingService/Call",
"/proto.rpc.webrtc.v1.SignalingService/CallUpdate",
"/proto.rpc.webrtc.v1.SignalingService/OptionalWebRTCConfig",
"/proto.rpc.v1.AuthService/Authenticate",
"/viam.robot.v1.RobotService/ResourceNames",
"/viam.robot.v1.RobotService/ResourceRPCSubtypes",
"/viam.robot.v1.RobotService/StartSession",
"/viam.robot.v1.RobotService/SendSessionHeartbeat",
]
)


async def delay(coro, seconds):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a loop.call_later function that might be useful here:

loop = asyncio.get_running_loop()
loop.call_later(seconds, coro)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looked promising, but unfortunately loop.call_later takes a sync callback and not a co-routine: https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.call_later

await asyncio.sleep(seconds)
await coro


class SessionsClient:
"""
A Session allows a client to express that it is actively connected and
supports stopping actuating components when it's not.
"""

_current_id: str = ""
_disabled: bool = False
_lock = asyncio.Lock()
_supported: Optional[bool] = None
_heartbeat_interval: Optional[timedelta] = None

def __init__(self, channel: Channel, *, disabled: bool = False):
self.channel = channel
self.client = RobotServiceStub(channel)
self._disabled = disabled

listen(self.channel, SendRequest, self._send_request)
listen(self.channel, RecvTrailingMetadata, self._recv_trailers)

def reset(self):
if self._lock.locked():
return

LOGGER.debug("resetting session")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] can we make all logs Sentence case please?

self._supported = None

async def _send_request(self, event: SendRequest):
if self._disabled:
return

if event.method_name in EXEMPT_METADATA_METHODS:
return

event.metadata.update(await self.metadata)

async def _recv_trailers(self, event: RecvTrailingMetadata):
if event.status == Status.INVALID_ARGUMENT and event.status_message == "SESSION_EXPIRED":
LOGGER.debug("Session expired")
self.reset()

@property
async def metadata(self) -> _MetadataLike:
if self._disabled:
return self._metadata

if self._supported:
return self._metadata

async with self._lock:
if self._supported is False:
return self._metadata

request = StartSessionRequest(resume=self._current_id)
response: Optional[StartSessionResponse] = None

try:
response = await self.client.StartSession(request)
except GRPCError as error:
if error.status == Status.UNIMPLEMENTED:
self._supported = False
return self._metadata
else:
raise
else:
if response is None:
raise GRPCError(status=Status.INTERNAL, message="Expected response to start session")

if response.heartbeat_window is None:
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session")

self._supported = True
self._heartbeat_interval = response.heartbeat_window.ToTimedelta()
self._current_id = response.id

await self._heartbeat_tick()

return self._metadata

async def _heartbeat_tick(self):
if not self._supported:
return

while self._lock.locked():
pass

request = SendSessionHeartbeatRequest(id=self._current_id)

if self._heartbeat_interval is None:
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session")

try:
await self.client.SendSessionHeartbeat(request)
except (GRPCError, StreamTerminatedError):
LOGGER.debug("Heartbeat terminated", exc_info=True)
self.reset()
else:
LOGGER.debug("Sent heartbeat successfully")
# We send heartbeats slightly faster than the interval window to
# ensure that we don't fall outside of it and expire the session.
wait = self._heartbeat_interval.total_seconds() / 5
asyncio.create_task(delay(self._heartbeat_tick(), wait), name=f"{_TASK_PREFIX}-heartbeat")

@property
def _metadata(self) -> _MetadataLike:
if self._supported and self._current_id != "":
return {SESSION_METADATA_KEY: self._current_id}

return {}
114 changes: 114 additions & 0 deletions tests/test_sessions_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from datetime import timedelta

import pytest
from google.protobuf.duration_pb2 import Duration
from grpclib import GRPCError, Status
from grpclib.server import Stream
from grpclib.testing import ChannelFor

from viam.proto.robot import SendSessionHeartbeatRequest, SendSessionHeartbeatResponse, StartSessionRequest, StartSessionResponse
from viam.resource.manager import ResourceManager
from viam.robot.service import RobotService
from viam.sessions_client import SESSION_METADATA_KEY, SessionsClient

SESSION_ID = "sid"
HEARTBEAT_INTERVAL = 2


@pytest.fixture(scope="function")
def service() -> RobotService:
async def StartSession(stream: Stream[StartSessionRequest, StartSessionResponse]) -> None:
request = await stream.recv_message()
assert request is not None
heartbeat_window = Duration()
heartbeat_window.FromTimedelta(timedelta(seconds=HEARTBEAT_INTERVAL))
response = StartSessionResponse(id=SESSION_ID, heartbeat_window=heartbeat_window)
await stream.send_message(response)

async def SendSessionHeartbeat(stream: Stream[SendSessionHeartbeatRequest, SendSessionHeartbeatResponse]) -> None:
request = await stream.recv_message()
assert request is not None
response = SendSessionHeartbeatResponse()
await stream.send_message(response)

manager = ResourceManager([])
service = RobotService(manager)
service.StartSession = StartSession
service.SendSessionHeartbeat = SendSessionHeartbeat

return service


@pytest.fixture(scope="function")
def service_without_session(service: RobotService) -> RobotService:
del service.StartSession
return service


@pytest.fixture(scope="function")
def service_without_heartbeat(service: RobotService) -> RobotService:
del service.SendSessionHeartbeat
return service


@pytest.mark.asyncio
async def test_init_client():
async with ChannelFor([]) as channel:
client = SessionsClient(channel)
assert client._current_id == ""
assert client._supported is None


@pytest.mark.asyncio
async def test_sessions_error():
async with ChannelFor([]) as channel:
client = SessionsClient(channel)

with pytest.raises(GRPCError) as e_info:
assert await client.metadata == {}

assert e_info.value.status == Status.UNKNOWN


@pytest.mark.asyncio
async def test_sessions_not_supported():
async with ChannelFor([]) as channel:
client = SessionsClient(channel)
client._supported = False
assert await client.metadata == {}
assert client._supported is False


@pytest.mark.asyncio
async def test_sessions_not_implemented(service_without_session: RobotService):
async with ChannelFor([service_without_session]) as channel:
client = SessionsClient(channel)
assert await client.metadata == {}
assert client._supported is False


@pytest.mark.asyncio
async def test_sessions_heartbeat_disconnect(service_without_heartbeat: RobotService):
async with ChannelFor([service_without_heartbeat]) as channel:
client = SessionsClient(channel)
assert await client.metadata == {}
assert client._supported is None


@pytest.mark.asyncio
async def test_sessions_heartbeat(service: RobotService):
async with ChannelFor([service]) as channel:
client = SessionsClient(channel)
assert await client.metadata == {SESSION_METADATA_KEY: SESSION_ID}
assert client._supported
assert client._heartbeat_interval and client._heartbeat_interval.total_seconds() == HEARTBEAT_INTERVAL
assert client._current_id == SESSION_ID


@pytest.mark.asyncio
async def test_sessions_disabled(service: RobotService):
async with ChannelFor([service]) as channel:
client = SessionsClient(channel, disabled=True)
assert await client.metadata == {}
assert client._supported is None
assert not client._heartbeat_interval