Skip to content

Commit

Permalink
fix: Fix redshift event listener interaction with the session keyword. (
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin committed Oct 20, 2022
1 parent d405c7d commit 088b180
Show file tree
Hide file tree
Showing 9 changed files with 437 additions and 334 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/lint_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,13 @@ jobs:
- name: Run image
uses: abatilo/actions-poetry@v2.0.0
with:
poetry-version: 1.1.8
poetry-version: 1.2.0

- name: Set up cache
uses: actions/cache@v2
with:
path: ~/.cache/pypoetry/virtualenvs
path: ~/.cache/pypoetry
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-poetry-
- name: Install dependencies
run: ${{ inputs.install }}

Expand Down
691 changes: 383 additions & 308 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytest-mock-resources"
version = "2.6.0"
version = "2.6.1"
description = "A pytest plugin for easily instantiating reproducible mock resources."
authors = [
"Omar Khan <oakhan3@gmail.com>",
Expand Down Expand Up @@ -55,9 +55,9 @@ black = "22.3.0"
coverage = "*"
flake8 = "*"
isort = ">=5.0"
mypy = {version = "0.931", python = ">=3.5"}
pydocstyle = {version = "*", python = ">=3.5"}
sqlalchemy-stubs = {version = "*", python = ">=3.5"}
mypy = {version = "0.982"}
pydocstyle = {version = "*"}
sqlalchemy-stubs = {version = "*"}
pytest-xdist = "*"
pytest-asyncio = "*"
types-six = "^1.16.0"
Expand Down
13 changes: 13 additions & 0 deletions src/pytest_mock_resources/fixture/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import uuid

import pytest


def generate_fixture_id(enabled: bool = True, name=""):
if enabled:
uuid_str = str(uuid.uuid4()).replace("-", "_")
return "_".join(["pmr_template", name, uuid_str])
return None


def asyncio_fixture(async_fixture, scope="function"):
# pytest-asyncio in versions >=0.17 force you to use a `pytest_asyncio.fixture`
# call instead of `pytest.fixture`. Given that this would introduce an unncessary
# dependency on pytest-asyncio (when there are other alternatives) seems less than
# ideal, so instead we can just set the flag that they set, as the indicator.
async_fixture._force_asyncio_fixture = True

fixture = pytest.fixture(scope=scope)
return fixture(async_fixture)
11 changes: 2 additions & 9 deletions src/pytest_mock_resources/fixture/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pytest_mock_resources.container.base import get_container
from pytest_mock_resources.container.postgres import get_sqlalchemy_engine, PostgresConfig
from pytest_mock_resources.fixture.base import generate_fixture_id
from pytest_mock_resources.fixture.base import asyncio_fixture, generate_fixture_id
from pytest_mock_resources.fixture.credentials import assign_fixture_credentials
from pytest_mock_resources.sqlalchemy import bifurcate_actions, EngineManager, normalize_actions

Expand Down Expand Up @@ -108,14 +108,7 @@ async def _async(pmr_postgres_container, pmr_postgres_config):
yield engine

if async_:
# pytest-asyncio in versions >=0.17 force you to use a `pytest_asyncio.fixture`
# call instead of `pytest.fixture`. Given that this would introduce an unncessary
# dependency on pytest-asyncio (when there are other alternatives) seems less than
# ideal, so instead we can just set the flag that they set, as the indicator.
_async._force_asyncio_fixture = True

fixture = pytest.fixture(scope=scope)
return fixture(_async)
return asyncio_fixture(_async, scope=scope)
else:
return _sync

Expand Down
16 changes: 10 additions & 6 deletions src/pytest_mock_resources/fixture/redshift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pytest_mock_resources.container.base import get_container
from pytest_mock_resources.container.redshift import get_sqlalchemy_engine, RedshiftConfig
from pytest_mock_resources.fixture.base import generate_fixture_id
from pytest_mock_resources.fixture.base import asyncio_fixture, generate_fixture_id
from pytest_mock_resources.fixture.postgresql import create_engine_manager
from pytest_mock_resources.patch.redshift import psycopg2, sqlalchemy

Expand Down Expand Up @@ -98,21 +98,25 @@ def _sync(pmr_redshift_container, pmr_redshift_config):
database_name = engine_manager.engine.url.database

for engine in engine_manager.manage_sync():
sqlalchemy.register_redshift_behavior(engine)
sqlalchemy.register_redshift_behavior(engine_manager.engine)
with psycopg2.patch_connect(pmr_redshift_config, database_name):
yield engine

@pytest.fixture(scope=scope)
async def _async(pmr_redshift_container, pmr_redshift_config):
engine_manager = _create_engine_manager(pmr_redshift_config)
database_name = engine_manager.engine.url.database

async for engine in engine_manager.manage_async():
async for conn in engine_manager.manage_async():
engine = conn
if session:
engine = conn.sync_session.connection().engine

sqlalchemy.register_redshift_behavior(engine)

with psycopg2.patch_connect(pmr_redshift_config, database_name):
yield engine
yield conn

if async_:
return _async
return asyncio_fixture(_async, scope=scope)
else:
return _sync
2 changes: 1 addition & 1 deletion src/pytest_mock_resources/hooks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings

_resource_kinds = ["postgres", "redshift", "mongo", "redis", "mysql"]
_resource_kinds = ["postgres", "redshift", "mongo", "redis", "mysql", "moto"]


def pytest_addoption(parser):
Expand Down
8 changes: 6 additions & 2 deletions src/pytest_mock_resources/patch/redshift/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Union

from sqlalchemy import event
from sqlalchemy import event, text
from sqlalchemy.sql.base import Executable
from sqlalchemy.sql.elements import TextClause

from pytest_mock_resources.compat import sqlparse
from pytest_mock_resources.patch.redshift.mock_s3_copy import mock_s3_copy_command, strip
Expand All @@ -24,6 +25,9 @@ def receive_before_execute(
individual cursor executions. Only the final statement's return value will be
returned.
"""
if isinstance(clauseelement, TextClause):
clauseelement = clauseelement.text

if isinstance(clauseelement, Executable):
return clauseelement, multiparams, params

Expand All @@ -33,7 +37,7 @@ def receive_before_execute(
for statement in statements:
cursor.execute(statement, *multiparams, **params)

return final_statement, multiparams, params
return text(final_statement), multiparams, params


def receive_before_cursor_execute(_, cursor, statement: str, parameters, context, executemany):
Expand Down
16 changes: 16 additions & 0 deletions tests/fixture/redshift/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,19 @@ def test_tightly_scoped_patch(redshift, postgres):
conn.execute(copy_command)

assert 'syntax error at or near "credentials"' in str(e.value)


redshift_session = create_redshift_fixture(session=True)
async_redshift_session = create_redshift_fixture(session=True, async_=True)


def test_event_listener_registration(redshift_session):
result = redshift_session.execute(text("select 1; select 1")).scalar()
assert result == 1


@pytest.mark.asyncio
async def test_event_listener_registration_async(async_redshift_session):
result = await async_redshift_session.execute(text("select 1; select 1"))
value = result.scalar()
assert value == 1

0 comments on commit 088b180

Please sign in to comment.