From 2455620b7b9dd08b92e0ac33d625eb83cd6eeeaf Mon Sep 17 00:00:00 2001 From: Dan Cardin Date: Wed, 25 Mar 2020 08:18:40 -0400 Subject: [PATCH] feat: Use sqlalchemy's event system to apply redshift behavior. Using patches/manually wrapping an engine's `execute` method is the existing behavior and is comparatively fragile. Instead we now use sqlalchemy's event system (which is designed for this sort of thing!) to hook into the query execution to add our behavior. This fixes long-standing inconsistencies between default sqlalchemy behavior, and hopefully better hides the changes to calling code. --- .../database/relational/redshift/__init__.py | 14 ++-- .../patch/redshift/mock_s3_copy.py | 81 ++++++++++++------- .../patch/redshift/mock_s3_unload.py | 33 ++++---- .../patch/redshift/psycopg2.py | 70 +++++++++------- .../patch/redshift/sqlalchemy.py | 78 ++++++++++-------- tests/fixture/__init__.py | 3 + tests/fixture/database/__init__.py | 65 +++++++-------- .../database/relational/test_generic.py | 3 +- tests/fixture/database/test_copy.py | 18 ++--- tests/fixture/database/test_patch.py | 8 +- 10 files changed, 206 insertions(+), 167 deletions(-) diff --git a/src/pytest_mock_resources/fixture/database/relational/redshift/__init__.py b/src/pytest_mock_resources/fixture/database/relational/redshift/__init__.py index f703c319..1fe5f797 100644 --- a/src/pytest_mock_resources/fixture/database/relational/redshift/__init__.py +++ b/src/pytest_mock_resources/fixture/database/relational/redshift/__init__.py @@ -43,13 +43,11 @@ def _(_redshift_container, pmr_postgres_config): password=pmr_postgres_config.password, ) - engine = sqlalchemy.substitute_execute_with_custom_execute(engine) - engine_manager = EngineManager( - engine, ordered_actions, tables=tables, default_schema="public" - ) - - with psycopg2.patch_connect(pmr_postgres_config): - for engine in engine_manager.manage(session=session): - yield engine + sqlalchemy.register_redshift_behavior(engine) + with psycopg2.patch_connect(pmr_postgres_config, database_name): + engine_manager = EngineManager( + engine, ordered_actions, tables=tables, default_schema="public" + ) + yield from engine_manager.manage(session=session) return _ diff --git a/src/pytest_mock_resources/patch/redshift/mock_s3_copy.py b/src/pytest_mock_resources/patch/redshift/mock_s3_copy.py index 41dbe518..2593ec78 100644 --- a/src/pytest_mock_resources/patch/redshift/mock_s3_copy.py +++ b/src/pytest_mock_resources/patch/redshift/mock_s3_copy.py @@ -2,25 +2,28 @@ import csv import gzip import io -import sys -from sqlalchemy import MetaData, Table +import attr from pytest_mock_resources.compat import boto3 -def execute_mock_s3_copy_command(statement, engine): - params = _parse_s3_command(statement) +@attr.s +class S3CopyCommand: + table_name = attr.ib() + delimiter = attr.ib() + s3_uri = attr.ib() + empty_as_null = attr.ib() + format = attr.ib(default="CSV") + aws_access_key_id = attr.ib(default=None) + aws_secret_access_key = attr.ib(default=None) + columns = attr.ib(default=None) + schema_name = attr.ib(default=None) - _mock_s3_copy( - table_name=params["table_name"], - schema_name=params["schema_name"], - s3_uri=params["s3_uri"], - aws_secret_access_key=params["aws_secret_access_key"], - aws_access_key_id=params["aws_access_key_id"], - columns=params.get("columns", None), - engine=engine, - ) + +def mock_s3_copy_command(statement, cursor): + copy_command = _parse_s3_command(statement) + return _mock_s3_copy(cursor, copy_command) def _parse_s3_command(statement): @@ -35,6 +38,7 @@ def _parse_s3_command(statement): params["schema_name"], params["table_name"] = _split_table_name(tokens.pop(0)) # Checking for columns + params["columns"] = [] if tokens[0][0] == "(": ending_index = 0 for index, arg in enumerate(tokens): @@ -64,7 +68,8 @@ def _parse_s3_command(statement): ).format(statement=statement) ) params["s3_uri"] = strip(tokens.pop(0)) - + empty_as_null = False + delimiter = None # Fetching credentials for token in tokens: if "aws_access_key_id" in token.lower() or "aws_secret_access_key" in token.lower(): @@ -100,7 +105,14 @@ def _parse_s3_command(statement): " No Support for additional credential formats, eg IAM roles, etc, yet." ).format(statement=statement) ) - return params + if "emptyasnull" == token.lower(): + empty_as_null = True + if "csv" == token.lower(): + delimiter = "," + + if delimiter is None: + delimiter = "|" + return S3CopyCommand(**params, empty_as_null=empty_as_null, delimiter=delimiter) def _split_table_name(table_name): @@ -116,14 +128,17 @@ def _split_table_name(table_name): def _mock_s3_copy( - table_name, s3_uri, schema_name, aws_secret_access_key, aws_access_key_id, columns, engine + cursor, + copy_command, ): """Execute patched 'copy' command.""" s3 = boto3.client( - "s3", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key + "s3", + aws_access_key_id=copy_command.aws_access_key_id, + aws_secret_access_key=copy_command.aws_secret_access_key, ) - ending_index = len(s3_uri) - path_to_file = s3_uri[5:ending_index] + ending_index = len(copy_command.s3_uri) + path_to_file = copy_command.s3_uri[5:ending_index] bucket, key = path_to_file.split("/", 1) response = s3.get_object(Bucket=bucket, Key=key) @@ -134,25 +149,29 @@ def _mock_s3_copy( is_gzipped = binascii.hexlify(response["Body"].read(2)) == b"1f8b" response = s3.get_object(Bucket=bucket, Key=key) - data = read_data_csv(response["Body"].read(), is_gzipped, columns) - - meta = MetaData() - table = Table(table_name, meta, autoload=True, schema=schema_name, autoload_with=engine) - engine.execute(table.insert(data)) + data = get_raw_file(response["Body"].read(), is_gzipped) + + cursor.copy_expert( + "COPY {cc.table_name} FROM STDIN WITH DELIMITER AS '{cc.delimiter}' {cc.format} HEADER {non_null_clause}".format( + cc=copy_command, + non_null_clause=("FORCE NOT NULL " + ", ".join(copy_command.columns)) + if copy_command.columns + else "", + ), + data, + ) -def read_data_csv(file, is_gzipped=False, columns=None, delimiter="|"): +def get_raw_file(file, is_gzipped=False): buffer = io.BytesIO(file) if is_gzipped: buffer = gzip.GzipFile(fileobj=buffer, mode="rb") + return buffer - # FUCK you python 2. This is ridiculous! - wrapper = buffer - if sys.version_info.major >= 3: - wrapper = io.TextIOWrapper(buffer) - else: - delimiter = delimiter.encode("utf-8") +def read_data_csv(file, is_gzipped=False, columns=None, delimiter="|"): + buffer = get_raw_file(file, is_gzipped=is_gzipped) + wrapper = io.TextIOWrapper(buffer) reader = csv.DictReader( wrapper, delimiter=delimiter, diff --git a/src/pytest_mock_resources/patch/redshift/mock_s3_unload.py b/src/pytest_mock_resources/patch/redshift/mock_s3_unload.py index 26064e37..8d71f5f3 100644 --- a/src/pytest_mock_resources/patch/redshift/mock_s3_unload.py +++ b/src/pytest_mock_resources/patch/redshift/mock_s3_unload.py @@ -1,21 +1,20 @@ import csv import gzip import io -import sys from pytest_mock_resources.compat import boto3 from pytest_mock_resources.patch.redshift.mock_s3_copy import strip -def execute_mock_s3_unload_command(statement, engine): +def mock_s3_unload_command(statement, cursor): params = _parse_s3_command(statement) - _mock_s3_unload( + return _mock_s3_unload( select_statement=params["select_statement"], s3_uri=params["s3_uri"], aws_secret_access_key=params["aws_secret_access_key"], aws_access_key_id=params["aws_access_key_id"], - engine=engine, + cursor=cursor, delimiter=params.get("delimiter", "|"), is_gzipped=params["gzip"], ) @@ -136,7 +135,7 @@ def _mock_s3_unload( s3_uri, aws_secret_access_key, aws_access_key_id, - engine, + cursor, delimiter, is_gzipped, ): @@ -146,8 +145,12 @@ def _mock_s3_unload( path_to_file = s3_uri[5:ending_index] bucket, key = path_to_file.split("/", 1) - result = engine.execute(select_statement) - buffer = get_data_csv(result, is_gzipped=is_gzipped, delimiter=delimiter) + cursor.execute(select_statement) + result = cursor.fetchall() + column_names = [desc[0] for desc in cursor.description] + buffer = get_data_csv( + result, column_names=column_names, is_gzipped=is_gzipped, delimiter=delimiter + ) # Push the data to the S3 Bucket. conn = boto3.resource( @@ -158,22 +161,17 @@ def _mock_s3_unload( obj.put(Body=buffer) -def get_data_csv(rows, is_gzipped=False, delimiter="|", **additional_to_csv_options): +def get_data_csv(rows, column_names, is_gzipped=False, delimiter="|", **additional_to_csv_options): result = io.BytesIO() buffer = result if is_gzipped: buffer = gzip.GzipFile(fileobj=buffer, mode="wb") - # FUCK you python 2. This is ridiculous! - wrapper = buffer - if sys.version_info.major >= 3: - wrapper = io.TextIOWrapper(buffer) - else: - delimiter = delimiter.encode("utf-8") + wrapper = io.TextIOWrapper(buffer) writer = csv.DictWriter( wrapper, - fieldnames=rows.keys(), + fieldnames=column_names, delimiter=delimiter, quoting=csv.QUOTE_MINIMAL, quotechar='"', @@ -183,10 +181,9 @@ def get_data_csv(rows, is_gzipped=False, delimiter="|", **additional_to_csv_opti ) writer.writeheader() for row in rows: - writer.writerow(dict(row.items())) + writer.writerow(dict(zip(column_names, row))) - if sys.version_info.major >= 3: - wrapper.detach() + wrapper.detach() if is_gzipped: buffer.close() diff --git a/src/pytest_mock_resources/patch/redshift/psycopg2.py b/src/pytest_mock_resources/patch/redshift/psycopg2.py index 4aaf11cc..99e367ac 100644 --- a/src/pytest_mock_resources/patch/redshift/psycopg2.py +++ b/src/pytest_mock_resources/patch/redshift/psycopg2.py @@ -1,18 +1,30 @@ import contextlib -from unittest import mock from sqlalchemy import create_engine +from sqlalchemy.sql.base import Executable +from pytest_mock_resources.container.postgres import PostgresConfig from sqlalchemy.sql.elements import TextClause from pytest_mock_resources.compat import psycopg2 from pytest_mock_resources.patch.redshift.mock_s3_copy import strip -from pytest_mock_resources.patch.redshift.sqlalchemy import ( - execute_mock_s3_copy_command, - execute_mock_s3_unload_command, -) +import contextlib + +from sqlalchemy.sql.elements import TextClause + +from unittest import mock +from pytest_mock_resources.compat import psycopg2 +from pytest_mock_resources.patch.redshift.mock_s3_copy import mock_s3_copy_command, strip +from pytest_mock_resources.patch.redshift.mock_s3_unload import mock_s3_unload_command + + +@contextlib.contextmanager +def patch_connect(config: PostgresConfig, database: str): + new_connect = mock_psycopg2_connect(config, database, _connect=psycopg2._connect) + with mock.patch("psycopg2._connect", new=new_connect) as p: + yield p -def mock_psycopg2_connect(config, _connect): +def mock_psycopg2_connect(config: PostgresConfig, database: str, _connect): """Patch `psycopg2._connect`. Add support for S3 COPY and UNLOAD. @@ -20,35 +32,33 @@ def mock_psycopg2_connect(config, _connect): class CustomCursor(psycopg2.extensions.cursor): """A custom cursor class to define a custom execute method.""" - def execute(self, sql, args=None): - dsn_params = self.connection.get_dsn_parameters() - engine = create_engine( - "postgresql+psycopg2://{user}:{password}@{hostname}:{port}/{dbname}".format( - user=config.username, - password=config.password, - hostname=config.host, - port=config.port, - dbname=dsn_params["dbname"], - ) - ) - if not isinstance(sql, TextClause) and strip(sql).lower().startswith("copy"): - return execute_mock_s3_copy_command(sql, engine) - if not isinstance(sql, TextClause) and strip(sql).lower().startswith("unload"): - return execute_mock_s3_unload_command(sql, engine) - return super(CustomCursor, self).execute(sql, args) + if isinstance(sql, Executable): + return super().execute(sql, args) + + if strip(sql).lower().startswith("copy"): + mock_s3_copy_command(sql, self) + sql = 'commit' + + if strip(sql).lower().startswith("unload"): + mock_s3_unload_command(sql, self) + sql = 'commit' + + return super().execute(sql, args) def _mock_psycopg2_connect(*args, **kwargs): """Substitute the default cursor with a custom cursor.""" conn = _connect(*args, **kwargs) - conn.cursor_factory = CustomCursor - return conn + dsn_info = conn.get_dsn_parameters() - return _mock_psycopg2_connect + connection_info_matches = ( + config.host == dsn_info['host'], + str(config.port) == dsn_info['port'], + database == dsn_info['dbname'], + ) + if connection_info_matches: + conn.cursor_factory = CustomCursor + return conn -@contextlib.contextmanager -def patch_connect(config): - new_connect = mock_psycopg2_connect(config, _connect=psycopg2._connect) - with mock.patch("psycopg2._connect", new=new_connect) as p: - yield p + return _mock_psycopg2_connect diff --git a/src/pytest_mock_resources/patch/redshift/sqlalchemy.py b/src/pytest_mock_resources/patch/redshift/sqlalchemy.py index 67132e1d..4e46b623 100644 --- a/src/pytest_mock_resources/patch/redshift/sqlalchemy.py +++ b/src/pytest_mock_resources/patch/redshift/sqlalchemy.py @@ -1,54 +1,66 @@ from typing import Union +from sqlalchemy import event from sqlalchemy.sql.base import Executable from pytest_mock_resources.compat import sqlparse -from pytest_mock_resources.patch.redshift.mock_s3_copy import execute_mock_s3_copy_command, strip -from pytest_mock_resources.patch.redshift.mock_s3_unload import execute_mock_s3_unload_command +from pytest_mock_resources.patch.redshift.mock_s3_copy import mock_s3_copy_command, strip +from pytest_mock_resources.patch.redshift.mock_s3_unload import mock_s3_unload_command -def substitute_execute_with_custom_execute(engine): +def register_redshift_behavior(engine): """Substitute the default execute method with a custom execute for copy and unload command.""" - default_execute = engine.execute - def custom_execute(statement: Union[Executable, str], *args, **kwargs): - if isinstance(statement, Executable): - return default_execute(statement, *args, **kwargs) + event.listen(engine, "before_execute", receive_before_execute, retval=True) + event.listen(engine, "before_cursor_execute", receive_before_cursor_execute, retval=True) - # The statement is assumed to be a string at this point. - normalized_statement = strip(statement).lower() - if normalized_statement.startswith("copy"): - return execute_mock_s3_copy_command(statement, engine) - elif normalized_statement.startswith("unload"): - return execute_mock_s3_unload_command(statement, engine) - return default_execute(statement, *args, **kwargs) +def receive_before_execute( + conn, clauseelement: Union[Executable, str], multiparams, params, execution_options=None +): + """Handle the `before_execute` event. - def handle_multiple_statements(statement: Union[Executable, str], *args, **kwargs): - """Split statement into individual sql statements and execute. + Specifically, this only needs to handle the parsing of multiple statements into + individual cursor executions. Only the final statement's return value will be + returned. + """ + if isinstance(clauseelement, Executable): + return clauseelement, multiparams, params - Splits multiple statements by ';' and executes each. - NOTE: Only the result of the last statements is returned. - """ - statements_list = parse_multiple_statements(statement) - result = None - for statement in statements_list: - result = custom_execute(statement, *args, **kwargs) + *statements, final_statement = parse_multiple_statements(clauseelement) - return result + cursor = conn.connection.cursor() + for statement in statements: + cursor.execute(statement, *multiparams, **params) - # Now each statement is handled as if it contains multiple sql statements - engine.execute = handle_multiple_statements - return engine + return final_statement, multiparams, params -def parse_multiple_statements(statement: Union[Executable, str]): - """Split the given sql statement into a list of individual sql statements.""" - # Ignore SQLAlchemy Text Objects. - if isinstance(statement, Executable): - return [statement] +def receive_before_cursor_execute(_, cursor, statement: str, parameters, context, executemany): + """Handle the `before_cursor_execute` event. + + This is where we add support for custom features such as redshift COPY/UNLOAD because + the query has already been rendered (into a string) at this point. + + Notably, COPY/UNLOAD need to perform extra non-sql behavior and potentially execute + more than a single query and the interface requires that we return a query. Thus, + we return a no-op query to be executed by sqlalchemy for certain kinds of supported + extra features. + """ + normalized_statement = strip(statement).lower() + if normalized_statement.startswith("unload"): + mock_s3_unload_command(statement, cursor) + return "SELECT 1", {} - # Preprocess input statement + if normalized_statement.startswith("copy"): + mock_s3_copy_command(statement, cursor) + context.should_autocommit = True + return "SELECT 1", {} + return statement, parameters + + +def parse_multiple_statements(statement: str): + """Split the given sql statement into a list of individual sql statements.""" processed_statement = _preprocess(statement) statements_list = [str(statement) for statement in sqlparse.split(processed_statement)] diff --git a/tests/fixture/__init__.py b/tests/fixture/__init__.py index e69de29b..817bde56 100644 --- a/tests/fixture/__init__.py +++ b/tests/fixture/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.register_assert_rewrite("tests.fixture.database") diff --git a/tests/fixture/database/__init__.py b/tests/fixture/database/__init__.py index b67ecc73..55feb7d1 100644 --- a/tests/fixture/database/__init__.py +++ b/tests/fixture/database/__init__.py @@ -7,39 +7,34 @@ from pytest_mock_resources.patch.redshift.mock_s3_unload import get_data_csv original_data = [ - {"i": 3342, "f": 32434.0, "c": "a", "v": "gfhsdgaf"}, - {"i": 3343, "f": 0.0, "c": "b", "v": None}, - {"i": 0, "f": 32434.0, "c": None, "v": "gfhsdgaf"}, + (3342, 32434.0, "a", "gfhsdgaf"), + (3343, 0.0, "b", None), + (0, 32434.0, None, "gfhsdgaf"), ] +data_columns = ["i", "f", "c", "v"] -class ResultProxy: - def __init__(self, data): - self.data = data +def empty_as_string(row, stringify_value=False, c_space=True): + result = {} + for key, value in dict(zip(data_columns, row)).items(): + if value is None: + if c_space and key == "c": + row_value = " " + else: + row_value = "" + else: + row_value = value + if stringify_value: + row_value = str(value) - def keys(self): - return self.data[0].keys() - - def __iter__(self): - return iter(self.data) - - -def empty_as_string(row): - return { - key: value if value is not None else "" if key != "c" else " " for key, value in row.items() - } - - -def data_as_csv(data): - if isinstance(data, dict): - return {key: str(value) if value is not None else "" for key, value in data.items()} - return [data_as_csv(row) for row in data] + result[key] = row_value + return result COPY_TEMPLATE = ( "{COMMAND} test_s3_copy_into_redshift {COLUMNS} {FROM} '{LOCATION}' " "{CREDENTIALS} 'aws_access_key_id=AAAAAAAAAAAAAAAAAAAA;" - "aws_secret_access_key=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA'" + "aws_secret_access_key=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA' " "{OPTIONAL_ARGS};" ) @@ -54,7 +49,6 @@ def data_as_csv(data): def fetch_values_from_table_and_assert(engine): execute = engine.execute("SELECT * from test_s3_copy_into_redshift") results = [row for row in execute] - engine.execute("DROP TABLE test_s3_copy_into_redshift") assert len(results) == len(original_data) for index, val in enumerate(results): assert empty_as_string(results[index]) == empty_as_string(original_data[index]) @@ -64,10 +58,12 @@ def fetch_and_assert_psycopg2(cursor): cursor.execute("SELECT * from test_s3_copy_into_redshift") results = cursor.fetchall() assert len(results) == len(original_data) - for index, val in enumerate(results): - og_data = empty_as_string(original_data[index]) - assert results[index] == tuple([og_data[k] for k in ["i", "f", "c", "v"]]) - cursor.execute("DROP TABLE test_s3_copy_into_redshift") + for result, original in zip(results, original_data): + og_data = empty_as_string(original) + expected_result = tuple(og_data.values()) + print(result) + print(expected_result) + assert result == expected_result def setup_table_and_bucket(redshift, file_name="file.csv"): @@ -83,7 +79,7 @@ def setup_table_and_bucket(redshift, file_name="file.csv"): aws_secret_access_key="AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", ) conn.create_bucket(Bucket="mybucket") - conn.Object("mybucket", file_name).put(Body=get_data_csv(ResultProxy(original_data))) + conn.Object("mybucket", file_name).put(Body=get_data_csv(original_data, data_columns)) def setup_table_and_insert_data(engine): @@ -111,8 +107,7 @@ def fetch_values_from_s3_and_assert( ) response = s3.get_object(Bucket="mybucket", Key=file_name) data = read_data_csv(response["Body"].read(), is_gzipped=is_gzipped, delimiter=delimiter) - assert data == data_as_csv(original_data) - engine.execute("DROP TABLE test_s3_unload_from_redshift") + assert data == [empty_as_string(f, stringify_value=True, c_space=False) for f in original_data] def randomcase(s): @@ -148,10 +143,10 @@ def copy_fn_to_test_psycopg2_connect_patch(config): COPY_TEMPLATE.format( COMMAND="COPY", LOCATION="s3://mybucket/file.csv", - COLUMNS="", + COLUMNS="(i, f, c, v)", FROM="from", CREDENTIALS="credentials", - OPTIONAL_ARGS="", + OPTIONAL_ARGS="EMPTYASNULL", ) ) @@ -168,7 +163,7 @@ def copy_fn_to_test_psycopg2_connect_patch_as_context_manager(config): COPY_TEMPLATE.format( COMMAND="COPY", LOCATION="s3://mybucket/file.csv", - COLUMNS="", + COLUMNS="(i, f, c, v)", FROM="from", CREDENTIALS="credentials", OPTIONAL_ARGS="", diff --git a/tests/fixture/database/relational/test_generic.py b/tests/fixture/database/relational/test_generic.py index 0afec510..9e55fd6b 100644 --- a/tests/fixture/database/relational/test_generic.py +++ b/tests/fixture/database/relational/test_generic.py @@ -1,8 +1,7 @@ import pytest import sqlalchemy -from sqlalchemy import Column, Integer, MetaData, SmallInteger, Table, Unicode +from sqlalchemy import Column, Integer, MetaData, select, SmallInteger, Table, Unicode from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.future import select from sqlalchemy.orm import sessionmaker from pytest_mock_resources import create_postgres_fixture, create_sqlite_fixture, Rows diff --git a/tests/fixture/database/test_copy.py b/tests/fixture/database/test_copy.py index 0a1904f3..512d1114 100644 --- a/tests/fixture/database/test_copy.py +++ b/tests/fixture/database/test_copy.py @@ -1,6 +1,5 @@ import time -import pytest from sqlalchemy import Column, Integer, text from sqlalchemy.ext.declarative import declarative_base @@ -8,11 +7,11 @@ from pytest_mock_resources.compat import boto3, moto from tests.fixture.database import ( COPY_TEMPLATE, + data_columns, fetch_values_from_table_and_assert, get_data_csv, original_data, randomcase, - ResultProxy, setup_table_and_bucket, ) @@ -38,13 +37,12 @@ def test_s3_copy_into_redshift(redshift): fetch_values_from_table_and_assert(redshift) -@pytest.mark.xfail(strict=True, reason="Existing bug, we should reconsider our existing mechanism") def test_s3_copy_into_redshift_transaction(redshift): with moto.mock_s3(): setup_table_and_bucket(redshift) - with redshift.begin() as redshift_connection: - redshift_connection.execute( + with redshift.begin() as conn: + conn.execute( COPY_TEMPLATE.format( COMMAND="COPY", LOCATION="s3://mybucket/file.csv", @@ -95,7 +93,9 @@ def test_s3_copy_from_gzip(redshift): time_in_mills=int(round(time.time() * 1000)) ) - file = get_data_csv(ResultProxy(original_data), is_gzipped=True, path_or_buf=temp_file_name) + file = get_data_csv( + original_data, data_columns, is_gzipped=True, path_or_buf=temp_file_name + ) conn.Object("mybucket", "file.csv.gz").put(Body=file) redshift.execute( @@ -218,7 +218,7 @@ def test_multiple_sql_statemts(redshift): aws_secret_access_key="AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", ) conn.create_bucket(Bucket="mybucket") - conn.Object("mybucket", "file.csv").put(Body=get_data_csv(ResultProxy(original_data))) + conn.Object("mybucket", "file.csv").put(Body=get_data_csv(original_data, data_columns)) redshift.execute( ( @@ -235,7 +235,7 @@ def test_multiple_sql_statemts(redshift): CREDENTIALS="credentials", OPTIONAL_ARGS="", ) - ) + ), ) fetch_values_from_table_and_assert(redshift) @@ -263,7 +263,7 @@ def test_redshift_auto_schema_creation(redshift): aws_secret_access_key="AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", ) conn.create_bucket(Bucket="mybucket") - conn.Object("mybucket", "file.csv").put(Body=get_data_csv(ResultProxy(original_data))) + conn.Object("mybucket", "file.csv").put(Body=get_data_csv(original_data, data_columns)) redshift.execute( ( diff --git a/tests/fixture/database/test_patch.py b/tests/fixture/database/test_patch.py index 6fd1df10..83331191 100644 --- a/tests/fixture/database/test_patch.py +++ b/tests/fixture/database/test_patch.py @@ -1,4 +1,4 @@ -from pytest_mock_resources import create_redshift_fixture +from pytest_mock_resources import create_postgres_fixture, create_redshift_fixture from tests.fixture.database import ( copy_fn_to_test_create_engine_patch, copy_fn_to_test_psycopg2_connect_patch, @@ -9,6 +9,7 @@ ) redshift = create_redshift_fixture() +postgres = create_postgres_fixture() def test_copy(redshift): @@ -37,3 +38,8 @@ def test_unload_with_psycopg2(redshift): def test_unload_with_psycopg2_as_context_manager(redshift): config = redshift.pmr_credentials.as_psycopg2_kwargs() unload_fn_to_test_psycopg2_connect_patch_as_context_manager(config) + + +def test_tightly_scoped_patch(redshift, postgres): + redshift.execute("select 1; select 1;") + postgres.execute("select 1; select 1;")