Skip to content

Commit

Permalink
Merge pull request #133 from sqlalchemy-redshift/enums-for-parameters
Browse files Browse the repository at this point in the history
Use Enums for Format, Compression and Encoding
  • Loading branch information
graingert committed Sep 25, 2017
2 parents f0f5b6e + 1224e68 commit e1b4c23
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 34 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
- Do not enumerate `search_path` with external schemas (`Issue #120
<https://github.com/sqlalchemy-redshift/sqlalchemy-redshift/pull/120>`_)
- Return constraint name from get_pk_constraint and get_foreign_keys
- Use Enums for Format, Compression and Encoding.
Deprecate string parameters for these parameter types
(`Issue #133 <https://github.com/sqlalchemy-redshift/sqlalchemy-redshift/pull/133>`_)


0.6.0 (2017-05-04)
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
# version 0.9.2
'SQLAlchemy>=0.9.2',
],
extras_require={
':python_version < "3.4"': 'enum34 >= 1.1.6, < 2.0.0'
},
classifiers=[
"Development Status :: 4 - Beta",
"Environment :: Console",
Expand Down
80 changes: 49 additions & 31 deletions sqlalchemy_redshift/commands.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import collections
import enum
import numbers
import re
import warnings

import sqlalchemy as sa
from sqlalchemy.ext import compiler as sa_compiler
Expand Down Expand Up @@ -261,6 +263,25 @@ def visit_unload_from_select(element, compiler, **kw):
)


class Format(enum.Enum):
csv = 'CSV'
json = 'JSON'
avro = 'AVRO'


class Compression(enum.Enum):
gzip = 'GZIP'
lzop = 'LZOP'
bzip2 = 'BZIP2'


class Encoding(enum.Enum):
utf8 = 'UTF8'
utf16 = 'UTF16'
utf16le = 'UTF16LE'
utf16be = 'UTF16BE'


class CopyCommand(_ExecutableClause):
"""
Prepares a Redshift COPY statement.
Expand All @@ -285,11 +306,11 @@ class CopyCommand(_ExecutableClause):
iam_role_name: str, optional
IAM role name for role-based credentials. Required unless you supply
key based credentials (``access_key_id`` and ``secret_access_key``)
format : str, optional
CSV, JSON, or AVRO. Indicates the type of file to copy from
format : Format, optional
Indicates the type of file to copy from
quote : str, optional
Specifies the character to be used as the quote character when using
``format='CSV'``. The default is a double quotation mark ( ``"`` )
``format=Format.csv``. The default is a double quotation mark ( ``"`` )
delimiter : File delimiter, optional
defaults to ``|``
path_file : str, optional
Expand All @@ -298,9 +319,8 @@ class CopyCommand(_ExecutableClause):
defaults to ``'auto'``
fixed_width: iterable of (str, int), optional
List of (column name, length) pairs to control fixed-width output.
compression : str, optional
GZIP, LZOP, BZIP2, indicates the type of compression of the
file to copy
compression : Compression, optional
indicates the type of compression of the file to copy
accept_any_date : bool, optional
Allows any date format, including invalid formats such as
``00/00/00 00:00:00``, to be loaded as NULL without generating an error
Expand All @@ -319,10 +339,9 @@ class CopyCommand(_ExecutableClause):
empty_as_null : bool, optional
Boolean value denoting whether to load VARCHAR fields with empty
values as NULL instead of empty string
encoding : str, optional
``'UTF8'``, ``'UTF16'``, ``'UTF16LE'``, ``'UTF16BE'``. Specifies the
encoding type of the load data
defaults to ``'UTF8'``
encoding : Encoding, optional
Specifies the encoding type of the load data defaults to
``Encoding.utf8``
escape : bool, optional
When this parameter is specified, the backslash character (``\``) in
input data is treated as an escape character. The character that
Expand Down Expand Up @@ -388,8 +407,6 @@ class CopyCommand(_ExecutableClause):
manifest : bool, optional
Boolean value denoting whether data_location is a manifest file.
"""
formats = ['CSV', 'JSON', 'AVRO', None]
compression_types = ['GZIP', 'LZOP', 'BZIP2']

def __init__(self, to, data_location, access_key_id=None,
secret_access_key=None, session_token=None,
Expand Down Expand Up @@ -425,16 +442,17 @@ def __init__(self, to, data_location, access_key_id=None,
'"ignore_header" parameter should be an integer'
)

if format not in self.formats:
raise ValueError('"format" parameter must be one of %s' %
self.formats)
def check_enum(Enum, val):
if val is None:
return

if compression is not None:
if compression not in self.compression_types:
raise ValueError(
'"compression" parameter must be one of %s' %
self.compression_types
)
cleaned = Enum(val)
if cleaned is not val:
tpl = '{val!r} should be, {cleaned!r}, an instance of {Enum!r}'
msg = tpl.format(val=val, cleaned=cleaned, Enum=Enum)
warnings.warn(msg, DeprecationWarning)

return cleaned

table = None
columns = []
Expand All @@ -456,19 +474,19 @@ def __init__(self, to, data_location, access_key_id=None,
self.columns = columns
self.data_location = data_location
self.credentials = credentials
self.format = format
self.format = check_enum(Format, format)
self.quote = quote
self.path_file = path_file
self.delimiter = delimiter
self.fixed_width = fixed_width
self.compression = compression
self.compression = check_enum(Compression, compression)
self.manifest = manifest
self.accept_any_date = accept_any_date
self.accept_inv_chars = accept_inv_chars
self.blanks_as_null = blanks_as_null
self.date_format = date_format
self.empty_as_null = empty_as_null
self.encoding = encoding
self.encoding = check_enum(Encoding, encoding)
self.escape = escape
self.explicit_ids = explicit_ids
self.fill_record = fill_record
Expand Down Expand Up @@ -510,7 +528,7 @@ def visit_copy_command(element, compiler, **kw):
),
]

if element.format == 'CSV':
if element.format == Format.csv:
format_ = 'FORMAT AS CSV'
if element.quote is not None:
format_ += ' QUOTE AS :quote_character'
Expand All @@ -519,14 +537,14 @@ def visit_copy_command(element, compiler, **kw):
value=element.quote,
type_=sa.String,
))
elif element.format == 'JSON':
elif element.format == Format.json:
format_ = 'FORMAT AS JSON AS :json_option'
bindparams.append(sa.bindparam(
'json_option',
value=element.path_file,
type_=sa.String,
))
elif element.format == 'AVRO':
elif element.format == Format.avro:
format_ = 'FORMAT AS AVRO AS :avro_option'
bindparams.append(sa.bindparam(
'avro_option',
Expand All @@ -552,8 +570,8 @@ def visit_copy_command(element, compiler, **kw):
type_=sa.String,
))

if element.compression in ['GZIP', 'LZOP', 'BZIP2']:
parameters.append(element.compression)
if element.compression is not None:
parameters.append(Compression(element.compression).value)

if element.manifest:
parameters.append('MANIFEST')
Expand Down Expand Up @@ -583,8 +601,8 @@ def visit_copy_command(element, compiler, **kw):
if element.empty_as_null:
parameters.append('EMPTYASNULL')

if element.encoding in ['UTF8', 'UTF16', 'UTF16LE', 'UTF16BE']:
parameters.append('ENCODING AS ' + element.encoding)
if element.encoding is not None:
parameters.append('ENCODING AS ' + Encoding(element.encoding).value)

if element.escape:
parameters.append('ESCAPE')
Expand Down
9 changes: 7 additions & 2 deletions sqlalchemy_redshift/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
)
from sqlalchemy.types import VARCHAR, NullType

from .commands import CopyCommand, UnloadFromSelect
from .commands import (
CopyCommand, UnloadFromSelect, Format, Compression, Encoding
)
from .compat import string_types

try:
Expand All @@ -29,7 +31,10 @@
class RedshiftImpl(postgresql.PostgresqlImpl):
__dialect__ = 'redshift'

__all__ = ['CopyCommand', 'UnloadFromSelect', 'RedshiftDialect']
__all__ = [
'CopyCommand', 'UnloadFromSelect', 'RedshiftDialect', 'Compression',
'Encoding', 'Format',
]


# Regex for parsing and identity constraint out of adsrc, e.g.:
Expand Down
16 changes: 15 additions & 1 deletion tests/test_copy_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,19 @@ def test_different_tables():
data_location='s3://bucket',
access_key_id=access_key_id,
secret_access_key=secret_access_key,
format='CSV'
format=dialect.Format.csv,
)


def test_legacy_string_format():
metdata = sa.MetaData()
t1 = sa.Table('t1', metdata, sa.Column('col1', sa.Unicode()))
t2 = sa.Table('t2', metdata, sa.Column('col1', sa.Unicode()))
with pytest.raises(ValueError):
dialect.CopyCommand(
[t1.c.col1, t2.c.col1],
data_location='s3://bucket',
access_key_id=access_key_id,
secret_access_key=secret_access_key,
format='CSV',
)

0 comments on commit e1b4c23

Please sign in to comment.