diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b7213e26..216a3f3b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,6 +15,8 @@ jobs: python-versions: '["3.7"]' sqlalchemy-versions: '["1.3.24", "1.4.39"]' + # Sqlalchemy 1.3 doesn't yet include the mypy plugin so we get a bunch of spurious + # mypy issues. test-sqlalchemy13: uses: ./.github/workflows/lint_and_test.yml with: @@ -43,4 +45,5 @@ jobs: uses: ./.github/workflows/lint_and_test.yml with: install: make install + sqlalchemy-versions: '["1.4.39"]' python-versions: '["3.7", "3.8", "3.9", "3.10"]' diff --git a/.github/workflows/lint_and_test.yml b/.github/workflows/lint_and_test.yml index 005fc589..1bde3246 100644 --- a/.github/workflows/lint_and_test.yml +++ b/.github/workflows/lint_and_test.yml @@ -86,7 +86,7 @@ jobs: uses: actions/cache@v2 with: path: ~/.cache/pypoetry - key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }} + key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.sqlalchemy-version }}-${{ matrix.pytest-asyncio-version }}-${{ hashFiles('**/poetry.lock') }} - name: Install dependencies run: ${{ inputs.install }} diff --git a/pyproject.toml b/pyproject.toml index c5da7152..9591325e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pytest-mock-resources" -version = "2.6.1" +version = "2.6.3" description = "A pytest plugin for easily instantiating reproducible mock resources." authors = [ "Omar Khan ", diff --git a/src/pytest_mock_resources/__init__.py b/src/pytest_mock_resources/__init__.py index 92250594..454c09ed 100644 --- a/src/pytest_mock_resources/__init__.py +++ b/src/pytest_mock_resources/__init__.py @@ -6,6 +6,7 @@ RedisConfig, RedshiftConfig, ) +from pytest_mock_resources.credentials import Credentials from pytest_mock_resources.fixture import ( create_mongo_fixture, create_moto_fixture, @@ -14,7 +15,6 @@ create_redis_fixture, create_redshift_fixture, create_sqlite_fixture, - Credentials, pmr_mongo_config, pmr_mongo_container, pmr_moto_config, diff --git a/src/pytest_mock_resources/fixture/credentials.py b/src/pytest_mock_resources/credentials.py similarity index 75% rename from src/pytest_mock_resources/fixture/credentials.py rename to src/pytest_mock_resources/credentials.py index aeeb3ea4..254cfb45 100644 --- a/src/pytest_mock_resources/fixture/credentials.py +++ b/src/pytest_mock_resources/credentials.py @@ -1,3 +1,5 @@ +from sqlalchemy.orm import Session + from pytest_mock_resources import compat @@ -74,6 +76,26 @@ def as_redis_kwargs(self): "password": self.password, } + @classmethod + def assign_from_connection(cls, connection): + if isinstance(connection, Session): + url = connection.connection().engine.url + else: + url = connection.url + + instance = cls( + drivername=url.drivername, + host=url.host, + port=url.port, + username=url.username, + password=url.password, + database=url.database, + ) + connection.pmr_credentials = instance + return instance -def assign_fixture_credentials(engine, **credentials): - engine.pmr_credentials = Credentials(**credentials) + @classmethod + def assign_from_credentials(cls, engine, **credentials): + instance = Credentials(**credentials) + engine.pmr_credentials = instance + return instance diff --git a/src/pytest_mock_resources/fixture/__init__.py b/src/pytest_mock_resources/fixture/__init__.py index 4987bbd4..ffad503f 100644 --- a/src/pytest_mock_resources/fixture/__init__.py +++ b/src/pytest_mock_resources/fixture/__init__.py @@ -1,4 +1,3 @@ -from pytest_mock_resources.fixture.credentials import Credentials from pytest_mock_resources.fixture.mongo import ( create_mongo_fixture, pmr_mongo_config, @@ -33,7 +32,6 @@ from pytest_mock_resources.fixture.sqlite import create_sqlite_fixture __all__ = [ - "Credentials", "create_mongo_fixture", "create_moto_fixture", "create_mysql_fixture", diff --git a/src/pytest_mock_resources/fixture/mongo.py b/src/pytest_mock_resources/fixture/mongo.py index fca95760..8f140fdc 100644 --- a/src/pytest_mock_resources/fixture/mongo.py +++ b/src/pytest_mock_resources/fixture/mongo.py @@ -3,7 +3,7 @@ from pytest_mock_resources.compat import pymongo from pytest_mock_resources.container.base import get_container from pytest_mock_resources.container.mongo import MongoConfig -from pytest_mock_resources.fixture.credentials import assign_fixture_credentials +from pytest_mock_resources.credentials import Credentials @pytest.fixture(scope="session") @@ -64,14 +64,13 @@ def _create_clean_database(config): ) limited_db = limited_client[db_id] - assign_fixture_credentials( + Credentials.assign_from_credentials( limited_db, drivername="mongodb", host=config.host, port=config.port, - database=db_id, username=db_id, - password="password", + password=password, + database=db_id, ) - return limited_db diff --git a/src/pytest_mock_resources/fixture/mysql.py b/src/pytest_mock_resources/fixture/mysql.py index 39fc4f8c..4af5f2fa 100644 --- a/src/pytest_mock_resources/fixture/mysql.py +++ b/src/pytest_mock_resources/fixture/mysql.py @@ -4,7 +4,6 @@ from pytest_mock_resources.container.base import get_container from pytest_mock_resources.container.mysql import get_sqlalchemy_engine, MysqlConfig -from pytest_mock_resources.fixture.credentials import assign_fixture_credentials from pytest_mock_resources.sqlalchemy import EngineManager @@ -45,16 +44,6 @@ def _(pmr_mysql_container, pmr_mysql_config): database_name = _create_clean_database(pmr_mysql_config) engine = get_sqlalchemy_engine(pmr_mysql_config, database_name) - assign_fixture_credentials( - engine, - drivername="mysql+pymysql", - host=pmr_mysql_config.host, - port=pmr_mysql_config.port, - database=database_name, - username=pmr_mysql_config.username, - password=pmr_mysql_config.password, - ) - engine_manager = EngineManager.create( engine, dynamic_actions=ordered_actions, tables=tables, session=session ) diff --git a/src/pytest_mock_resources/fixture/postgresql.py b/src/pytest_mock_resources/fixture/postgresql.py index fb76ba21..29f586de 100644 --- a/src/pytest_mock_resources/fixture/postgresql.py +++ b/src/pytest_mock_resources/fixture/postgresql.py @@ -7,8 +7,8 @@ from pytest_mock_resources.container.base import get_container from pytest_mock_resources.container.postgres import get_sqlalchemy_engine, PostgresConfig +from pytest_mock_resources.credentials import Credentials from pytest_mock_resources.fixture.base import asyncio_fixture, generate_fixture_id -from pytest_mock_resources.fixture.credentials import assign_fixture_credentials from pytest_mock_resources.sqlalchemy import bifurcate_actions, EngineManager, normalize_actions log = logging.getLogger(__name__) @@ -170,7 +170,7 @@ def create_engine_manager( # distinct from what might have been used for the template database. database_name = _produce_clean_database(root_engine, createdb_template=createdb_template) engine = get_sqlalchemy_engine(pmr_postgres_config, database_name, **engine_kwargs) - _assign_credential(engine, pmr_postgres_config, database_name) + Credentials.assign_from_connection(engine) return EngineManager( engine, @@ -217,15 +217,3 @@ def _generate_database_name(conn): id_ = tuple(result)[0][0] database_name = "pytest_mock_resource_db_{}".format(id_) return database_name - - -def _assign_credential(engine, config, database_name): - assign_fixture_credentials( - engine, - drivername="postgresql+psycopg2", - host=config.host, - port=config.port, - username=config.username, - password=config.password, - database=database_name, - ) diff --git a/src/pytest_mock_resources/fixture/redis.py b/src/pytest_mock_resources/fixture/redis.py index e426bf2d..8c112db2 100644 --- a/src/pytest_mock_resources/fixture/redis.py +++ b/src/pytest_mock_resources/fixture/redis.py @@ -3,7 +3,7 @@ from pytest_mock_resources.compat import redis from pytest_mock_resources.container.base import get_container from pytest_mock_resources.container.redis import RedisConfig -from pytest_mock_resources.fixture.credentials import assign_fixture_credentials +from pytest_mock_resources.credentials import Credentials @pytest.fixture(scope="session") @@ -65,7 +65,7 @@ def _(request, pmr_redis_container, pmr_redis_config): db = redis.Redis(host=pmr_redis_config.host, port=pmr_redis_config.port, db=database_number) db.flushdb() - assign_fixture_credentials( + Credentials.assign_from_credentials( db, drivername="redis", host=pmr_redis_config.host, diff --git a/src/pytest_mock_resources/fixture/redshift/__init__.py b/src/pytest_mock_resources/fixture/redshift/__init__.py index e54271e6..9e677423 100644 --- a/src/pytest_mock_resources/fixture/redshift/__init__.py +++ b/src/pytest_mock_resources/fixture/redshift/__init__.py @@ -107,9 +107,10 @@ async def _async(pmr_redshift_container, pmr_redshift_config): database_name = engine_manager.engine.url.database async for conn in engine_manager.manage_async(): - engine = conn if session: engine = conn.sync_session.connection().engine + else: + engine = conn.sync_engine sqlalchemy.register_redshift_behavior(engine) diff --git a/src/pytest_mock_resources/fixture/sqlite.py b/src/pytest_mock_resources/fixture/sqlite.py index a553458c..f42267ef 100644 --- a/src/pytest_mock_resources/fixture/sqlite.py +++ b/src/pytest_mock_resources/fixture/sqlite.py @@ -25,7 +25,6 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import sqltypes -from pytest_mock_resources.fixture.credentials import assign_fixture_credentials from pytest_mock_resources.sqlalchemy import EngineManager @@ -261,16 +260,6 @@ def _(): ) for engine in engine_manager.manage_sync(): with filter_sqlalchemy_warnings(decimal_warnings_enabled=(not decimal_warnings)): - assign_fixture_credentials( - raw_engine, - drivername=driver_name, - host="", - port=None, - database=database_name, - username="", - password="", - ) - yield engine event.remove(raw_engine, "connect", enable_foreign_key_checks) diff --git a/src/pytest_mock_resources/patch/redshift/sqlalchemy.py b/src/pytest_mock_resources/patch/redshift/sqlalchemy.py index 7ec6cb15..ca10f7be 100644 --- a/src/pytest_mock_resources/patch/redshift/sqlalchemy.py +++ b/src/pytest_mock_resources/patch/redshift/sqlalchemy.py @@ -1,11 +1,10 @@ from typing import Union -from sqlalchemy import event, text +from sqlalchemy import event from sqlalchemy.sql.base import Executable -from sqlalchemy.sql.elements import TextClause from pytest_mock_resources.compat import sqlparse -from pytest_mock_resources.patch.redshift.mock_s3_copy import mock_s3_copy_command, strip +from pytest_mock_resources.patch.redshift.mock_s3_copy import mock_s3_copy_command from pytest_mock_resources.patch.redshift.mock_s3_unload import mock_s3_unload_command @@ -25,9 +24,6 @@ def receive_before_execute( individual cursor executions. Only the final statement's return value will be returned. """ - if isinstance(clauseelement, TextClause): - clauseelement = clauseelement.text - if isinstance(clauseelement, Executable): return clauseelement, multiparams, params @@ -37,7 +33,7 @@ def receive_before_execute( for statement in statements: cursor.execute(statement, *multiparams, **params) - return text(final_statement), multiparams, params + return final_statement, multiparams, params def receive_before_cursor_execute(_, cursor, statement: str, parameters, context, executemany): @@ -51,7 +47,7 @@ def receive_before_cursor_execute(_, cursor, statement: str, parameters, context we return a no-op query to be executed by sqlalchemy for certain kinds of supported extra features. """ - normalized_statement = strip(statement).lower() + normalized_statement = _preprocess(statement).lower() if normalized_statement.startswith("unload"): mock_s3_unload_command(statement, cursor) return "SELECT 1", {} diff --git a/src/pytest_mock_resources/sqlalchemy.py b/src/pytest_mock_resources/sqlalchemy.py index 84a37a55..8d88ab2b 100644 --- a/src/pytest_mock_resources/sqlalchemy.py +++ b/src/pytest_mock_resources/sqlalchemy.py @@ -12,6 +12,7 @@ from sqlalchemy.sql.schema import Table from pytest_mock_resources import compat +from pytest_mock_resources.credentials import Credentials log = logging.getLogger(__name__) @@ -136,6 +137,7 @@ def manage_sync(self): if self.actions_share_transaction is False: self.engine.dispose() + Credentials.assign_from_connection(session) yield session finally: session.close() @@ -147,6 +149,7 @@ def manage_sync(self): if self.actions_share_transaction is False: self.engine.dispose() + Credentials.assign_from_connection(self.engine) yield self.engine finally: @@ -173,6 +176,7 @@ async def manage_async(self, session=None): await session.commit() await session.close() + Credentials.assign_from_connection(engine.sync_engine) yield session else: async with engine.begin() as conn: @@ -182,6 +186,7 @@ async def manage_async(self, session=None): if not self.actions_share_transaction: await engine.dispose() + Credentials.assign_from_connection(engine.sync_engine) yield engine finally: await engine.dispose() diff --git a/tests/fixture/redshift/test_patch.py b/tests/fixture/redshift/test_patch.py index 23430001..316a989c 100644 --- a/tests/fixture/redshift/test_patch.py +++ b/tests/fixture/redshift/test_patch.py @@ -3,6 +3,7 @@ from sqlalchemy import text from pytest_mock_resources import create_postgres_fixture, create_redshift_fixture +from tests import skip_if_sqlalchemy2 from tests.fixture.redshift.utils import ( copy_fn_to_test_create_engine_patch, copy_fn_to_test_psycopg2_connect_patch, @@ -77,17 +78,42 @@ def test_tightly_scoped_patch(redshift, postgres): assert 'syntax error at or near "credentials"' in str(e.value) +redshift_engine = create_redshift_fixture(session=False) redshift_session = create_redshift_fixture(session=True) +async_redshift_engine = create_redshift_fixture(session=False, async_=True) async_redshift_session = create_redshift_fixture(session=True, async_=True) -def test_event_listener_registration(redshift_session): +@skip_if_sqlalchemy2 +def test_event_listener_registration_engine(redshift_engine): + with redshift_engine.connect() as conn: + result = conn.execute("select 1; select 1").scalar() + + result = redshift_engine.execute("select 1; select 1").scalar() + assert result == 1 + + +@skip_if_sqlalchemy2 +def test_event_listener_registration_session(redshift_session): + result = redshift_session.execute("select 1; select 1").scalar() + assert result == 1 + + +def test_event_listener_registration_text(redshift_session): result = redshift_session.execute(text("select 1; select 1")).scalar() assert result == 1 @pytest.mark.asyncio -async def test_event_listener_registration_async(async_redshift_session): - result = await async_redshift_session.execute(text("select 1; select 1")) +async def test_event_listener_registration_async_engine(async_redshift_engine): + async with async_redshift_engine.connect() as conn: + result = await conn.execute(text("select 1")) + value = result.scalar() + assert value == 1 + + +@pytest.mark.asyncio +async def test_event_listener_registration_async_session(async_redshift_session): + result = await async_redshift_session.execute(text("select 1")) value = result.scalar() assert value == 1 diff --git a/tests/fixture/test_pmr_credentials.py b/tests/fixture/test_pmr_credentials.py new file mode 100644 index 00000000..795988de --- /dev/null +++ b/tests/fixture/test_pmr_credentials.py @@ -0,0 +1,121 @@ +import pytest +from sqlalchemy import create_engine, text + +from pytest_mock_resources import ( + create_mongo_fixture, + create_moto_fixture, + create_mysql_fixture, + create_postgres_fixture, + create_redis_fixture, + create_redshift_fixture, + create_sqlite_fixture, +) + +mongo = create_mongo_fixture() +moto = create_moto_fixture() +mysql = create_mysql_fixture() +mysql_session = create_mysql_fixture(session=True) +pg = create_postgres_fixture() +pg_session = create_postgres_fixture(session=True) +pg_async = create_postgres_fixture(async_=True) +pg_async_session = create_postgres_fixture(async_=True, session=True) +redis = create_redis_fixture() +redshift = create_redshift_fixture() +redshift_session = create_redshift_fixture(session=True) +redshift_async = create_redshift_fixture(async_=True) +redshift_async_session = create_redshift_fixture(async_=True, session=True) +sqlite = create_sqlite_fixture() +sqlite_session = create_sqlite_fixture(session=True) + + +def test_mongo_pmr_credentials(mongo): + assert mongo.pmr_credentials + + +def test_moto_pmr_credentials(moto, pmr_moto_credentials): + assert moto + assert pmr_moto_credentials.aws_access_key_id + assert pmr_moto_credentials.aws_secret_access_key + + +def test_mysql_pmr_credentials(mysql): + credentials = mysql.pmr_credentials + verify_relational(mysql, credentials) + + +def test_mysql_session_pmr_credentials(mysql_session): + credentials = mysql_session.pmr_credentials + verify_relational(mysql_session, credentials, session=True) + + +def test_postgres_pmr_credentials(pg): + credentials = pg.pmr_credentials + verify_relational(pg, credentials) + + +def test_postgres_session_pmr_credentials(pg_session): + credentials = pg_session.pmr_credentials + verify_relational(pg_session, credentials, session=True) + + +@pytest.mark.asyncio +async def test_postgres_async_pmr_credentials(pg_async): + assert pg_async.sync_engine.pmr_credentials + + +@pytest.mark.asyncio +async def test_postgres_async_session_pmr_credentials(pg_async_session): + assert (await pg_async_session.connection()).sync_engine.pmr_credentials + + +def test_redis_pmr_credentials(redis): + assert redis.pmr_credentials + + +def test_redshift_pmr_credentials(redshift): + credentials = redshift.pmr_credentials + verify_relational(redshift, credentials) + + +def test_redshift_session_pmr_credentials(redshift_session): + credentials = redshift_session.pmr_credentials + verify_relational(redshift_session, credentials, session=True) + + +@pytest.mark.asyncio +async def test_redshift_async_pmr_credentials(redshift_async): + assert redshift_async.sync_engine.pmr_credentials + + +@pytest.mark.asyncio +async def test_redshift_async_session_pmr_credentials(redshift_async_session): + assert (await redshift_async_session.connection()).sync_engine.pmr_credentials + + +def test_sqlite_pmr_credentials(sqlite): + assert sqlite.pmr_credentials + + +def test_sqlite_session_pmr_credentials(sqlite_session): + assert sqlite_session.pmr_credentials + + +def verify_relational(connection, credentials, session=False): + """Verify connection to the same database as the one given to the test function, using credentials.""" + assert credentials + + queries = [ + text("create table foo (id integer)"), + text("commit"), + ] + if not session: + with connection.begin() as conn: + for query in queries: + conn.execute(query) + else: + for query in queries: + connection.execute(query) + + manual_engine = create_engine(credentials.as_url()) + with manual_engine.connect() as conn: + conn.execute(text("select * from foo"))