Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix redshift event listener interaction with the session keyword. #171

Merged
merged 1 commit into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .github/workflows/lint_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,13 @@ jobs:
- name: Run image
uses: abatilo/actions-poetry@v2.0.0
with:
poetry-version: 1.1.8
poetry-version: 1.2.0

- name: Set up cache
uses: actions/cache@v2
with:
path: ~/.cache/pypoetry/virtualenvs
path: ~/.cache/pypoetry
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-poetry-
- name: Install dependencies
run: ${{ inputs.install }}

Expand Down
691 changes: 383 additions & 308 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytest-mock-resources"
version = "2.6.0"
version = "2.6.1"
description = "A pytest plugin for easily instantiating reproducible mock resources."
authors = [
"Omar Khan <oakhan3@gmail.com>",
Expand Down Expand Up @@ -55,9 +55,9 @@ black = "22.3.0"
coverage = "*"
flake8 = "*"
isort = ">=5.0"
mypy = {version = "0.931", python = ">=3.5"}
pydocstyle = {version = "*", python = ">=3.5"}
sqlalchemy-stubs = {version = "*", python = ">=3.5"}
mypy = {version = "0.982"}
pydocstyle = {version = "*"}
sqlalchemy-stubs = {version = "*"}
pytest-xdist = "*"
pytest-asyncio = "*"
types-six = "^1.16.0"
Expand Down
13 changes: 13 additions & 0 deletions src/pytest_mock_resources/fixture/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import uuid

import pytest


def generate_fixture_id(enabled: bool = True, name=""):
if enabled:
uuid_str = str(uuid.uuid4()).replace("-", "_")
return "_".join(["pmr_template", name, uuid_str])
return None


def asyncio_fixture(async_fixture, scope="function"):
# pytest-asyncio in versions >=0.17 force you to use a `pytest_asyncio.fixture`
# call instead of `pytest.fixture`. Given that this would introduce an unncessary
# dependency on pytest-asyncio (when there are other alternatives) seems less than
# ideal, so instead we can just set the flag that they set, as the indicator.
async_fixture._force_asyncio_fixture = True

fixture = pytest.fixture(scope=scope)
return fixture(async_fixture)
11 changes: 2 additions & 9 deletions src/pytest_mock_resources/fixture/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pytest_mock_resources.container.base import get_container
from pytest_mock_resources.container.postgres import get_sqlalchemy_engine, PostgresConfig
from pytest_mock_resources.fixture.base import generate_fixture_id
from pytest_mock_resources.fixture.base import asyncio_fixture, generate_fixture_id
from pytest_mock_resources.fixture.credentials import assign_fixture_credentials
from pytest_mock_resources.sqlalchemy import bifurcate_actions, EngineManager, normalize_actions

Expand Down Expand Up @@ -108,14 +108,7 @@ async def _async(pmr_postgres_container, pmr_postgres_config):
yield engine

if async_:
# pytest-asyncio in versions >=0.17 force you to use a `pytest_asyncio.fixture`
# call instead of `pytest.fixture`. Given that this would introduce an unncessary
# dependency on pytest-asyncio (when there are other alternatives) seems less than
# ideal, so instead we can just set the flag that they set, as the indicator.
_async._force_asyncio_fixture = True

fixture = pytest.fixture(scope=scope)
return fixture(_async)
return asyncio_fixture(_async, scope=scope)
else:
return _sync

Expand Down
16 changes: 10 additions & 6 deletions src/pytest_mock_resources/fixture/redshift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pytest_mock_resources.container.base import get_container
from pytest_mock_resources.container.redshift import get_sqlalchemy_engine, RedshiftConfig
from pytest_mock_resources.fixture.base import generate_fixture_id
from pytest_mock_resources.fixture.base import asyncio_fixture, generate_fixture_id
from pytest_mock_resources.fixture.postgresql import create_engine_manager
from pytest_mock_resources.patch.redshift import psycopg2, sqlalchemy

Expand Down Expand Up @@ -98,21 +98,25 @@ def _sync(pmr_redshift_container, pmr_redshift_config):
database_name = engine_manager.engine.url.database

for engine in engine_manager.manage_sync():
sqlalchemy.register_redshift_behavior(engine)
sqlalchemy.register_redshift_behavior(engine_manager.engine)
with psycopg2.patch_connect(pmr_redshift_config, database_name):
yield engine

@pytest.fixture(scope=scope)
async def _async(pmr_redshift_container, pmr_redshift_config):
engine_manager = _create_engine_manager(pmr_redshift_config)
database_name = engine_manager.engine.url.database

async for engine in engine_manager.manage_async():
async for conn in engine_manager.manage_async():
engine = conn
if session:
engine = conn.sync_session.connection().engine

sqlalchemy.register_redshift_behavior(engine)

with psycopg2.patch_connect(pmr_redshift_config, database_name):
yield engine
yield conn

if async_:
return _async
return asyncio_fixture(_async, scope=scope)
else:
return _sync
2 changes: 1 addition & 1 deletion src/pytest_mock_resources/hooks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings

_resource_kinds = ["postgres", "redshift", "mongo", "redis", "mysql"]
_resource_kinds = ["postgres", "redshift", "mongo", "redis", "mysql", "moto"]


def pytest_addoption(parser):
Expand Down
8 changes: 6 additions & 2 deletions src/pytest_mock_resources/patch/redshift/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Union

from sqlalchemy import event
from sqlalchemy import event, text
from sqlalchemy.sql.base import Executable
from sqlalchemy.sql.elements import TextClause

from pytest_mock_resources.compat import sqlparse
from pytest_mock_resources.patch.redshift.mock_s3_copy import mock_s3_copy_command, strip
Expand All @@ -24,6 +25,9 @@ def receive_before_execute(
individual cursor executions. Only the final statement's return value will be
returned.
"""
if isinstance(clauseelement, TextClause):
clauseelement = clauseelement.text

if isinstance(clauseelement, Executable):
return clauseelement, multiparams, params

Expand All @@ -33,7 +37,7 @@ def receive_before_execute(
for statement in statements:
cursor.execute(statement, *multiparams, **params)

return final_statement, multiparams, params
return text(final_statement), multiparams, params


def receive_before_cursor_execute(_, cursor, statement: str, parameters, context, executemany):
Expand Down
16 changes: 16 additions & 0 deletions tests/fixture/redshift/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,19 @@ def test_tightly_scoped_patch(redshift, postgres):
conn.execute(copy_command)

assert 'syntax error at or near "credentials"' in str(e.value)


redshift_session = create_redshift_fixture(session=True)
async_redshift_session = create_redshift_fixture(session=True, async_=True)


def test_event_listener_registration(redshift_session):
result = redshift_session.execute(text("select 1; select 1")).scalar()
assert result == 1


@pytest.mark.asyncio
async def test_event_listener_registration_async(async_redshift_session):
result = await async_redshift_session.execute(text("select 1; select 1"))
value = result.scalar()
assert value == 1