Skip to content

Commit

Permalink
fix: Address poor handling of SQL statement parsing for redshift.
Browse files Browse the repository at this point in the history
    Addresses a bug where `CreateSchema` statements would be incorrectly
    handled as though they were strings.
  • Loading branch information
DanCardin committed Aug 17, 2021
1 parent 281427d commit 11c4b97
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/pytest_mock_resources/patch/redshift/psycopg2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import contextlib
from unittest import mock

from sqlalchemy import create_engine
from sqlalchemy.sql.elements import TextClause

from unittest import mock
from pytest_mock_resources.compat import psycopg2
from pytest_mock_resources.patch.redshift.mock_s3_copy import strip
from pytest_mock_resources.patch.redshift.sqlalchemy import (
Expand Down
43 changes: 20 additions & 23 deletions src/pytest_mock_resources/patch/redshift/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.expression import Insert, Select, Update
from typing import Union

from sqlalchemy.sql.base import Executable

from pytest_mock_resources.compat import sqlparse
from pytest_mock_resources.patch.redshift.mock_s3_copy import execute_mock_s3_copy_command, strip
from pytest_mock_resources.patch.redshift.mock_s3_unload import execute_mock_s3_unload_command

SQLALCHEMY_BASES = (Select, Insert, Update, TextClause)


def substitute_execute_with_custom_execute(engine):
"""Substitute the default execute method with a custom execute for copy and unload command."""
default_execute = engine.execute

def custom_execute(statement, *args, **kwargs):
if not isinstance(statement, SQLALCHEMY_BASES) and strip(statement).lower().startswith(
"copy"
):
def custom_execute(statement: Union[Executable, str], *args, **kwargs):
if isinstance(statement, Executable):
return default_execute(statement, *args, **kwargs)

# The statement is assumed to be a string at this point.
normalized_statement = strip(statement).lower()
if normalized_statement.startswith("copy"):
return execute_mock_s3_copy_command(statement, engine)
if not isinstance(statement, SQLALCHEMY_BASES) and strip(statement).lower().startswith(
"unload"
):
elif normalized_statement.startswith("unload"):
return execute_mock_s3_unload_command(statement, engine)

return default_execute(statement, *args, **kwargs)

def handle_multiple_statements(statement, *args, **kwargs):
def handle_multiple_statements(statement: Union[Executable, str], *args, **kwargs):
"""Split statement into individual sql statements and execute.
Splits multiple statements by ';' and executes each.
Expand All @@ -41,24 +42,20 @@ def handle_multiple_statements(statement, *args, **kwargs):
return engine


def parse_multiple_statements(statement):
def parse_multiple_statements(statement: Union[Executable, str]):
"""Split the given sql statement into a list of individual sql statements."""
statements_list = []

# Ignore SQLAlchemy Text Objects.
if isinstance(statement, SQLALCHEMY_BASES):
statements_list.append(statement)
return statements_list

# Prprocess input statement
statement = _preprocess(statement)
if isinstance(statement, Executable):
return [statement]

statements_list = [str(statement) for statement in sqlparse.split(statement)]
# Preprocess input statement
processed_statement = _preprocess(statement)
statements_list = [str(statement) for statement in sqlparse.split(processed_statement)]

return statements_list


def _preprocess(statement):
def _preprocess(statement: str):
"""Preprocess the input statement."""
statement = statement.strip()
# Replace any occourance of " with '.
Expand Down
48 changes: 47 additions & 1 deletion tests/fixture/database/test_copy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import time

import pytest
from sqlalchemy import text
from sqlalchemy import Column, Integer, text
from sqlalchemy.ext.declarative import declarative_base

from pytest_mock_resources import create_redshift_fixture
from pytest_mock_resources.compat import boto3, moto
Expand Down Expand Up @@ -238,3 +239,48 @@ def test_multiple_sql_statemts(redshift):
)

fetch_values_from_table_and_assert(redshift)


Base = declarative_base()


class Example(Base):
__tablename__ = "quarter"
__table_args__ = {"schema": "foo"}

id = Column(Integer, primary_key=True)


redshift = create_redshift_fixture(Base)


def test_redshift_auto_schema_creation(redshift):
with moto.mock_s3():
conn = boto3.resource(
"s3",
region_name="us-east-1",
aws_access_key_id="AAAAAAAAAAAAAAAAAAAA",
aws_secret_access_key="AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
)
conn.create_bucket(Bucket="mybucket")
conn.Object("mybucket", "file.csv").put(Body=get_data_csv(ResultProxy(original_data)))

redshift.execute(
(
"CREATE TEMP TABLE test_s3_copy_into_redshift "
"(i INT, f FLOAT, c CHAR(1), v VARCHAR(16));"
"{COMMAND} test_s3_copy_into_redshift {COLUMNS} {FROM} '{LOCATION}' "
"{CREDENTIALS} 'aws_access_key_id=AAAAAAAAAAAAAAAAAAAA;"
"aws_secret_access_key=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA'"
"{OPTIONAL_ARGS};".format(
COMMAND="COPY",
LOCATION="s3://mybucket/file.csv",
COLUMNS="(i, f, c, v)",
FROM="from",
CREDENTIALS="credentials",
OPTIONAL_ARGS="",
)
)
)

fetch_values_from_table_and_assert(redshift)

0 comments on commit 11c4b97

Please sign in to comment.