Skip to content

Commit

Permalink
fix: Decide on behavior support for multiple redshift statements. (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin committed Oct 27, 2022
1 parent 088b180 commit e9d5f0f
Show file tree
Hide file tree
Showing 16 changed files with 199 additions and 62 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ jobs:
python-versions: '["3.7"]'
sqlalchemy-versions: '["1.3.24", "1.4.39"]'

# Sqlalchemy 1.3 doesn't yet include the mypy plugin so we get a bunch of spurious
# mypy issues.
test-sqlalchemy13:
uses: ./.github/workflows/lint_and_test.yml
with:
Expand Down Expand Up @@ -43,4 +45,5 @@ jobs:
uses: ./.github/workflows/lint_and_test.yml
with:
install: make install
sqlalchemy-versions: '["1.4.39"]'
python-versions: '["3.7", "3.8", "3.9", "3.10"]'
2 changes: 1 addition & 1 deletion .github/workflows/lint_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ jobs:
uses: actions/cache@v2
with:
path: ~/.cache/pypoetry
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.sqlalchemy-version }}-${{ matrix.pytest-asyncio-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
run: ${{ inputs.install }}

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytest-mock-resources"
version = "2.6.1"
version = "2.6.3"
description = "A pytest plugin for easily instantiating reproducible mock resources."
authors = [
"Omar Khan <oakhan3@gmail.com>",
Expand Down
2 changes: 1 addition & 1 deletion src/pytest_mock_resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
RedisConfig,
RedshiftConfig,
)
from pytest_mock_resources.credentials import Credentials
from pytest_mock_resources.fixture import (
create_mongo_fixture,
create_moto_fixture,
Expand All @@ -14,7 +15,6 @@
create_redis_fixture,
create_redshift_fixture,
create_sqlite_fixture,
Credentials,
pmr_mongo_config,
pmr_mongo_container,
pmr_moto_config,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from sqlalchemy.orm import Session

from pytest_mock_resources import compat


Expand Down Expand Up @@ -74,6 +76,26 @@ def as_redis_kwargs(self):
"password": self.password,
}

@classmethod
def assign_from_connection(cls, connection):
if isinstance(connection, Session):
url = connection.connection().engine.url
else:
url = connection.url

instance = cls(
drivername=url.drivername,
host=url.host,
port=url.port,
username=url.username,
password=url.password,
database=url.database,
)
connection.pmr_credentials = instance
return instance

def assign_fixture_credentials(engine, **credentials):
engine.pmr_credentials = Credentials(**credentials)
@classmethod
def assign_from_credentials(cls, engine, **credentials):
instance = Credentials(**credentials)
engine.pmr_credentials = instance
return instance
2 changes: 0 additions & 2 deletions src/pytest_mock_resources/fixture/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pytest_mock_resources.fixture.credentials import Credentials
from pytest_mock_resources.fixture.mongo import (
create_mongo_fixture,
pmr_mongo_config,
Expand Down Expand Up @@ -33,7 +32,6 @@
from pytest_mock_resources.fixture.sqlite import create_sqlite_fixture

__all__ = [
"Credentials",
"create_mongo_fixture",
"create_moto_fixture",
"create_mysql_fixture",
Expand Down
9 changes: 4 additions & 5 deletions src/pytest_mock_resources/fixture/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pytest_mock_resources.compat import pymongo
from pytest_mock_resources.container.base import get_container
from pytest_mock_resources.container.mongo import MongoConfig
from pytest_mock_resources.fixture.credentials import assign_fixture_credentials
from pytest_mock_resources.credentials import Credentials


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -64,14 +64,13 @@ def _create_clean_database(config):
)
limited_db = limited_client[db_id]

assign_fixture_credentials(
Credentials.assign_from_credentials(
limited_db,
drivername="mongodb",
host=config.host,
port=config.port,
database=db_id,
username=db_id,
password="password",
password=password,
database=db_id,
)

return limited_db
11 changes: 0 additions & 11 deletions src/pytest_mock_resources/fixture/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from pytest_mock_resources.container.base import get_container
from pytest_mock_resources.container.mysql import get_sqlalchemy_engine, MysqlConfig
from pytest_mock_resources.fixture.credentials import assign_fixture_credentials
from pytest_mock_resources.sqlalchemy import EngineManager


Expand Down Expand Up @@ -45,16 +44,6 @@ def _(pmr_mysql_container, pmr_mysql_config):
database_name = _create_clean_database(pmr_mysql_config)
engine = get_sqlalchemy_engine(pmr_mysql_config, database_name)

assign_fixture_credentials(
engine,
drivername="mysql+pymysql",
host=pmr_mysql_config.host,
port=pmr_mysql_config.port,
database=database_name,
username=pmr_mysql_config.username,
password=pmr_mysql_config.password,
)

engine_manager = EngineManager.create(
engine, dynamic_actions=ordered_actions, tables=tables, session=session
)
Expand Down
16 changes: 2 additions & 14 deletions src/pytest_mock_resources/fixture/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from pytest_mock_resources.container.base import get_container
from pytest_mock_resources.container.postgres import get_sqlalchemy_engine, PostgresConfig
from pytest_mock_resources.credentials import Credentials
from pytest_mock_resources.fixture.base import asyncio_fixture, generate_fixture_id
from pytest_mock_resources.fixture.credentials import assign_fixture_credentials
from pytest_mock_resources.sqlalchemy import bifurcate_actions, EngineManager, normalize_actions

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -170,7 +170,7 @@ def create_engine_manager(
# distinct from what might have been used for the template database.
database_name = _produce_clean_database(root_engine, createdb_template=createdb_template)
engine = get_sqlalchemy_engine(pmr_postgres_config, database_name, **engine_kwargs)
_assign_credential(engine, pmr_postgres_config, database_name)
Credentials.assign_from_connection(engine)

return EngineManager(
engine,
Expand Down Expand Up @@ -217,15 +217,3 @@ def _generate_database_name(conn):
id_ = tuple(result)[0][0]
database_name = "pytest_mock_resource_db_{}".format(id_)
return database_name


def _assign_credential(engine, config, database_name):
assign_fixture_credentials(
engine,
drivername="postgresql+psycopg2",
host=config.host,
port=config.port,
username=config.username,
password=config.password,
database=database_name,
)
4 changes: 2 additions & 2 deletions src/pytest_mock_resources/fixture/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pytest_mock_resources.compat import redis
from pytest_mock_resources.container.base import get_container
from pytest_mock_resources.container.redis import RedisConfig
from pytest_mock_resources.fixture.credentials import assign_fixture_credentials
from pytest_mock_resources.credentials import Credentials


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -65,7 +65,7 @@ def _(request, pmr_redis_container, pmr_redis_config):
db = redis.Redis(host=pmr_redis_config.host, port=pmr_redis_config.port, db=database_number)
db.flushdb()

assign_fixture_credentials(
Credentials.assign_from_credentials(
db,
drivername="redis",
host=pmr_redis_config.host,
Expand Down
3 changes: 2 additions & 1 deletion src/pytest_mock_resources/fixture/redshift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ async def _async(pmr_redshift_container, pmr_redshift_config):
database_name = engine_manager.engine.url.database

async for conn in engine_manager.manage_async():
engine = conn
if session:
engine = conn.sync_session.connection().engine
else:
engine = conn.sync_engine

sqlalchemy.register_redshift_behavior(engine)

Expand Down
11 changes: 0 additions & 11 deletions src/pytest_mock_resources/fixture/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes

from pytest_mock_resources.fixture.credentials import assign_fixture_credentials
from pytest_mock_resources.sqlalchemy import EngineManager


Expand Down Expand Up @@ -261,16 +260,6 @@ def _():
)
for engine in engine_manager.manage_sync():
with filter_sqlalchemy_warnings(decimal_warnings_enabled=(not decimal_warnings)):
assign_fixture_credentials(
raw_engine,
drivername=driver_name,
host="",
port=None,
database=database_name,
username="",
password="",
)

yield engine

event.remove(raw_engine, "connect", enable_foreign_key_checks)
Expand Down
12 changes: 4 additions & 8 deletions src/pytest_mock_resources/patch/redshift/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Union

from sqlalchemy import event, text
from sqlalchemy import event
from sqlalchemy.sql.base import Executable
from sqlalchemy.sql.elements import TextClause

from pytest_mock_resources.compat import sqlparse
from pytest_mock_resources.patch.redshift.mock_s3_copy import mock_s3_copy_command, strip
from pytest_mock_resources.patch.redshift.mock_s3_copy import mock_s3_copy_command
from pytest_mock_resources.patch.redshift.mock_s3_unload import mock_s3_unload_command


Expand All @@ -25,9 +24,6 @@ def receive_before_execute(
individual cursor executions. Only the final statement's return value will be
returned.
"""
if isinstance(clauseelement, TextClause):
clauseelement = clauseelement.text

if isinstance(clauseelement, Executable):
return clauseelement, multiparams, params

Expand All @@ -37,7 +33,7 @@ def receive_before_execute(
for statement in statements:
cursor.execute(statement, *multiparams, **params)

return text(final_statement), multiparams, params
return final_statement, multiparams, params


def receive_before_cursor_execute(_, cursor, statement: str, parameters, context, executemany):
Expand All @@ -51,7 +47,7 @@ def receive_before_cursor_execute(_, cursor, statement: str, parameters, context
we return a no-op query to be executed by sqlalchemy for certain kinds of supported
extra features.
"""
normalized_statement = strip(statement).lower()
normalized_statement = _preprocess(statement).lower()
if normalized_statement.startswith("unload"):
mock_s3_unload_command(statement, cursor)
return "SELECT 1", {}
Expand Down
5 changes: 5 additions & 0 deletions src/pytest_mock_resources/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlalchemy.sql.schema import Table

from pytest_mock_resources import compat
from pytest_mock_resources.credentials import Credentials

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,6 +137,7 @@ def manage_sync(self):
if self.actions_share_transaction is False:
self.engine.dispose()

Credentials.assign_from_connection(session)
yield session
finally:
session.close()
Expand All @@ -147,6 +149,7 @@ def manage_sync(self):
if self.actions_share_transaction is False:
self.engine.dispose()

Credentials.assign_from_connection(self.engine)
yield self.engine

finally:
Expand All @@ -173,6 +176,7 @@ async def manage_async(self, session=None):
await session.commit()
await session.close()

Credentials.assign_from_connection(engine.sync_engine)
yield session
else:
async with engine.begin() as conn:
Expand All @@ -182,6 +186,7 @@ async def manage_async(self, session=None):
if not self.actions_share_transaction:
await engine.dispose()

Credentials.assign_from_connection(engine.sync_engine)
yield engine
finally:
await engine.dispose()
Expand Down
32 changes: 29 additions & 3 deletions tests/fixture/redshift/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sqlalchemy import text

from pytest_mock_resources import create_postgres_fixture, create_redshift_fixture
from tests import skip_if_sqlalchemy2
from tests.fixture.redshift.utils import (
copy_fn_to_test_create_engine_patch,
copy_fn_to_test_psycopg2_connect_patch,
Expand Down Expand Up @@ -77,17 +78,42 @@ def test_tightly_scoped_patch(redshift, postgres):
assert 'syntax error at or near "credentials"' in str(e.value)


redshift_engine = create_redshift_fixture(session=False)
redshift_session = create_redshift_fixture(session=True)
async_redshift_engine = create_redshift_fixture(session=False, async_=True)
async_redshift_session = create_redshift_fixture(session=True, async_=True)


def test_event_listener_registration(redshift_session):
@skip_if_sqlalchemy2
def test_event_listener_registration_engine(redshift_engine):
with redshift_engine.connect() as conn:
result = conn.execute("select 1; select 1").scalar()

result = redshift_engine.execute("select 1; select 1").scalar()
assert result == 1


@skip_if_sqlalchemy2
def test_event_listener_registration_session(redshift_session):
result = redshift_session.execute("select 1; select 1").scalar()
assert result == 1


def test_event_listener_registration_text(redshift_session):
result = redshift_session.execute(text("select 1; select 1")).scalar()
assert result == 1


@pytest.mark.asyncio
async def test_event_listener_registration_async(async_redshift_session):
result = await async_redshift_session.execute(text("select 1; select 1"))
async def test_event_listener_registration_async_engine(async_redshift_engine):
async with async_redshift_engine.connect() as conn:
result = await conn.execute(text("select 1"))
value = result.scalar()
assert value == 1


@pytest.mark.asyncio
async def test_event_listener_registration_async_session(async_redshift_session):
result = await async_redshift_session.execute(text("select 1"))
value = result.scalar()
assert value == 1
Loading

0 comments on commit e9d5f0f

Please sign in to comment.