Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement connection service file functionality #1223

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 143 additions & 8 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
from __future__ import annotations

import asyncio
import configparser
import collections
from collections.abc import Callable
import enum
@@ -87,6 +88,9 @@ class SSLNegotiation(compat.StrEnum):
PGPASSFILE = '.pgpass'


PG_SERVICEFILE = '.pg_service.conf'


def _read_password_file(passfile: pathlib.Path) \
-> typing.List[typing.Tuple[str, ...]]:

@@ -268,7 +272,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:


def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
password, passfile, database, ssl, service,
direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
# `auth_hosts` is the version of host information for the purposes
@@ -281,6 +285,28 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if dsn:
parsed = urllib.parse.urlparse(dsn)

query = None
if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]

if 'service' in query:
val = query.pop('service')
if not service and val:
service = val

connection_service_file = os.getenv('PGSERVICEFILE')
if connection_service_file is None:
homedir = compat.get_pg_home_directory()
if homedir:
connection_service_file = homedir / PG_SERVICEFILE
else:
connection_service_file = None
else:
connection_service_file = pathlib.Path(connection_service_file)

if parsed.scheme not in {'postgresql', 'postgres'}:
raise exceptions.ClientConfigurationError(
'invalid DSN: scheme is expected to be either '
@@ -315,11 +341,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if password is None and dsn_password:
password = urllib.parse.unquote(dsn_password)

if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]
if query:

if 'port' in query:
val = query.pop('port')
@@ -406,12 +428,124 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if gsslib is None:
gsslib = val

if 'service' in query:
val = query.pop('service')
if service is None:
service = val

if query:
if server_settings is None:
server_settings = query
else:
server_settings = {**query, **server_settings}

if connection_service_file is not None and service is not None:
pg_service = configparser.ConfigParser()
pg_service.read(connection_service_file)
if service in pg_service.sections():
service_params = pg_service[service]
if 'port' in service_params:
val = service_params.pop('port')
if not port and val:
port = [int(p) for p in val.split(',')]

if 'host' in service_params:
val = service_params.pop('host')
if not host and val:
host, port = _parse_hostlist(val, port)

if 'dbname' in service_params:
val = service_params.pop('dbname')
if database is None:
database = val

if 'database' in service_params:
val = service_params.pop('database')
if database is None:
database = val

if 'user' in service_params:
val = service_params.pop('user')
if user is None:
user = val

if 'password' in service_params:
val = service_params.pop('password')
if password is None:
password = val

if 'passfile' in service_params:
val = service_params.pop('passfile')
if passfile is None:
passfile = val

if 'sslmode' in service_params:
val = service_params.pop('sslmode')
if ssl is None:
ssl = val

if 'sslcert' in service_params:
val = service_params.pop('sslcert')
if sslcert is None:
sslcert = val

if 'sslkey' in service_params:
val = service_params.pop('sslkey')
if sslkey is None:
sslkey = val

if 'sslrootcert' in service_params:
val = service_params.pop('sslrootcert')
if sslrootcert is None:
sslrootcert = val

if 'sslnegotiation' in service_params:
val = service_params.pop('sslnegotiation')
if sslnegotiation is None:
sslnegotiation = val

if 'sslcrl' in service_params:
val = service_params.pop('sslcrl')
if sslcrl is None:
sslcrl = val

if 'sslpassword' in service_params:
val = service_params.pop('sslpassword')
if sslpassword is None:
sslpassword = val

if 'ssl_min_protocol_version' in service_params:
val = service_params.pop(
'ssl_min_protocol_version'
)
if ssl_min_protocol_version is None:
ssl_min_protocol_version = val

if 'ssl_max_protocol_version' in service_params:
val = service_params.pop(
'ssl_max_protocol_version'
)
if ssl_max_protocol_version is None:
ssl_max_protocol_version = val

if 'target_session_attrs' in service_params:
dsn_target_session_attrs = service_params.pop(
'target_session_attrs'
)
if target_session_attrs is None:
target_session_attrs = dsn_target_session_attrs

if 'krbsrvname' in service_params:
val = service_params.pop('krbsrvname')
if krbsrvname is None:
krbsrvname = val

if 'gsslib' in service_params:
val = service_params.pop('gsslib')
if gsslib is None:
gsslib = val
if not service:
service = os.environ.get('PGSERVICE')
if not host:
hostspec = os.environ.get('PGHOST')
if hostspec:
@@ -724,7 +858,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
target_session_attrs, krbsrvname, gsslib,
service):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
@@ -754,7 +889,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
direct_tls=direct_tls, database=database,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
krbsrvname=krbsrvname, gsslib=gsslib, service=service)

config = _ClientConfiguration(
command_timeout=command_timeout,
6 changes: 6 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
@@ -2074,6 +2074,7 @@ async def _do_execute(
async def connect(dsn=None, *,
host=None, port=None,
user=None, password=None, passfile=None,
service=None,
database=None,
loop=None,
timeout=60,
@@ -2183,6 +2184,10 @@ async def connect(dsn=None, *,
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
on Windows).

:param service:
The name of the postgres connection service stored in the postgres
connection service file.

:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
@@ -2428,6 +2433,7 @@ async def connect(dsn=None, *,
user=user,
password=password,
passfile=passfile,
service=service,
ssl=ssl,
direct_tls=direct_tls,
database=database,
111 changes: 109 additions & 2 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
@@ -1116,7 +1116,8 @@ def run_testcase(self, testcase):
env = testcase.get('env', {})
test_env = {'PGHOST': None, 'PGPORT': None,
'PGUSER': None, 'PGPASSWORD': None,
'PGDATABASE': None, 'PGSSLMODE': None}
'PGDATABASE': None, 'PGSSLMODE': None,
'PGSERVICE': None, }
test_env.update(env)

dsn = testcase.get('dsn')
@@ -1132,6 +1133,7 @@ def run_testcase(self, testcase):
target_session_attrs = testcase.get('target_session_attrs')
krbsrvname = testcase.get('krbsrvname')
gsslib = testcase.get('gsslib')
service = testcase.get('service')

expected = testcase.get('result')
expected_error = testcase.get('error')
@@ -1157,7 +1159,7 @@ def run_testcase(self, testcase):
direct_tls=direct_tls,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
krbsrvname=krbsrvname, gsslib=gsslib, service=service)

params = {
k: v for k, v in params._asdict().items()
@@ -1236,6 +1238,111 @@ def test_connect_params(self):
for testcase in self.TESTS:
self.run_testcase(testcase)

def test_connect_connection_service_file(self):
connection_service_file = tempfile.NamedTemporaryFile(
'w+t', delete=False)
connection_service_file.write(textwrap.dedent('''
[test_service_dbname]
port=5433
host=somehost
dbname=test_dbname
user=admin
password=test_password
target_session_attrs=primary
krbsrvname=fakekrbsrvname
gsslib=sspi

[test_service_database]
port=5433
host=somehost
database=test_dbname
user=admin
password=test_password
target_session_attrs=primary
krbsrvname=fakekrbsrvname
gsslib=sspi
'''))
connection_service_file.close()
os.chmod(connection_service_file.name, stat.S_IWUSR | stat.S_IRUSR)
try:
# Test connection service file with dbname
self.run_testcase({
'dsn': 'postgresql://?service=test_service_dbname',
'env': {
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
# Test connection service file with database
self.run_testcase({
'dsn': 'postgresql://?service=test_service_database',
'env': {
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
# Test that envvars are overridden by service file
self.run_testcase({
'dsn': 'postgresql://?service=test_service_dbname',
'env': {
'PGUSER': 'user',
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
# Test that dsn params overwrite service file
self.run_testcase({
'dsn': 'postgresql://?service={}&dbname={}'.format(
"test_service_dbname", "test_dbname_dsn"
),
'env': {
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname_dsn',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
finally:
os.unlink(connection_service_file.name)

def test_connect_pgpass_regular(self):
passfile = tempfile.NamedTemporaryFile('w+t', delete=False)
passfile.write(textwrap.dedent(R'''