From 9187fee430714566dfec3aa15d8fdb6bec8c3dfd Mon Sep 17 00:00:00 2001 From: Dan Cardin Date: Mon, 20 Sep 2021 10:06:39 -0400 Subject: [PATCH] refactor: Extract sqlalchemy compat into its own module. --- src/pytest_mock_resources/compat/__init__.py | 35 ++----------------- src/pytest_mock_resources/compat/import_.py | 23 ++++++++++++ .../compat/sqlalchemy.py | 10 ++++++ .../fixture/database/relational/generic.py | 8 +++-- 4 files changed, 41 insertions(+), 35 deletions(-) create mode 100644 src/pytest_mock_resources/compat/import_.py create mode 100644 src/pytest_mock_resources/compat/sqlalchemy.py diff --git a/src/pytest_mock_resources/compat/__init__.py b/src/pytest_mock_resources/compat/__init__.py index b3204af2..a28b1803 100644 --- a/src/pytest_mock_resources/compat/__init__.py +++ b/src/pytest_mock_resources/compat/__init__.py @@ -1,27 +1,7 @@ -class ImportAdaptor(object): - __wrapped__ = False - - def __init__(self, package, recommended_extra, fail_message=None, **attrs): - self.package = package - self.recommended_extra = recommended_extra - self.fail_message = fail_message - - for key, value in attrs.items(): - setattr(self, key, value) - - def fail(self): - if self.fail_message: - fail_message = self.fail_message - else: - fail_message = "Cannot use {recommended_extra} fixtures without {package}. pip install pytest-mock-resources[{recommended_extra}]".format( - package=self.package, recommended_extra=self.recommended_extra - ) - - raise RuntimeError(fail_message) - - def __getattr__(self, attr): - self.fail() +from pytest_mock_resources.compat.import_ import ImportAdaptor +# isort: split +from pytest_mock_resources.compat import import_, sqlalchemy # flake8: ignore try: import psycopg2 @@ -88,12 +68,3 @@ def __getattr__(self, attr): import pymysql except ImportError: pymysql = ImportAdaptor("pymysql", "mysql") # type: ignore - -try: - from sqlalchemy.ext import asyncio as sqlalchemy_asyncio # type: ignore -except ImportError: - sqlalchemy_asyncio = ImportAdaptor( # type: ignore - "SQLAlchemy", - "SQLAlchemy >= 1.4", - fail_message="Cannot use sqlalchemy async features with SQLAlchemy < 1.4.\n", - ) diff --git a/src/pytest_mock_resources/compat/import_.py b/src/pytest_mock_resources/compat/import_.py new file mode 100644 index 00000000..9c657ba2 --- /dev/null +++ b/src/pytest_mock_resources/compat/import_.py @@ -0,0 +1,23 @@ +class ImportAdaptor(object): + __wrapped__ = False + + def __init__(self, package, recommended_extra, fail_message=None, **attrs): + self.package = package + self.recommended_extra = recommended_extra + self.fail_message = fail_message + + for key, value in attrs.items(): + setattr(self, key, value) + + def fail(self): + if self.fail_message: + fail_message = self.fail_message + else: + fail_message = "Cannot use {recommended_extra} fixtures without {package}. pip install pytest-mock-resources[{recommended_extra}]".format( + package=self.package, recommended_extra=self.recommended_extra + ) + + raise RuntimeError(fail_message) + + def __getattr__(self, attr): + self.fail() diff --git a/src/pytest_mock_resources/compat/sqlalchemy.py b/src/pytest_mock_resources/compat/sqlalchemy.py new file mode 100644 index 00000000..ed4ebed1 --- /dev/null +++ b/src/pytest_mock_resources/compat/sqlalchemy.py @@ -0,0 +1,10 @@ +from pytest_mock_resources.compat.import_ import ImportAdaptor + +try: + from sqlalchemy.ext import asyncio as asyncio # type: ignore +except ImportError: + asyncio = ImportAdaptor( # type: ignore + "SQLAlchemy", + "SQLAlchemy >= 1.4", + fail_message="Cannot use sqlalchemy async features with SQLAlchemy < 1.4.\n", + ) diff --git a/src/pytest_mock_resources/fixture/database/relational/generic.py b/src/pytest_mock_resources/fixture/database/relational/generic.py index cae563e6..c787aa24 100644 --- a/src/pytest_mock_resources/fixture/database/relational/generic.py +++ b/src/pytest_mock_resources/fixture/database/relational/generic.py @@ -11,7 +11,7 @@ from sqlalchemy.sql.ddl import CreateSchema from sqlalchemy.sql.schema import Table -from pytest_mock_resources.compat import sqlalchemy_asyncio +from pytest_mock_resources import compat @six.add_metaclass(abc.ABCMeta) @@ -188,7 +188,9 @@ async def manage_async(self, session=None): session_factory = session else: session_factory = sessionmaker( - async_engine, expire_on_commit=False, class_=sqlalchemy_asyncio.AsyncSession + async_engine, + expire_on_commit=False, + class_=compat.sqlalchemy.asyncio.AsyncSession, ) async with session_factory() as session: yield session @@ -210,7 +212,7 @@ def _get_async_engine(self, isolation_level=None): options = {} if isolation_level: options["isolation_level"] = isolation_level - return sqlalchemy_asyncio.create_async_engine(url, **options) + return compat.sqlalchemy.asyncio.create_async_engine(url, **options) def identify_matching_tables(metadata, table_specifier):