Skip to content

Commit

Permalink
fix: shuffle attempts when one host resolves to more than one IP
Browse files Browse the repository at this point in the history
This behaviour (first resolve all the hosts, then shuffle the IPs) mimics
better what the libpq does in non-async mode.
  • Loading branch information
dvarrazzo committed Dec 13, 2023
1 parent bfb68f1 commit bd259ee
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 47 deletions.
22 changes: 7 additions & 15 deletions psycopg/psycopg/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,13 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
hostaddrs: list[str] = []
ports: list[str] = []

for attempt in conninfo._split_attempts(params):
try:
async for a2 in conninfo._split_attempts_and_resolve(attempt):
if a2.get("host") is not None:
hosts.append(a2["host"])
if a2.get("hostaddr") is not None:
hostaddrs.append(a2["hostaddr"])
if a2.get("port") is not None:
ports.append(str(a2["port"]))
except OSError as ex:
last_exc = ex

if params.get("host") and not hosts:
# We couldn't resolve anything
raise e.OperationalError(str(last_exc))
async for attempt in conninfo.conninfo_attempts_async(params):
if attempt.get("host") is not None:
hosts.append(attempt["host"])
if attempt.get("hostaddr") is not None:
hostaddrs.append(attempt["hostaddr"])
if attempt.get("port") is not None:
ports.append(str(attempt["port"]))

out = params.copy()
shosts = ",".join(hosts)
Expand Down
58 changes: 27 additions & 31 deletions psycopg/psycopg/conninfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,10 @@ def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
# If an host resolves to more than one ip, the libpq will make more than
# one attempt and wouldn't get to try the following ones, as before
# fixing #674.
attempts = _split_attempts(params)
if params.get("load_balance_hosts", "disable") == "random":
attempts = list(_split_attempts(params))
shuffle(attempts)
yield from attempts
else:
yield from _split_attempts(params)
yield from attempts


async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
Expand All @@ -317,25 +315,27 @@ async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
# TODO: the function should resolve all hosts and shuffle the results
# to replicate the same libpq algorithm.
yielded = False
last_exc = None
attempts = []
for attempt in _split_attempts(params):
try:
async for a2 in _split_attempts_and_resolve(attempt):
yielded = True
yield a2
attempts.extend(await _resolve_hostnames(attempt))
except OSError as ex:
last_exc = ex

if not yielded:
if not attempts:
assert last_exc
# We couldn't resolve anything
raise e.OperationalError(str(last_exc))

if params.get("load_balance_hosts", "disable") == "random":
shuffle(attempts)

for attempt in attempts:
yield attempt

def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:

def _split_attempts(params: ConnDict) -> list[ConnDict]:
"""
Split connection parameters with a sequence of hosts into separate attempts.
"""
Expand Down Expand Up @@ -363,13 +363,13 @@ def split_val(key: str) -> list[str]:

# A single attempt to make. Don't mangle the conninfo string.
if nhosts <= 1:
yield params
return
return [params]

if len(ports) == 1:
ports *= nhosts

# Now all lists are either empty or have the same length
rv = []
for i in range(nhosts):
attempt = params.copy()
if hosts:
Expand All @@ -378,41 +378,39 @@ def split_val(key: str) -> list[str]:
attempt["hostaddr"] = hostaddrs[i]
if ports:
attempt["port"] = ports[i]
yield attempt
rv.append(attempt)

return rv


async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDict]:
async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
"""
Perform async DNS lookup of the hosts and return a new params dict.
If a ``host`` param is present but not ``hostname``, resolve the host
addresses asynchronously.
:param params: The input parameters, for instance as returned by
`~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
a single entry for host, hostaddr because it is designed to further
process the input of _split_attempts().
If a ``host`` param is present but not ``hostname``, resolve the host
addresses asynchronously.
The function may change the input ``host``, ``hostname``, ``port`` to allow
connecting without further DNS lookups.
:return: A list of attempts to make (to include the case of a hostname
resolving to more than one IP).
"""
host = _get_param(params, "host")
if not host or host.startswith("/") or host[1:2] == ":":
# Local path, or no host to resolve
yield params
return
return [params]

hostaddr = _get_param(params, "hostaddr")
if hostaddr:
# Already resolved
yield params
return
return [params]

if is_ip_address(host):
# If the host is already an ip address don't try to resolve it
params["hostaddr"] = host
yield params
return
return [{**params, "hostaddr": host}]

loop = asyncio.get_running_loop()

Expand All @@ -426,9 +424,7 @@ async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDic
ans = await loop.getaddrinfo(
host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
)

for item in ans:
yield {**params, "hostaddr": item[4][0]}
return [{**params, "hostaddr": item[4][0]} for item in ans]


def _get_param(params: ConnDict, name: str) -> str | None:
Expand Down
19 changes: 18 additions & 1 deletion tests/test_conninfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def test_conninfo_attempts_bad(setpgenv, conninfo, env):
list(conninfo_attempts(params))


def test_conninfo_random(dsn, conn_cls):
def test_conninfo_random():
hosts = [f"host{n:02d}" for n in range(50)]
args = {"host": ",".join(hosts)}
ahosts = [att["host"] for att in conninfo_attempts(args)]
Expand All @@ -515,13 +515,30 @@ def test_conninfo_random(dsn, conn_cls):
assert ahosts == hosts


@pytest.mark.anyio
async def test_conninfo_random_async(fake_resolve):
args = {"host": "alot.com"}
hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
assert len(hostaddrs) == 20
assert hostaddrs == sorted(hostaddrs)

args["load_balance_hosts"] = "disable"
hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
assert hostaddrs == sorted(hostaddrs)

args["load_balance_hosts"] = "random"
hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
assert hostaddrs != sorted(hostaddrs)


@pytest.fixture
async def fake_resolve(monkeypatch):
fake_hosts = {
"localhost": ["127.0.0.1"],
"foo.com": ["1.1.1.1"],
"qux.com": ["2.2.2.2"],
"dup.com": ["3.3.3.3", "3.3.3.4"],
"alot.com": [f"4.4.4.{n}" for n in range(10, 30)],
}

def family(host):
Expand Down

0 comments on commit bd259ee

Please sign in to comment.