Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions ydb/resolver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import contextlib
import logging
import threading
import random
import itertools
from . import connection as conn_impl, issues, settings as settings_impl, _apis
import typing
from . import connection as conn_impl, driver, issues, settings as settings_impl, _apis


# Workaround for good IDE and universal for runtime
if typing.TYPE_CHECKING:
from ._grpc.v4.protos import ydb_discovery_pb2
else:
from ._grpc.common.protos import ydb_discovery_pb2


logger = logging.getLogger(__name__)

Expand All @@ -22,7 +33,7 @@ class EndpointInfo(object):
"node_id",
)

def __init__(self, endpoint_info):
def __init__(self, endpoint_info: ydb_discovery_pb2.EndpointInfo):
self.address = endpoint_info.address
self.endpoint = "%s:%s" % (endpoint_info.address, endpoint_info.port)
self.location = endpoint_info.location
Expand All @@ -33,7 +44,7 @@ def __init__(self, endpoint_info):
self.ssl_target_name_override = endpoint_info.ssl_target_name_override
self.node_id = endpoint_info.node_id

def endpoints_with_options(self):
def endpoints_with_options(self) -> typing.Generator[typing.Tuple[str, conn_impl.EndpointOptions], None, None]:
ssl_target_name_override = None
if self.ssl:
if self.ssl_target_name_override:
Expand Down Expand Up @@ -73,14 +84,14 @@ def __eq__(self, other):
return self.endpoint == other.endpoint


def _list_endpoints_request_factory(connection_params):
def _list_endpoints_request_factory(connection_params: driver.DriverConfig) -> _apis.ydb_discovery.ListEndpointsRequest:
request = _apis.ydb_discovery.ListEndpointsRequest()
request.database = connection_params.database
return request


class DiscoveryResult(object):
def __init__(self, self_location, endpoints):
def __init__(self, self_location: str, endpoints: "list[EndpointInfo]"):
self.self_location = self_location
self.endpoints = endpoints

Expand All @@ -94,7 +105,12 @@ def __repr__(self):
return self.__str__()

@classmethod
def from_response(cls, rpc_state, response, use_all_nodes=False):
def from_response(
cls,
rpc_state: conn_impl._RpcState,
response: ydb_discovery_pb2.ListEndpointsResponse,
use_all_nodes: bool = False,
) -> DiscoveryResult:
issues._process_response(response.operation)
message = _apis.ydb_discovery.ListEndpointsResult()
response.operation.result.Unpack(message)
Expand Down Expand Up @@ -123,7 +139,7 @@ def from_response(cls, rpc_state, response, use_all_nodes=False):


class DiscoveryEndpointsResolver(object):
def __init__(self, driver_config):
def __init__(self, driver_config: driver.DriverConfig):
self.logger = logger.getChild(self.__class__.__name__)
self._driver_config = driver_config
self._ready_timeout = getattr(self._driver_config, "discovery_request_timeout", 10)
Expand All @@ -136,27 +152,27 @@ def __init__(self, driver_config):
random.shuffle(self._endpoints)
self._endpoints_iter = itertools.cycle(self._endpoints)

def _add_debug_details(self, message, *args):
def _add_debug_details(self, message: str, *args):
self.logger.debug(message, *args)
message = message % args
with self._lock:
self._debug_details_items.append(message)
if len(self._debug_details_items) > self._debug_details_history_size:
self._debug_details_items.pop()

def debug_details(self):
def debug_details(self) -> str:
"""
Returns last resolver errors as a debug string.
"""
with self._lock:
return "\n".join(self._debug_details_items)

def resolve(self):
def resolve(self) -> typing.ContextManager[typing.Optional[DiscoveryResult]]:
with self.context_resolve() as result:
return result

@contextlib.contextmanager
def context_resolve(self):
def context_resolve(self) -> typing.ContextManager[typing.Optional[DiscoveryResult]]:
self.logger.debug("Preparing initial endpoint to resolve endpoints")
endpoint = next(self._endpoints_iter)
initial = conn_impl.Connection.ready_factory(endpoint, self._driver_config, ready_timeout=self._ready_timeout)
Expand Down