-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Use sqlalchemy's event system to apply redshift behavior.
Using patches/manually wrapping an engine's `execute` method is the existing behavior and is comparatively fragile. Instead we now use sqlalchemy's event system (which is designed for this sort of thing!) to hook into the query execution to add our behavior. This fixes long-standing inconsistencies between default sqlalchemy behavior, and hopefully better hides the changes to calling code.
- Loading branch information
Showing
10 changed files
with
206 additions
and
167 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,64 @@ | ||
import contextlib | ||
from unittest import mock | ||
|
||
from sqlalchemy import create_engine | ||
from sqlalchemy.sql.base import Executable | ||
from pytest_mock_resources.container.postgres import PostgresConfig | ||
from sqlalchemy.sql.elements import TextClause | ||
|
||
from pytest_mock_resources.compat import psycopg2 | ||
from pytest_mock_resources.patch.redshift.mock_s3_copy import strip | ||
from pytest_mock_resources.patch.redshift.sqlalchemy import ( | ||
execute_mock_s3_copy_command, | ||
execute_mock_s3_unload_command, | ||
) | ||
import contextlib | ||
|
||
from sqlalchemy.sql.elements import TextClause | ||
|
||
from unittest import mock | ||
from pytest_mock_resources.compat import psycopg2 | ||
from pytest_mock_resources.patch.redshift.mock_s3_copy import mock_s3_copy_command, strip | ||
from pytest_mock_resources.patch.redshift.mock_s3_unload import mock_s3_unload_command | ||
|
||
|
||
@contextlib.contextmanager | ||
def patch_connect(config: PostgresConfig, database: str): | ||
new_connect = mock_psycopg2_connect(config, database, _connect=psycopg2._connect) | ||
with mock.patch("psycopg2._connect", new=new_connect) as p: | ||
yield p | ||
|
||
|
||
def mock_psycopg2_connect(config, _connect): | ||
def mock_psycopg2_connect(config: PostgresConfig, database: str, _connect): | ||
"""Patch `psycopg2._connect`. | ||
Add support for S3 COPY and UNLOAD. | ||
""" | ||
|
||
class CustomCursor(psycopg2.extensions.cursor): | ||
"""A custom cursor class to define a custom execute method.""" | ||
|
||
def execute(self, sql, args=None): | ||
dsn_params = self.connection.get_dsn_parameters() | ||
engine = create_engine( | ||
"postgresql+psycopg2://{user}:{password}@{hostname}:{port}/{dbname}".format( | ||
user=config.username, | ||
password=config.password, | ||
hostname=config.host, | ||
port=config.port, | ||
dbname=dsn_params["dbname"], | ||
) | ||
) | ||
if not isinstance(sql, TextClause) and strip(sql).lower().startswith("copy"): | ||
return execute_mock_s3_copy_command(sql, engine) | ||
if not isinstance(sql, TextClause) and strip(sql).lower().startswith("unload"): | ||
return execute_mock_s3_unload_command(sql, engine) | ||
return super(CustomCursor, self).execute(sql, args) | ||
if isinstance(sql, Executable): | ||
return super().execute(sql, args) | ||
|
||
if strip(sql).lower().startswith("copy"): | ||
mock_s3_copy_command(sql, self) | ||
sql = 'commit' | ||
|
||
if strip(sql).lower().startswith("unload"): | ||
mock_s3_unload_command(sql, self) | ||
sql = 'commit' | ||
|
||
return super().execute(sql, args) | ||
|
||
def _mock_psycopg2_connect(*args, **kwargs): | ||
"""Substitute the default cursor with a custom cursor.""" | ||
conn = _connect(*args, **kwargs) | ||
conn.cursor_factory = CustomCursor | ||
return conn | ||
dsn_info = conn.get_dsn_parameters() | ||
|
||
return _mock_psycopg2_connect | ||
connection_info_matches = ( | ||
config.host == dsn_info['host'], | ||
str(config.port) == dsn_info['port'], | ||
database == dsn_info['dbname'], | ||
) | ||
|
||
if connection_info_matches: | ||
conn.cursor_factory = CustomCursor | ||
return conn | ||
|
||
@contextlib.contextmanager | ||
def patch_connect(config): | ||
new_connect = mock_psycopg2_connect(config, _connect=psycopg2._connect) | ||
with mock.patch("psycopg2._connect", new=new_connect) as p: | ||
yield p | ||
return _mock_psycopg2_connect |
Oops, something went wrong.