diff --git a/src/pytest_mock_resources/patch/redshift/psycopg2.py b/src/pytest_mock_resources/patch/redshift/psycopg2.py index 5ddece49..4aaf11cc 100644 --- a/src/pytest_mock_resources/patch/redshift/psycopg2.py +++ b/src/pytest_mock_resources/patch/redshift/psycopg2.py @@ -1,9 +1,9 @@ import contextlib +from unittest import mock from sqlalchemy import create_engine from sqlalchemy.sql.elements import TextClause -from unittest import mock from pytest_mock_resources.compat import psycopg2 from pytest_mock_resources.patch.redshift.mock_s3_copy import strip from pytest_mock_resources.patch.redshift.sqlalchemy import ( diff --git a/src/pytest_mock_resources/patch/redshift/sqlalchemy.py b/src/pytest_mock_resources/patch/redshift/sqlalchemy.py index 605b1b3f..67132e1d 100644 --- a/src/pytest_mock_resources/patch/redshift/sqlalchemy.py +++ b/src/pytest_mock_resources/patch/redshift/sqlalchemy.py @@ -1,29 +1,30 @@ -from sqlalchemy.sql.elements import TextClause -from sqlalchemy.sql.expression import Insert, Select, Update +from typing import Union + +from sqlalchemy.sql.base import Executable from pytest_mock_resources.compat import sqlparse from pytest_mock_resources.patch.redshift.mock_s3_copy import execute_mock_s3_copy_command, strip from pytest_mock_resources.patch.redshift.mock_s3_unload import execute_mock_s3_unload_command -SQLALCHEMY_BASES = (Select, Insert, Update, TextClause) - def substitute_execute_with_custom_execute(engine): """Substitute the default execute method with a custom execute for copy and unload command.""" default_execute = engine.execute - def custom_execute(statement, *args, **kwargs): - if not isinstance(statement, SQLALCHEMY_BASES) and strip(statement).lower().startswith( - "copy" - ): + def custom_execute(statement: Union[Executable, str], *args, **kwargs): + if isinstance(statement, Executable): + return default_execute(statement, *args, **kwargs) + + # The statement is assumed to be a string at this point. + normalized_statement = strip(statement).lower() + if normalized_statement.startswith("copy"): return execute_mock_s3_copy_command(statement, engine) - if not isinstance(statement, SQLALCHEMY_BASES) and strip(statement).lower().startswith( - "unload" - ): + elif normalized_statement.startswith("unload"): return execute_mock_s3_unload_command(statement, engine) + return default_execute(statement, *args, **kwargs) - def handle_multiple_statements(statement, *args, **kwargs): + def handle_multiple_statements(statement: Union[Executable, str], *args, **kwargs): """Split statement into individual sql statements and execute. Splits multiple statements by ';' and executes each. @@ -41,24 +42,20 @@ def handle_multiple_statements(statement, *args, **kwargs): return engine -def parse_multiple_statements(statement): +def parse_multiple_statements(statement: Union[Executable, str]): """Split the given sql statement into a list of individual sql statements.""" - statements_list = [] - # Ignore SQLAlchemy Text Objects. - if isinstance(statement, SQLALCHEMY_BASES): - statements_list.append(statement) - return statements_list - - # Prprocess input statement - statement = _preprocess(statement) + if isinstance(statement, Executable): + return [statement] - statements_list = [str(statement) for statement in sqlparse.split(statement)] + # Preprocess input statement + processed_statement = _preprocess(statement) + statements_list = [str(statement) for statement in sqlparse.split(processed_statement)] return statements_list -def _preprocess(statement): +def _preprocess(statement: str): """Preprocess the input statement.""" statement = statement.strip() # Replace any occourance of " with '. diff --git a/tests/fixture/database/test_copy.py b/tests/fixture/database/test_copy.py index afae0b7c..0a1904f3 100644 --- a/tests/fixture/database/test_copy.py +++ b/tests/fixture/database/test_copy.py @@ -1,7 +1,8 @@ import time import pytest -from sqlalchemy import text +from sqlalchemy import Column, Integer, text +from sqlalchemy.ext.declarative import declarative_base from pytest_mock_resources import create_redshift_fixture from pytest_mock_resources.compat import boto3, moto @@ -238,3 +239,48 @@ def test_multiple_sql_statemts(redshift): ) fetch_values_from_table_and_assert(redshift) + + +Base = declarative_base() + + +class Example(Base): + __tablename__ = "quarter" + __table_args__ = {"schema": "foo"} + + id = Column(Integer, primary_key=True) + + +redshift = create_redshift_fixture(Base) + + +def test_redshift_auto_schema_creation(redshift): + with moto.mock_s3(): + conn = boto3.resource( + "s3", + region_name="us-east-1", + aws_access_key_id="AAAAAAAAAAAAAAAAAAAA", + aws_secret_access_key="AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", + ) + conn.create_bucket(Bucket="mybucket") + conn.Object("mybucket", "file.csv").put(Body=get_data_csv(ResultProxy(original_data))) + + redshift.execute( + ( + "CREATE TEMP TABLE test_s3_copy_into_redshift " + "(i INT, f FLOAT, c CHAR(1), v VARCHAR(16));" + "{COMMAND} test_s3_copy_into_redshift {COLUMNS} {FROM} '{LOCATION}' " + "{CREDENTIALS} 'aws_access_key_id=AAAAAAAAAAAAAAAAAAAA;" + "aws_secret_access_key=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA'" + "{OPTIONAL_ARGS};".format( + COMMAND="COPY", + LOCATION="s3://mybucket/file.csv", + COLUMNS="(i, f, c, v)", + FROM="from", + CREDENTIALS="credentials", + OPTIONAL_ARGS="", + ) + ) + ) + + fetch_values_from_table_and_assert(redshift)