-
Notifications
You must be signed in to change notification settings - Fork 64
RSDK-2506 Sessions client #301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
maximpertsov
merged 31 commits into
viamrobotics:main
from
maximpertsov:sessions-python-RSDK-2506
May 30, 2023
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
c27f1a1
add test watcher
maximpertsov 235a7bf
spike session client
maximpertsov 4992c05
remove unused
maximpertsov f6da55a
client testing
maximpertsov 1ed1298
revert debug
maximpertsov 7cd5139
send heartbeats faster than interval
maximpertsov beaa5bf
init sessions client in viam channel
maximpertsov fb2ff54
wip: pass metadata into request in viam channel
maximpertsov b48a9a2
fixup
maximpertsov 5063724
make inner metadata property sync
maximpertsov 4e9ab91
reset session when intercepting expiration event
maximpertsov 7720fcf
wip
maximpertsov 66aa2d9
reset session on close
maximpertsov d9eda25
fixup expiration check
maximpertsov 19919ee
5%
maximpertsov 0ad2f5d
allow sessions to be disabled
maximpertsov 327d56a
do not send session metadata for session-specific calls
maximpertsov 622c9ba
lock
maximpertsov ce86b42
change option
maximpertsov 0222499
shorter beat interval
maximpertsov b4ad97a
async lock + only acquire lock at startup
maximpertsov c4396b9
send session metadata with heartbeat
maximpertsov 3ef9be1
reduce debug logs
maximpertsov cb886a6
cleanup flow
maximpertsov 6843d96
Merge branch 'main' into sessions-python-RSDK-2506
maximpertsov a5f448d
make session_id and supported field internal
maximpertsov afaff1b
remove unused
maximpertsov 9f0c4e0
CR@edaniels exclude methods from receiving session metadata
maximpertsov bf95c3a
CR@njooma: sentence case logs/errors
maximpertsov cf91a1f
remove pytest-watcher
maximpertsov 50c6bda
add sessions disabled test
maximpertsov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import asyncio | ||
from datetime import timedelta | ||
from typing import Optional | ||
|
||
from grpclib import Status | ||
from grpclib.client import Channel | ||
from grpclib.events import RecvTrailingMetadata, SendRequest, listen | ||
from grpclib.exceptions import GRPCError, StreamTerminatedError | ||
from grpclib.metadata import _MetadataLike | ||
|
||
from viam import _TASK_PREFIX, logging | ||
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest, StartSessionResponse | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
SESSION_METADATA_KEY = "viam-sid" | ||
|
||
EXEMPT_METADATA_METHODS = frozenset( | ||
[ | ||
"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", | ||
"/proto.rpc.webrtc.v1.SignalingService/Call", | ||
"/proto.rpc.webrtc.v1.SignalingService/CallUpdate", | ||
"/proto.rpc.webrtc.v1.SignalingService/OptionalWebRTCConfig", | ||
"/proto.rpc.v1.AuthService/Authenticate", | ||
"/viam.robot.v1.RobotService/ResourceNames", | ||
"/viam.robot.v1.RobotService/ResourceRPCSubtypes", | ||
"/viam.robot.v1.RobotService/StartSession", | ||
"/viam.robot.v1.RobotService/SendSessionHeartbeat", | ||
] | ||
) | ||
|
||
|
||
async def delay(coro, seconds): | ||
await asyncio.sleep(seconds) | ||
await coro | ||
|
||
|
||
class SessionsClient: | ||
""" | ||
A Session allows a client to express that it is actively connected and | ||
supports stopping actuating components when it's not. | ||
""" | ||
|
||
_current_id: str = "" | ||
_disabled: bool = False | ||
_lock = asyncio.Lock() | ||
_supported: Optional[bool] = None | ||
_heartbeat_interval: Optional[timedelta] = None | ||
|
||
def __init__(self, channel: Channel, *, disabled: bool = False): | ||
self.channel = channel | ||
self.client = RobotServiceStub(channel) | ||
self._disabled = disabled | ||
|
||
listen(self.channel, SendRequest, self._send_request) | ||
listen(self.channel, RecvTrailingMetadata, self._recv_trailers) | ||
|
||
def reset(self): | ||
if self._lock.locked(): | ||
return | ||
|
||
LOGGER.debug("resetting session") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] can we make all logs |
||
self._supported = None | ||
|
||
async def _send_request(self, event: SendRequest): | ||
if self._disabled: | ||
return | ||
|
||
if event.method_name in EXEMPT_METADATA_METHODS: | ||
return | ||
|
||
event.metadata.update(await self.metadata) | ||
|
||
async def _recv_trailers(self, event: RecvTrailingMetadata): | ||
if event.status == Status.INVALID_ARGUMENT and event.status_message == "SESSION_EXPIRED": | ||
LOGGER.debug("Session expired") | ||
self.reset() | ||
|
||
@property | ||
async def metadata(self) -> _MetadataLike: | ||
if self._disabled: | ||
return self._metadata | ||
|
||
if self._supported: | ||
return self._metadata | ||
|
||
async with self._lock: | ||
if self._supported is False: | ||
return self._metadata | ||
|
||
request = StartSessionRequest(resume=self._current_id) | ||
response: Optional[StartSessionResponse] = None | ||
|
||
try: | ||
response = await self.client.StartSession(request) | ||
except GRPCError as error: | ||
if error.status == Status.UNIMPLEMENTED: | ||
self._supported = False | ||
return self._metadata | ||
else: | ||
raise | ||
else: | ||
if response is None: | ||
raise GRPCError(status=Status.INTERNAL, message="Expected response to start session") | ||
|
||
if response.heartbeat_window is None: | ||
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session") | ||
|
||
self._supported = True | ||
self._heartbeat_interval = response.heartbeat_window.ToTimedelta() | ||
self._current_id = response.id | ||
|
||
await self._heartbeat_tick() | ||
|
||
return self._metadata | ||
|
||
async def _heartbeat_tick(self): | ||
if not self._supported: | ||
return | ||
|
||
while self._lock.locked(): | ||
pass | ||
|
||
request = SendSessionHeartbeatRequest(id=self._current_id) | ||
|
||
if self._heartbeat_interval is None: | ||
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session") | ||
|
||
try: | ||
await self.client.SendSessionHeartbeat(request) | ||
except (GRPCError, StreamTerminatedError): | ||
LOGGER.debug("Heartbeat terminated", exc_info=True) | ||
self.reset() | ||
else: | ||
LOGGER.debug("Sent heartbeat successfully") | ||
# We send heartbeats slightly faster than the interval window to | ||
# ensure that we don't fall outside of it and expire the session. | ||
wait = self._heartbeat_interval.total_seconds() / 5 | ||
asyncio.create_task(delay(self._heartbeat_tick(), wait), name=f"{_TASK_PREFIX}-heartbeat") | ||
|
||
@property | ||
def _metadata(self) -> _MetadataLike: | ||
if self._supported and self._current_id != "": | ||
return {SESSION_METADATA_KEY: self._current_id} | ||
|
||
return {} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from datetime import timedelta | ||
|
||
import pytest | ||
from google.protobuf.duration_pb2 import Duration | ||
from grpclib import GRPCError, Status | ||
from grpclib.server import Stream | ||
from grpclib.testing import ChannelFor | ||
|
||
from viam.proto.robot import SendSessionHeartbeatRequest, SendSessionHeartbeatResponse, StartSessionRequest, StartSessionResponse | ||
from viam.resource.manager import ResourceManager | ||
from viam.robot.service import RobotService | ||
from viam.sessions_client import SESSION_METADATA_KEY, SessionsClient | ||
|
||
SESSION_ID = "sid" | ||
HEARTBEAT_INTERVAL = 2 | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def service() -> RobotService: | ||
async def StartSession(stream: Stream[StartSessionRequest, StartSessionResponse]) -> None: | ||
request = await stream.recv_message() | ||
assert request is not None | ||
heartbeat_window = Duration() | ||
heartbeat_window.FromTimedelta(timedelta(seconds=HEARTBEAT_INTERVAL)) | ||
response = StartSessionResponse(id=SESSION_ID, heartbeat_window=heartbeat_window) | ||
await stream.send_message(response) | ||
|
||
async def SendSessionHeartbeat(stream: Stream[SendSessionHeartbeatRequest, SendSessionHeartbeatResponse]) -> None: | ||
request = await stream.recv_message() | ||
assert request is not None | ||
response = SendSessionHeartbeatResponse() | ||
await stream.send_message(response) | ||
|
||
manager = ResourceManager([]) | ||
service = RobotService(manager) | ||
service.StartSession = StartSession | ||
service.SendSessionHeartbeat = SendSessionHeartbeat | ||
|
||
return service | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def service_without_session(service: RobotService) -> RobotService: | ||
del service.StartSession | ||
return service | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def service_without_heartbeat(service: RobotService) -> RobotService: | ||
del service.SendSessionHeartbeat | ||
return service | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_init_client(): | ||
async with ChannelFor([]) as channel: | ||
client = SessionsClient(channel) | ||
assert client._current_id == "" | ||
assert client._supported is None | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_sessions_error(): | ||
async with ChannelFor([]) as channel: | ||
client = SessionsClient(channel) | ||
|
||
with pytest.raises(GRPCError) as e_info: | ||
assert await client.metadata == {} | ||
|
||
assert e_info.value.status == Status.UNKNOWN | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_sessions_not_supported(): | ||
async with ChannelFor([]) as channel: | ||
client = SessionsClient(channel) | ||
client._supported = False | ||
assert await client.metadata == {} | ||
assert client._supported is False | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_sessions_not_implemented(service_without_session: RobotService): | ||
async with ChannelFor([service_without_session]) as channel: | ||
client = SessionsClient(channel) | ||
assert await client.metadata == {} | ||
assert client._supported is False | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_sessions_heartbeat_disconnect(service_without_heartbeat: RobotService): | ||
async with ChannelFor([service_without_heartbeat]) as channel: | ||
client = SessionsClient(channel) | ||
assert await client.metadata == {} | ||
assert client._supported is None | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_sessions_heartbeat(service: RobotService): | ||
async with ChannelFor([service]) as channel: | ||
client = SessionsClient(channel) | ||
assert await client.metadata == {SESSION_METADATA_KEY: SESSION_ID} | ||
assert client._supported | ||
assert client._heartbeat_interval and client._heartbeat_interval.total_seconds() == HEARTBEAT_INTERVAL | ||
assert client._current_id == SESSION_ID | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_sessions_disabled(service: RobotService): | ||
async with ChannelFor([service]) as channel: | ||
client = SessionsClient(channel, disabled=True) | ||
assert await client.metadata == {} | ||
assert client._supported is None | ||
assert not client._heartbeat_interval |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's a
loop.call_later
function that might be useful here:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looked promising, but unfortunately
loop.call_later
takes a sync callback and not a co-routine: https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.call_later