Skip to content

Commit

Permalink
fix: don't add defaults to connection strings
Browse files Browse the repository at this point in the history
A default such as empty string for host may may shadow values defined in
a service file.

Fix #694.
  • Loading branch information
dvarrazzo committed Dec 13, 2023
1 parent 7f72e4c commit bfb68f1
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 117 deletions.
6 changes: 4 additions & 2 deletions docs/news.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ Future releases
Psycopg 3.1.15 (unreleased)
^^^^^^^^^^^^^^^^^^^^^^^^^^^

- Fix async connection to hosts resolving to multiple IP addresses
(:ticket:`#695`).
- Fix use of ``service`` in connection string (regression in 3.1.13,
:ticket:`#694`).
- Fix async connection to hosts resolving to multiple IP addresses (regression
in 3.1.13, :ticket:`#695`).


Current release
Expand Down
12 changes: 7 additions & 5 deletions psycopg/psycopg/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
hostaddrs: list[str] = []
ports: list[str] = []

for attempt in conninfo._split_attempts(conninfo._inject_defaults(params)):
for attempt in conninfo._split_attempts(params):
try:
async for a2 in conninfo._split_attempts_and_resolve(attempt):
hosts.append(a2["host"])
hostaddrs.append(a2["hostaddr"])
if "port" in params:
ports.append(a2["port"])
if a2.get("host") is not None:
hosts.append(a2["host"])
if a2.get("hostaddr") is not None:
hostaddrs.append(a2["hostaddr"])
if a2.get("port") is not None:
ports.append(str(a2["port"]))
except OSError as ex:
last_exc = ex

Expand Down
135 changes: 76 additions & 59 deletions psycopg/psycopg/conninfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from datetime import tzinfo
from functools import lru_cache
from ipaddress import ip_address
from dataclasses import dataclass
from typing_extensions import TypeAlias

from . import pq
from . import errors as e
from ._tz import get_tzinfo
from ._compat import cache
from ._encodings import pgconn_encoding

ConnDict: TypeAlias = "dict[str, Any]"
Expand Down Expand Up @@ -283,24 +283,28 @@ def _get_pgconn_attr(self, name: str) -> str:


def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
"""Split a set of connection params on the single attempts to perforn.
"""Split a set of connection params on the single attempts to perform.
A connection param can perform more than one attempt more than one ``host``
is provided.
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
# TODO: we should actually resolve the hosts ourselves.
# If an host resolves to more than one ip, the libpq will make more than
# one attempt and wouldn't get to try the following ones, as before
# fixing #674.
if params.get("load_balance_hosts", "disable") == "random":
attempts = list(_split_attempts(_inject_defaults(params)))
attempts = list(_split_attempts(params))
shuffle(attempts)
yield from attempts
else:
yield from _split_attempts(_inject_defaults(params))
yield from _split_attempts(params)


async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
"""Split a set of connection params on the single attempts to perforn.
"""Split a set of connection params on the single attempts to perform.
A connection param can perform more than one attempt more than one ``host``
is provided.
Expand All @@ -313,9 +317,11 @@ async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
# TODO: the function should resolve all hosts and shuffle the results
# to replicate the same libpq algorithm.
yielded = False
last_exc = None
for attempt in _split_attempts(_inject_defaults(params)):
for attempt in _split_attempts(params):
try:
async for a2 in _split_attempts_and_resolve(attempt):
yielded = True
Expand All @@ -329,45 +335,13 @@ async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
raise e.OperationalError(str(last_exc))


def _inject_defaults(params: ConnDict) -> ConnDict:
"""
Add defaults to a dictionary of parameters.
This avoids the need to look up for env vars at various stages during
processing.
Note that a port is always specified. 5432 likely comes from here.
The `host`, `hostaddr`, `port` will be always set to a string.
"""
defaults = _conn_defaults()
out = params.copy()

def inject(name: str, envvar: str) -> None:
value = out.get(name)
if not value:
out[name] = os.environ.get(envvar, defaults[name])
else:
out[name] = str(value)

inject("host", "PGHOST")
inject("hostaddr", "PGHOSTADDR")
inject("port", "PGPORT")

return out


def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
"""
Split connection parameters with a sequence of hosts into separate attempts.
Assume that `host`, `hostaddr`, `port` are always present and a string (as
emitted from `_inject_defaults()`).
"""

def split_val(key: str) -> list[str]:
# Assume all keys are present and strings.
val: str = params[key]
val = _get_param(params, key)
return val.split(",") if val else []

hosts = split_val("host")
Expand All @@ -386,14 +360,15 @@ def split_val(key: str) -> list[str]:
raise e.OperationalError(
f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
)
elif len(ports) == 1:
ports *= nhosts

# A single attempt to make
# A single attempt to make. Don't mangle the conninfo string.
if nhosts <= 1:
yield params
return

if len(ports) == 1:
ports *= nhosts

# Now all lists are either empty or have the same length
for i in range(nhosts):
attempt = params.copy()
Expand All @@ -412,24 +387,22 @@ async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDic
:param params: The input parameters, for instance as returned by
`~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
a single entry for host, hostaddr, port and doesn't check for env vars
because it is designed to further process the input of _split_attempts()
a single entry for host, hostaddr because it is designed to further
process the input of _split_attempts().
If a ``host`` param is present but not ``hostname``, resolve the host
addresses dynamically.
addresses asynchronously.
The function may change the input ``host``, ``hostname``, ``port`` to allow
connecting without further DNS lookups.
Raise `~psycopg.OperationalError` if resolution fails.
"""
host = params["host"]
host = _get_param(params, "host")
if not host or host.startswith("/") or host[1:2] == ":":
# Local path, or no host to resolve
yield params
return

hostaddr = params["hostaddr"]
hostaddr = _get_param(params, "hostaddr")
if hostaddr:
# Already resolved
yield params
Expand All @@ -443,25 +416,69 @@ async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDic

loop = asyncio.get_running_loop()

port = params["port"]
port = _get_param(params, "port")
if not port:
portdef = _get_param_def("port")
if portdef:
port = portdef.compiled

assert port and "," not in port # assume a libpq default and no multi
ans = await loop.getaddrinfo(
host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
)

for item in ans:
yield {**params, "hostaddr": item[4][0]}


@cache
def _conn_defaults() -> dict[str, str]:
def _get_param(params: ConnDict, name: str) -> str | None:
"""
Return a value from a connection string.
The value may be also specified in a PG* env var.
"""
if name in params:
return str(params[name])

# TODO: check if in service

paramdef = _get_param_def(name)
if not paramdef:
return None

env = os.environ.get(paramdef.envvar)
if env is not None:
return env

return None


@dataclass
class ParamDef:
"""
Information about defaults and env vars for connection params
"""

keyword: str
envvar: str
compiled: str | None


def _get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
"""
Return a dictionary of defaults for connection strings parameters.
Return the ParamDef of a connection string parameter.
"""
defs = pq.Conninfo.get_defaults()
return {
d.keyword.decode(): d.compiled.decode() if d.compiled is not None else ""
for d in defs
}
if not _cache:
defs = pq.Conninfo.get_defaults()
for d in defs:
cd = ParamDef(
keyword=d.keyword.decode(),
envvar=d.envvar.decode() if d.envvar else "",
compiled=d.compiled.decode() if d.compiled is not None else None,
)
_cache[cd.keyword] = cd

return _cache.get(keyword)


@lru_cache()
Expand Down
7 changes: 0 additions & 7 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,13 +876,6 @@ def removeif(key, value):
if params.get(key) == value:
params.pop(key)

removeif("host", "")
removeif("hostaddr", "")
removeif("port", "5432")
if "," in params.get("host", ""):
nhosts = len(params["host"].split(","))
removeif("port", ",".join(["5432"] * nhosts))
removeif("hostaddr", "," * (nhosts - 1))
removeif("connect_timeout", str(DEFAULT_TIMEOUT))

return params
Loading

0 comments on commit bfb68f1

Please sign in to comment.