Skip to content

Commit

Permalink
feat: Use sqlalchemy's event system to apply redshift behavior.
Browse files Browse the repository at this point in the history
Using patches/manually wrapping an engine's `execute` method is
the existing behavior and is comparatively fragile. Instead we now
use sqlalchemy's event system (which is designed for this sort of
thing!) to hook into the query execution to add our behavior.

This fixes long-standing inconsistencies between default sqlalchemy
behavior, and hopefully better hides the changes to calling code.
  • Loading branch information
DanCardin committed Nov 9, 2021
1 parent a09fc18 commit 2455620
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,11 @@ def _(_redshift_container, pmr_postgres_config):
password=pmr_postgres_config.password,
)

engine = sqlalchemy.substitute_execute_with_custom_execute(engine)
engine_manager = EngineManager(
engine, ordered_actions, tables=tables, default_schema="public"
)

with psycopg2.patch_connect(pmr_postgres_config):
for engine in engine_manager.manage(session=session):
yield engine
sqlalchemy.register_redshift_behavior(engine)
with psycopg2.patch_connect(pmr_postgres_config, database_name):
engine_manager = EngineManager(
engine, ordered_actions, tables=tables, default_schema="public"
)
yield from engine_manager.manage(session=session)

return _
81 changes: 50 additions & 31 deletions src/pytest_mock_resources/patch/redshift/mock_s3_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,28 @@
import csv
import gzip
import io
import sys

from sqlalchemy import MetaData, Table
import attr

from pytest_mock_resources.compat import boto3


def execute_mock_s3_copy_command(statement, engine):
params = _parse_s3_command(statement)
@attr.s
class S3CopyCommand:
table_name = attr.ib()
delimiter = attr.ib()
s3_uri = attr.ib()
empty_as_null = attr.ib()
format = attr.ib(default="CSV")
aws_access_key_id = attr.ib(default=None)
aws_secret_access_key = attr.ib(default=None)
columns = attr.ib(default=None)
schema_name = attr.ib(default=None)

_mock_s3_copy(
table_name=params["table_name"],
schema_name=params["schema_name"],
s3_uri=params["s3_uri"],
aws_secret_access_key=params["aws_secret_access_key"],
aws_access_key_id=params["aws_access_key_id"],
columns=params.get("columns", None),
engine=engine,
)

def mock_s3_copy_command(statement, cursor):
copy_command = _parse_s3_command(statement)
return _mock_s3_copy(cursor, copy_command)


def _parse_s3_command(statement):
Expand All @@ -35,6 +38,7 @@ def _parse_s3_command(statement):
params["schema_name"], params["table_name"] = _split_table_name(tokens.pop(0))

# Checking for columns
params["columns"] = []
if tokens[0][0] == "(":
ending_index = 0
for index, arg in enumerate(tokens):
Expand Down Expand Up @@ -64,7 +68,8 @@ def _parse_s3_command(statement):
).format(statement=statement)
)
params["s3_uri"] = strip(tokens.pop(0))

empty_as_null = False
delimiter = None
# Fetching credentials
for token in tokens:
if "aws_access_key_id" in token.lower() or "aws_secret_access_key" in token.lower():
Expand Down Expand Up @@ -100,7 +105,14 @@ def _parse_s3_command(statement):
" No Support for additional credential formats, eg IAM roles, etc, yet."
).format(statement=statement)
)
return params
if "emptyasnull" == token.lower():
empty_as_null = True
if "csv" == token.lower():
delimiter = ","

if delimiter is None:
delimiter = "|"
return S3CopyCommand(**params, empty_as_null=empty_as_null, delimiter=delimiter)


def _split_table_name(table_name):
Expand All @@ -116,14 +128,17 @@ def _split_table_name(table_name):


def _mock_s3_copy(
table_name, s3_uri, schema_name, aws_secret_access_key, aws_access_key_id, columns, engine
cursor,
copy_command,
):
"""Execute patched 'copy' command."""
s3 = boto3.client(
"s3", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
"s3",
aws_access_key_id=copy_command.aws_access_key_id,
aws_secret_access_key=copy_command.aws_secret_access_key,
)
ending_index = len(s3_uri)
path_to_file = s3_uri[5:ending_index]
ending_index = len(copy_command.s3_uri)
path_to_file = copy_command.s3_uri[5:ending_index]
bucket, key = path_to_file.split("/", 1)
response = s3.get_object(Bucket=bucket, Key=key)

Expand All @@ -134,25 +149,29 @@ def _mock_s3_copy(
is_gzipped = binascii.hexlify(response["Body"].read(2)) == b"1f8b"

response = s3.get_object(Bucket=bucket, Key=key)
data = read_data_csv(response["Body"].read(), is_gzipped, columns)

meta = MetaData()
table = Table(table_name, meta, autoload=True, schema=schema_name, autoload_with=engine)
engine.execute(table.insert(data))
data = get_raw_file(response["Body"].read(), is_gzipped)

cursor.copy_expert(
"COPY {cc.table_name} FROM STDIN WITH DELIMITER AS '{cc.delimiter}' {cc.format} HEADER {non_null_clause}".format(
cc=copy_command,
non_null_clause=("FORCE NOT NULL " + ", ".join(copy_command.columns))
if copy_command.columns
else "",
),
data,
)


def read_data_csv(file, is_gzipped=False, columns=None, delimiter="|"):
def get_raw_file(file, is_gzipped=False):
buffer = io.BytesIO(file)
if is_gzipped:
buffer = gzip.GzipFile(fileobj=buffer, mode="rb")
return buffer

# FUCK you python 2. This is ridiculous!
wrapper = buffer
if sys.version_info.major >= 3:
wrapper = io.TextIOWrapper(buffer)
else:
delimiter = delimiter.encode("utf-8")

def read_data_csv(file, is_gzipped=False, columns=None, delimiter="|"):
buffer = get_raw_file(file, is_gzipped=is_gzipped)
wrapper = io.TextIOWrapper(buffer)
reader = csv.DictReader(
wrapper,
delimiter=delimiter,
Expand Down
33 changes: 15 additions & 18 deletions src/pytest_mock_resources/patch/redshift/mock_s3_unload.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import csv
import gzip
import io
import sys

from pytest_mock_resources.compat import boto3
from pytest_mock_resources.patch.redshift.mock_s3_copy import strip


def execute_mock_s3_unload_command(statement, engine):
def mock_s3_unload_command(statement, cursor):
params = _parse_s3_command(statement)

_mock_s3_unload(
return _mock_s3_unload(
select_statement=params["select_statement"],
s3_uri=params["s3_uri"],
aws_secret_access_key=params["aws_secret_access_key"],
aws_access_key_id=params["aws_access_key_id"],
engine=engine,
cursor=cursor,
delimiter=params.get("delimiter", "|"),
is_gzipped=params["gzip"],
)
Expand Down Expand Up @@ -136,7 +135,7 @@ def _mock_s3_unload(
s3_uri,
aws_secret_access_key,
aws_access_key_id,
engine,
cursor,
delimiter,
is_gzipped,
):
Expand All @@ -146,8 +145,12 @@ def _mock_s3_unload(
path_to_file = s3_uri[5:ending_index]
bucket, key = path_to_file.split("/", 1)

result = engine.execute(select_statement)
buffer = get_data_csv(result, is_gzipped=is_gzipped, delimiter=delimiter)
cursor.execute(select_statement)
result = cursor.fetchall()
column_names = [desc[0] for desc in cursor.description]
buffer = get_data_csv(
result, column_names=column_names, is_gzipped=is_gzipped, delimiter=delimiter
)

# Push the data to the S3 Bucket.
conn = boto3.resource(
Expand All @@ -158,22 +161,17 @@ def _mock_s3_unload(
obj.put(Body=buffer)


def get_data_csv(rows, is_gzipped=False, delimiter="|", **additional_to_csv_options):
def get_data_csv(rows, column_names, is_gzipped=False, delimiter="|", **additional_to_csv_options):
result = io.BytesIO()
buffer = result
if is_gzipped:
buffer = gzip.GzipFile(fileobj=buffer, mode="wb")

# FUCK you python 2. This is ridiculous!
wrapper = buffer
if sys.version_info.major >= 3:
wrapper = io.TextIOWrapper(buffer)
else:
delimiter = delimiter.encode("utf-8")
wrapper = io.TextIOWrapper(buffer)

writer = csv.DictWriter(
wrapper,
fieldnames=rows.keys(),
fieldnames=column_names,
delimiter=delimiter,
quoting=csv.QUOTE_MINIMAL,
quotechar='"',
Expand All @@ -183,10 +181,9 @@ def get_data_csv(rows, is_gzipped=False, delimiter="|", **additional_to_csv_opti
)
writer.writeheader()
for row in rows:
writer.writerow(dict(row.items()))
writer.writerow(dict(zip(column_names, row)))

if sys.version_info.major >= 3:
wrapper.detach()
wrapper.detach()

if is_gzipped:
buffer.close()
Expand Down
70 changes: 40 additions & 30 deletions src/pytest_mock_resources/patch/redshift/psycopg2.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,64 @@
import contextlib
from unittest import mock

from sqlalchemy import create_engine
from sqlalchemy.sql.base import Executable
from pytest_mock_resources.container.postgres import PostgresConfig
from sqlalchemy.sql.elements import TextClause

from pytest_mock_resources.compat import psycopg2
from pytest_mock_resources.patch.redshift.mock_s3_copy import strip
from pytest_mock_resources.patch.redshift.sqlalchemy import (
execute_mock_s3_copy_command,
execute_mock_s3_unload_command,
)
import contextlib

from sqlalchemy.sql.elements import TextClause

from unittest import mock
from pytest_mock_resources.compat import psycopg2
from pytest_mock_resources.patch.redshift.mock_s3_copy import mock_s3_copy_command, strip
from pytest_mock_resources.patch.redshift.mock_s3_unload import mock_s3_unload_command


@contextlib.contextmanager
def patch_connect(config: PostgresConfig, database: str):
new_connect = mock_psycopg2_connect(config, database, _connect=psycopg2._connect)
with mock.patch("psycopg2._connect", new=new_connect) as p:
yield p


def mock_psycopg2_connect(config, _connect):
def mock_psycopg2_connect(config: PostgresConfig, database: str, _connect):
"""Patch `psycopg2._connect`.
Add support for S3 COPY and UNLOAD.
"""

class CustomCursor(psycopg2.extensions.cursor):
"""A custom cursor class to define a custom execute method."""

def execute(self, sql, args=None):
dsn_params = self.connection.get_dsn_parameters()
engine = create_engine(
"postgresql+psycopg2://{user}:{password}@{hostname}:{port}/{dbname}".format(
user=config.username,
password=config.password,
hostname=config.host,
port=config.port,
dbname=dsn_params["dbname"],
)
)
if not isinstance(sql, TextClause) and strip(sql).lower().startswith("copy"):
return execute_mock_s3_copy_command(sql, engine)
if not isinstance(sql, TextClause) and strip(sql).lower().startswith("unload"):
return execute_mock_s3_unload_command(sql, engine)
return super(CustomCursor, self).execute(sql, args)
if isinstance(sql, Executable):
return super().execute(sql, args)

if strip(sql).lower().startswith("copy"):
mock_s3_copy_command(sql, self)
sql = 'commit'

if strip(sql).lower().startswith("unload"):
mock_s3_unload_command(sql, self)
sql = 'commit'

return super().execute(sql, args)

def _mock_psycopg2_connect(*args, **kwargs):
"""Substitute the default cursor with a custom cursor."""
conn = _connect(*args, **kwargs)
conn.cursor_factory = CustomCursor
return conn
dsn_info = conn.get_dsn_parameters()

return _mock_psycopg2_connect
connection_info_matches = (
config.host == dsn_info['host'],
str(config.port) == dsn_info['port'],
database == dsn_info['dbname'],
)

if connection_info_matches:
conn.cursor_factory = CustomCursor
return conn

@contextlib.contextmanager
def patch_connect(config):
new_connect = mock_psycopg2_connect(config, _connect=psycopg2._connect)
with mock.patch("psycopg2._connect", new=new_connect) as p:
yield p
return _mock_psycopg2_connect
Loading

0 comments on commit 2455620

Please sign in to comment.