Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/viam/rpc/dial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import sys
import warnings
from dataclasses import dataclass
from typing import Callable, Literal, Optional, Tuple, Type
from typing import Callable, Literal, Optional, Tuple, Type, Union

from grpclib.client import Channel, Stream
from grpclib.const import Cardinality
from grpclib.metadata import Deadline, _MetadataLike
from grpclib.protocol import H2Protocol
from grpclib.stream import _RecvType, _SendType

from viam import logging
Expand Down Expand Up @@ -112,6 +113,27 @@ async def _get_access_token(channel: Channel, address: str, opts: DialOptions) -
class AuthenticatedChannel(Channel):
_metadata: _MetadataLike

def __init__(
self,
host: Optional[str] = None,
port: Optional[int] = None,
*,
ssl: Union[None, bool, ssl.SSLContext] = None,
server_hostname: Optional[str] = None,
):
super().__init__(host, port, ssl=ssl)
self._server_hostname = server_hostname

async def _create_connection(self) -> H2Protocol:
_, protocol = await self._loop.create_connection(
self._protocol_factory,
self._host,
self._port,
ssl=self._ssl,
server_hostname=self._server_hostname,
)
return protocol

def request(
self,
name: str,
Expand Down Expand Up @@ -265,7 +287,7 @@ async def _dial_direct(address: str, options: Optional[DialOptions] = None) -> C
ctx = None

if opts.credentials:
channel = AuthenticatedChannel(server_hostname, port, ssl=ctx)
channel = AuthenticatedChannel(host, port, ssl=ctx, server_hostname=server_hostname)
access_token = await _get_access_token(channel, address, opts)
metadata = {"authorization": f"Bearer {access_token}"}
channel._metadata = metadata
Expand Down