Skip to content

Add middleware support #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
@@ -264,6 +264,7 @@ def create_pool(dsn=None, *,
setup=None,
init=None,
loop=None,
middlewares=None,
pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection,
**connect_kwargs):
@@ -272,7 +273,7 @@ def create_pool(dsn=None, *,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
connection_class=connection_class, middlewares=middlewares,
**connect_kwargs)


7 changes: 4 additions & 3 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
@@ -594,7 +594,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,


async def _connect_addr(*, addr, loop, timeout, params, config,
connection_class):
middlewares, connection_class):
assert loop is not None

if timeout <= 0:
@@ -633,12 +633,12 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
tr.close()
raise

con = connection_class(pr, tr, loop, addr, config, params)
con = connection_class(pr, tr, loop, addr, config, params, middlewares)
pr.set_connection(con)
return con


async def _connect(*, loop, timeout, connection_class, **kwargs):
async def _connect(*, loop, timeout, middlewares, connection_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()

@@ -652,6 +652,7 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
con = await _connect_addr(
addr=addr, loop=loop, timeout=timeout,
params=params, config=config,
middlewares=middlewares,
connection_class=connection_class)
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
last_error = ex
22 changes: 17 additions & 5 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ class Connection(metaclass=ConnectionMeta):
"""

__slots__ = ('_protocol', '_transport', '_loop',
'_top_xact', '_aborted',
'_top_xact', '_aborted', '_middlewares',
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
@@ -52,7 +52,8 @@ class Connection(metaclass=ConnectionMeta):
def __init__(self, protocol, transport, loop,
addr: (str, int) or str,
config: connect_utils._ClientConfiguration,
params: connect_utils._ConnectionParameters):
params: connect_utils._ConnectionParameters,
_middlewares=None):
self._protocol = protocol
self._transport = transport
self._loop = loop
@@ -91,7 +92,7 @@ def __init__(self, protocol, transport, loop,

self._reset_query = None
self._proxy = None

self._middlewares = _middlewares
# Used to serialize operations that might involve anonymous
# statements. Specifically, we want to make the following
# operation atomic:
@@ -1399,8 +1400,13 @@ async def reload_schema_state(self):

async def _execute(self, query, args, limit, timeout, return_status=False):
with self._stmt_exclusive_section:
result, _ = await self.__execute(
query, args, limit, timeout, return_status=return_status)
wrapped = self.__execute
if self._middlewares:
for m in reversed(self._middlewares):
wrapped = await m(connection=self, handler=wrapped)

result, _ = await wrapped(query, args, limit,
timeout, return_status=return_status)
return result

async def __execute(self, query, args, limit, timeout,
@@ -1491,6 +1497,7 @@ async def connect(dsn=None, *,
max_cacheable_statement_size=1024 * 15,
command_timeout=None,
ssl=None,
middlewares=None,
connection_class=Connection,
server_settings=None):
r"""A coroutine to establish a connection to a PostgreSQL server.
@@ -1607,6 +1614,10 @@ async def connect(dsn=None, *,
PostgreSQL documentation for
a `list of supported options <server settings>`_.

:param middlewares:
An optional list of middleware functions. Refer to documentation
on create_pool.

:param Connection connection_class:
Class of the returned connection object. Must be a subclass of
:class:`~asyncpg.connection.Connection`.
@@ -1672,6 +1683,7 @@ async def connect(dsn=None, *,
ssl=ssl, database=database,
server_settings=server_settings,
command_timeout=command_timeout,
middlewares=middlewares,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size)
50 changes: 48 additions & 2 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
@@ -305,7 +305,7 @@ class Pool:
"""

__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
'_queue', '_loop', '_minsize', '_maxsize', '_middlewares',
'_init', '_connect_args', '_connect_kwargs',
'_working_addr', '_working_config', '_working_params',
'_holders', '_initialized', '_initializing', '_closing',
@@ -320,6 +320,7 @@ def __init__(self, *connect_args,
max_inactive_connection_lifetime,
setup,
init,
middlewares,
loop,
connection_class,
**connect_kwargs):
@@ -377,6 +378,7 @@ def __init__(self, *connect_args,
self._closed = False
self._generation = 0
self._init = init
self._middlewares = middlewares
self._connect_args = connect_args
self._connect_kwargs = connect_kwargs

@@ -469,6 +471,7 @@ async def _get_new_connection(self):
*self._connect_args,
loop=self._loop,
connection_class=self._connection_class,
middlewares=self._middlewares,
**self._connect_kwargs)

self._working_addr = con._addr
@@ -483,6 +486,7 @@ async def _get_new_connection(self):
addr=self._working_addr,
timeout=self._working_params.connect_timeout,
config=self._working_config,
middlewares=self._middlewares,
params=self._working_params,
connection_class=self._connection_class)

@@ -784,13 +788,37 @@ def __await__(self):
return self.pool._acquire(self.timeout).__await__()


def middleware(f):
"""Decorator for adding a middleware

Can be used like such

.. code-block:: python

@pool.middleware
async def my_middleware(query, args, limit,
timeout, return_status, *, handler, conn):
print('do something before')
result, stmt = await handler(query, args, limit,
timeout, return_status)
print('do something after')
return result, stmt

my_pool = await pool.create_pool(middlewares=[my_middleware])
"""
async def middleware_factory(connection, handler):
return functools.partial(f, connection=connection, handler=handler)
return middleware_factory


def create_pool(dsn=None, *,
min_size=10,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300.0,
setup=None,
init=None,
middlewares=None,
loop=None,
connection_class=connection.Connection,
**connect_kwargs):
@@ -866,6 +894,23 @@ def create_pool(dsn=None, *,
or :meth:`Connection.set_type_codec() <\
asyncpg.connection.Connection.set_type_codec>`.

:param middlewares:
A list of middleware functions to be middleware just
before a connection excecutes a statement.
Syntax of a middleware is as follows:

.. code-block:: python

async def middleware_factory(connection, handler):
async def middleware(query, args, limit,
timeout, return_status):
print('do something before')
result, stmt = await handler(query, args, limit,
timeout, return_status)
print('do something after')
return result, stmt
return middleware

:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
@@ -893,6 +938,7 @@ def create_pool(dsn=None, *,
dsn,
connection_class=connection_class,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
max_queries=max_queries, loop=loop, setup=setup,
middlewares=middlewares, init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
**connect_kwargs)
1 change: 1 addition & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ If you want to build **asyncpg** from a Git checkout you will need:
* CPython header files. These can usually be obtained by installing
the relevant Python development package: **python3-dev** on Debian/Ubuntu,
**python3-devel** on RHEL/Fedora.
* Clone the repo with submodules (`git clone --recursive`, or `git submodules init; git submodules update`)

Once the above requirements are satisfied, run the following command
in the root of the source checkout:
42 changes: 42 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
@@ -76,6 +76,48 @@ async def worker():
tasks = [worker() for _ in range(n)]
await asyncio.gather(*tasks)

async def test_pool_with_middleware(self):
called = False

async def my_middleware_factory(connection, handler):
async def middleware(query, args, limit, timeout, return_status):
nonlocal called
called = True
return await handler(query, args, limit,
timeout, return_status)
return middleware

pool = await self.create_pool(database='postgres',
min_size=1, max_size=1,
middlewares=[my_middleware_factory])

con = await pool.acquire(timeout=5)
await con.fetchval('SELECT 1')
assert called

pool.terminate()
del con

async def test_pool_with_middleware_decorator(self):
called = False

@pg_pool.middleware
async def my_middleware(query, args, limit, timeout, return_status,
*, connection, handler):
nonlocal called
called = True
return await handler(query, args, limit,
timeout, return_status)

pool = await self.create_pool(database='postgres', min_size=1,
max_size=1, middlewares=[my_middleware])
con = await pool.acquire(timeout=5)
await con.fetchval('SELECT 1')
assert called

pool.terminate()
del con

async def test_pool_03(self):
pool = await self.create_pool(database='postgres',
min_size=1, max_size=1)