diff --git a/src/viam/rpc/dial.py b/src/viam/rpc/dial.py index 580997a5c..a63284a17 100644 --- a/src/viam/rpc/dial.py +++ b/src/viam/rpc/dial.py @@ -6,11 +6,12 @@ import sys import warnings from dataclasses import dataclass -from typing import Callable, Literal, Optional, Tuple, Type +from typing import Callable, Literal, Optional, Tuple, Type, Union from grpclib.client import Channel, Stream from grpclib.const import Cardinality from grpclib.metadata import Deadline, _MetadataLike +from grpclib.protocol import H2Protocol from grpclib.stream import _RecvType, _SendType from viam import logging @@ -112,6 +113,27 @@ async def _get_access_token(channel: Channel, address: str, opts: DialOptions) - class AuthenticatedChannel(Channel): _metadata: _MetadataLike + def __init__( + self, + host: Optional[str] = None, + port: Optional[int] = None, + *, + ssl: Union[None, bool, ssl.SSLContext] = None, + server_hostname: Optional[str] = None, + ): + super().__init__(host, port, ssl=ssl) + self._server_hostname = server_hostname + + async def _create_connection(self) -> H2Protocol: + _, protocol = await self._loop.create_connection( + self._protocol_factory, + self._host, + self._port, + ssl=self._ssl, + server_hostname=self._server_hostname, + ) + return protocol + def request( self, name: str, @@ -265,7 +287,7 @@ async def _dial_direct(address: str, options: Optional[DialOptions] = None) -> C ctx = None if opts.credentials: - channel = AuthenticatedChannel(server_hostname, port, ssl=ctx) + channel = AuthenticatedChannel(host, port, ssl=ctx, server_hostname=server_hostname) access_token = await _get_access_token(channel, address, opts) metadata = {"authorization": f"Bearer {access_token}"} channel._metadata = metadata