Skip to content

Commit

Permalink
Merge branch 'release/0.18.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
ri-gilfanov committed Jun 23, 2021
2 parents fe0aaa3 + d07d067 commit 137983f
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 53 deletions.
12 changes: 6 additions & 6 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Copy and paste this code in a file and run:
class MyModel(Base):
__tablename__ = "my_table"
__tablename__ = 'my_table'
pk = sa.Column(sa.Integer, primary_key=True)
timestamp = sa.Column(sa.DateTime(), default=datetime.now)
Expand All @@ -108,14 +108,14 @@ Copy and paste this code in a file and run:
async def app_factory():
app = web.Application()
bind = aiohttp_sqlalchemy.bind("sqlite+aiosqlite:///")
aiohttp_sqlalchemy.setup(app, [bind])
aiohttp_sqlalchemy.setup(app, [
aiohttp_sqlalchemy.bind('sqlite+aiosqlite:///'),
])
await aiohttp_sqlalchemy.init_db(app, metadata)
app.add_routes([web.get("/", main)])
app.add_routes([web.get('/', main)])
return app
if __name__ == "__main__":
if __name__ == '__main__':
web.run_app(app_factory())
44 changes: 19 additions & 25 deletions aiohttp_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
DuplicateRequestKeyError,
)
from aiohttp_sqlalchemy.middlewares import sa_middleware
from aiohttp_sqlalchemy.typedefs import (
TBinding,
TBindings,
TBindTo,
TSessionFactory,
)
from aiohttp_sqlalchemy.typedefs import TBind, TBinds, TSessionFactory, TTarget
from aiohttp_sqlalchemy.utils import (
init_db,
sa_init_db,
Expand All @@ -37,7 +32,7 @@
SAView,
)

__version__ = '0.17.4'
__version__ = '0.18.0'

__all__ = [
'SA_DEFAULT_KEY',
Expand All @@ -63,51 +58,50 @@


def bind(
bind_to: TBindTo,
target: TTarget,
key: str = SA_DEFAULT_KEY,
*,
middleware: bool = True,
) -> 'TBinding':
) -> 'TBind':
"""Function wrapper for binding.
:param bind_to: target for SQLAlchemy binding. Argument can be database
connection url, asynchronous engine or asynchronous session
factory.
:param target: argument can be database connection url, asynchronous engine
or asynchronous session factory.
:param key: key of SQLAlchemy binding.
:param middleware: `bool` for enable middleware. True by default.
"""
if isinstance(bind_to, str):
bind_to = cast(AsyncEngine, create_async_engine(bind_to))
if isinstance(target, str):
target = cast(AsyncEngine, create_async_engine(target))

if isinstance(bind_to, AsyncEngine):
bind_to = cast(
if isinstance(target, AsyncEngine):
target = cast(
TSessionFactory,
sessionmaker(
bind=bind_to,
bind=target,
class_=AsyncSession,
expire_on_commit=False,
),
)

for type_ in (AsyncSession, Engine, Session):
if isinstance(bind_to, type_):
msg = f'{type_} is unsupported type of argument `bind_to`.'
if isinstance(target, type_):
msg = f'{type_} is unsupported type of argument `target`.'
raise TypeError(msg)

if not callable(bind_to):
msg = f'{bind_to} is unsupported type of argument `bind_to`.'
if not callable(target):
msg = f'{target} is unsupported type of argument `target`.'
raise TypeError(msg)

return bind_to, key, middleware
return target, key, middleware


def setup(app: Application, bindings: "TBindings") -> None:
def setup(app: Application, binds: "TBinds") -> None:
"""Setup function for SQLAlchemy binding to AIOHTTP application.
:param app: your AIOHTTP application.
:param bindings: iterable of `aiohttp_sqlalchemy.bind()` calls.
:param binds: iterable of `aiohttp_sqlalchemy.bind()` calls.
"""
for factory, key, middleware in bindings:
for factory, key, middleware in binds:
if key in app:
raise DuplicateAppKeyError(key)

Expand Down
13 changes: 11 additions & 2 deletions aiohttp_sqlalchemy/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,18 @@ async def wrapped(*args: Any, **kwargs: Any) -> StreamResponse:
if key in request:
raise DuplicateRequestKeyError(key)

# TODO: after dropped Python 3.7
# if session_factory := request.config_dict.get(key):
session_factory = request.config_dict.get(key)
async with session_factory() as request[key]:
return await handler(*args, **kwargs)
if session_factory:
async with session_factory() as request[key]:
return await handler(*args, **kwargs)
else:
raise KeyError(
f'Session factory not found by {key}.'
'Check `key` argument of `sa_decorator()`'
'or arguments of `aiohttp_sqlalchemy.setup()`.'
)

return wrapped

Expand Down
12 changes: 10 additions & 2 deletions aiohttp_sqlalchemy/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@ async def sa_middleware_(
if key in request:
raise DuplicateRequestKeyError(key)

# TODO: after dropped Python 3.7
# if session_factory := request.config_dict.get(key):
session_factory = request.config_dict.get(key)
async with session_factory() as request[key]:
return await handler(request)
if session_factory:
async with session_factory() as request[key]:
return await handler(request)
else:
raise KeyError(
f'Session factory not found by {key}.'
'Check `aiohttp_sqlalchemy.setup()`.'
)

return sa_middleware_
8 changes: 4 additions & 4 deletions aiohttp_sqlalchemy/typedefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from aiohttp.web import StreamResponse
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession

TSessionFactory = Callable[..., AsyncSession]
THandler = Callable[..., Awaitable[StreamResponse]]
THandlerWrapper = Callable[..., THandler]
TSessionFactory = Callable[..., AsyncSession]

TBindTo = Union[str, AsyncEngine, TSessionFactory]
TBinding = Tuple[TSessionFactory, str, bool]
TBindings = Iterable[TBinding]
TTarget = Union[str, AsyncEngine, TSessionFactory]
TBind = Tuple[TSessionFactory, str, bool]
TBinds = Iterable[TBind]
5 changes: 4 additions & 1 deletion docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ Copy and paste this code in a file and run:
result = await db_session.execute(sa.select(MyModel))
result = result.scalars()
data = {instance.pk: instance.timestamp.isoformat() for instance in result}
data = {
instance.pk: instance.timestamp.isoformat()
for instance in result
}
return web.json_response(data)
Expand Down
10 changes: 10 additions & 0 deletions docs/releases.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
========
Releases
========
Version 0.18
------------
**Changed**

* First argument of function ``aiohttp_sqlalchemy.bind()`` renamed from
``bind_to`` to ``target``;
* Type hint alias ``TBinding`` renamed to ``TBind``;
* Type hint alias ``TBindings`` renamed to ``TBinds``;
* Type hint alias ``TBindTo`` renamed to ``TTarget``.

Version 0.17
------------
**Added**
Expand Down
5 changes: 4 additions & 1 deletion examples/simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ async def main(request):
result = await db_session.execute(sa.select(MyModel))
result = result.scalars()

data = {instance.pk: instance.timestamp.isoformat() for instance in result}
data = {
instance.pk: instance.timestamp.isoformat()
for instance in result
}
return web.json_response(data)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aiohttp-sqlalchemy"
version = "0.17.4"
version = "0.18.0"
description = "SQLAlchemy 1.4 / 2.0 support for aiohttp."
authors = [
"Ruslan Ilyasovich Gilfanov <ri.gilfanov@yandex.ru>",
Expand Down
14 changes: 11 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
pytest_plugins = 'aiohttp.pytest_plugin'


@pytest.fixture
def wrong_key() -> str:
return 'wrong_key'


@pytest.fixture
def orm_async_engine() -> AsyncEngine:
return create_async_engine('sqlite+aiosqlite:///')
Expand All @@ -39,7 +44,7 @@ def session(session_factory: TSessionFactory) -> AsyncSession:


@pytest.fixture
def sa_main_middleware() -> THandler:
def main_middleware() -> THandler:
return sa_middleware(SA_DEFAULT_KEY)


Expand All @@ -55,8 +60,11 @@ def mocked_request(middlewared_app: Application) -> 'Request':
return make_mocked_request(METH_GET, '/', app=middlewared_app)


async def function_handler(request: Request) -> Response:
return web.json_response({})
@pytest.fixture
def function_handler() -> THandler:
async def handler(request: Request) -> Response:
return web.json_response({})
return handler


class ClassHandler:
Expand Down
18 changes: 16 additions & 2 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
DuplicateRequestKeyError,
sa_decorator,
)
from tests.conftest import ClassBasedView, ClassHandler, function_handler
from aiohttp_sqlalchemy.typedefs import THandler
from tests.conftest import ClassBasedView, ClassHandler


async def test_duplicate_request_key_error(
mocked_request: Request,
session: AsyncSession,
function_handler: THandler,
) -> None:
assert mocked_request.get(SA_DEFAULT_KEY) is None
mocked_request[SA_DEFAULT_KEY] = session
Expand All @@ -21,6 +23,15 @@ async def test_duplicate_request_key_error(
await sa_decorator()(function_handler)(mocked_request)


async def test_session_factory_not_found(
mocked_request: Request,
wrong_key: str,
) -> None:
assert wrong_key not in mocked_request
with pytest.raises(KeyError):
await sa_decorator(wrong_key)(ClassBasedView.get)(mocked_request)


async def test_decorated_class_based_view(mocked_request: Request) -> None:
assert mocked_request.get(SA_DEFAULT_KEY) is None
await sa_decorator()(ClassBasedView.get)(mocked_request)
Expand All @@ -34,7 +45,10 @@ async def test_decorated_class_handler(mocked_request: Request) -> None:
assert isinstance(mocked_request.get(SA_DEFAULT_KEY), AsyncSession)


async def test_decorated_function_handler(mocked_request: Request) -> None:
async def test_decorated_function_handler(
mocked_request: Request,
function_handler: THandler,
) -> None:
assert mocked_request.get(SA_DEFAULT_KEY) is None
await sa_decorator()(function_handler)(mocked_request)
assert isinstance(mocked_request.get(SA_DEFAULT_KEY), AsyncSession)
27 changes: 21 additions & 6 deletions tests/test_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,43 @@
from aiohttp.web import Request
from sqlalchemy.ext.asyncio import AsyncSession

from aiohttp_sqlalchemy import SA_DEFAULT_KEY, DuplicateRequestKeyError
from aiohttp_sqlalchemy import (
SA_DEFAULT_KEY,
DuplicateRequestKeyError,
sa_middleware,
)
from aiohttp_sqlalchemy.typedefs import THandler
from tests.conftest import function_handler


async def test_duplicate_request_key_error(
sa_main_middleware: THandler,
mocked_request: Request,
function_handler: THandler,
main_middleware: THandler,
session: AsyncSession,
) -> None:
assert mocked_request.get(SA_DEFAULT_KEY) is None
mocked_request[SA_DEFAULT_KEY] = session
assert mocked_request.get(SA_DEFAULT_KEY) is session

with pytest.raises(DuplicateRequestKeyError):
await sa_main_middleware(mocked_request, function_handler)
await main_middleware(mocked_request, function_handler)


async def test_session_factory_not_found(
mocked_request: Request,
function_handler: THandler,
wrong_key: str,
) -> None:
assert wrong_key not in mocked_request
with pytest.raises(KeyError):
await sa_middleware(wrong_key)(mocked_request, function_handler)


async def test_sa_middleware(
sa_main_middleware: THandler,
mocked_request: Request,
function_handler: THandler,
main_middleware: THandler,
) -> None:
assert mocked_request.get(SA_DEFAULT_KEY) is None
await sa_main_middleware(mocked_request, function_handler)
await main_middleware(mocked_request, function_handler)
assert isinstance(mocked_request.get(SA_DEFAULT_KEY), AsyncSession)

0 comments on commit 137983f

Please sign in to comment.