Skip to content
This repository has been archived by the owner on Nov 8, 2021. It is now read-only.

Commit

Permalink
Add post insert step and add many validations and tests
Browse files Browse the repository at this point in the history
* Use make_url when creating new engine

* Add column type checking, fix bugs uncovered by new changes

* Improve schema alteration logic and add unit tests

* Bump requirements

* Add escaping to default value in alter table

* Lint

* Add more detailed column value testing

* Fix isort lint

* Add post_insert step

* Add unit test for undefined length string

* Add inits to test directories to trigger tests with nose
  • Loading branch information
villebro committed Dec 12, 2019
1 parent 92730ac commit f6b69fb
Show file tree
Hide file tree
Showing 14 changed files with 219 additions and 29 deletions.
22 changes: 11 additions & 11 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
alabaster==0.7.12 # via sphinx
babel==2.7.0 # via sphinx
certifi==2019.9.11 # via requests
certifi==2019.11.28 # via requests
chardet==3.0.4
click==7.0 # via pip-tools
codecov==2.0.15
Expand All @@ -17,30 +17,30 @@ filelock==3.0.12 # via tox
flake8==3.7.9
idna==2.8 # via requests
imagesize==1.1.0 # via sphinx
importlib-metadata==0.23 # via pluggy, tox
importlib-metadata==1.2.0 # via pluggy, tox
isort==4.3.21
jinja2==2.10.3 # via sphinx
markupsafe==1.1.1 # via jinja2
mccabe==0.6.1 # via flake8
more-itertools==7.2.0 # via zipp
more-itertools==8.0.2 # via zipp
mypy-extensions==0.4.3 # via mypy
mypy==0.740
mypy==0.750
nose==1.3.7
packaging==19.2 # via sphinx, tox
pip-tools==4.2.0
pluggy==0.13.0 # via tox
pip-tools==4.3.0
pluggy==0.13.1 # via tox
py==1.8.0 # via tox
pycodestyle==2.5.0 # via flake8
pyflakes==2.1.1 # via flake8
pygments==2.4.2 # via sphinx
pygments==2.5.2 # via sphinx
pyparsing==2.4.5 # via packaging
pytz==2019.3 # via babel
requests==2.22.0 # via codecov, sphinx
six==1.13.0 # via packaging, pip-tools, tox
snowballstemmer==2.0.0 # via sphinx
sphinx-autodoc-typehints==1.10.3
sphinx-rtd-theme==0.4.3
sphinx==2.2.1
sphinx==2.2.2
sphinxcontrib-applehelp==1.0.1 # via sphinx
sphinxcontrib-devhelp==1.0.1 # via sphinx
sphinxcontrib-htmlhelp==1.0.2 # via sphinx
Expand All @@ -49,12 +49,12 @@ sphinxcontrib-qthelp==1.0.2 # via sphinx
sphinxcontrib-serializinghtml==1.1.3 # via sphinx
sqlalchemy==1.3.11
toml==0.10.0 # via tox
tox==3.14.1
tox==3.14.2
typed-ast==1.4.0 # via mypy
typing-extensions==3.7.4.1 # via mypy
urllib3==1.25.7 # via requests
virtualenv==16.7.7 # via tox
virtualenv==16.7.8 # via tox
zipp==0.6.0 # via importlib-metadata

# The following packages are considered to be unsafe in a requirements file:
# setuptools==41.6.0 # via sphinx
# setuptools==42.0.2 # via sphinx
9 changes: 9 additions & 0 deletions sqltask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ def insert_rows(self) -> None:
for table_context in self._tables.values():
table_context.insert_rows()

def post_insert(self) -> None:
"""
Optional step to execute after insertion is completed. Usually used to execute
sql statements that don't require row-by-row transformation
"""
pass

def delete_rows(self) -> None:
"""
Delete rows in target tables.
Expand Down Expand Up @@ -156,6 +163,8 @@ def execute_etl(self):
self.delete_rows()
logger.debug(f"Start insert")
self.insert_rows()
logger.debug(f"Start post insert")
self.post_insert()
logger.debug(f"Finish etl")

def execute(self):
Expand Down
7 changes: 4 additions & 3 deletions sqltask/base/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Optional

from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.schema import MetaData

from sqltask.engine_specs import get_engine_spec
Expand Down Expand Up @@ -45,6 +46,6 @@ def create_new(self,
provided by the original engine context
:return: a new instance of EngineContext with different url
"""
engine = create_engine(str(self.engine.url))
self.engine_spec.modify_url(engine.url, database=database, schema=schema)
return EngineContext(self.name, str(engine.url), **self.metadata_kwargs)
url = make_url(str(self.engine.url))
self.engine_spec.modify_url(url, database=database, schema=schema)
return EngineContext(self.name, str(url), **self.metadata_kwargs)
2 changes: 1 addition & 1 deletion sqltask/base/lookup_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ def get(self, *unnamed_keys, **named_keys) -> Dict[str, Any]:
keys = [key for key in unnamed_keys]
for key in self.keys[len(unnamed_keys):]:
if key not in named_keys:
raise Exception(f"Key not in lookup: {key}")
raise ValueError(f"Key not in lookup: {key}")
keys.append(named_keys[key])
return store.get(tuple(keys), {})
19 changes: 15 additions & 4 deletions sqltask/base/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ def __init__(
"""
# comment is apparently not Optional, so needs to be passed via kwargs
table_params = table_params or {}
self.columns = {column.name: column for column in columns}
if comment:
table_params["comment"] = comment

table = Table(name,
engine_context.metadata,
*columns,
Expand Down Expand Up @@ -93,6 +95,7 @@ def migrate_schema(self) -> None:
engine = self.engine_context.engine
engine_spec = self.engine_context.engine_spec
metadata = self.engine_context.metadata

if engine.has_table(table.name, schema=self.schema):
inspector = sa.inspect(engine)

Expand Down Expand Up @@ -201,7 +204,6 @@ def __init__(

dq_table_name = dq_table_name or self.name + "_dq"
dq_engine_context = dq_engine_context or self.engine_context
dq_timestamp_column_name = self.timestamp_column_name or "etl_timestamp"
dq_schema = dq_schema or self.schema
dq_info_column_names = dq_info_column_names or []
dq_table_params = dq_table_params or {}
Expand Down Expand Up @@ -244,7 +246,6 @@ def __init__(
comment=comment,
schema=dq_schema,
batch_params=batch_params,
timestamp_column_name=dq_timestamp_column_name,
table_params=dq_table_params,
)

Expand Down Expand Up @@ -286,11 +287,21 @@ class BaseOutputRow(UserDict):
all batch parameters are prepopulated.
"""
def __init__(self, table_context: BaseTableContext):
super().__init__(table_context.batch_params)
self.table_context = table_context
super().__init__(table_context.batch_params)
if table_context.timestamp_column_name:
self[table_context.timestamp_column_name] = datetime.utcnow()

def __setitem__(self, key, value):
# validate column value if table schema defined
if self.table_context.columns is not None:
target_column = self.table_context.columns.get(key)
if target_column is None:
raise KeyError(f"Column not found in target schema: {key}")
engine_spec = self.table_context.engine_context.engine_spec
engine_spec.validate_column_value(value, target_column)
super().__setitem__(key, value)

def map_all(self,
input_row: Mapping[str, Any],
columns: Optional[Sequence[str]] = None,
Expand Down Expand Up @@ -326,7 +337,7 @@ def append(self) -> None:
"""

output_row = {}
for column in self.table_context.table.columns:
for column in self.table_context.columns.values():
if column.name not in self:
raise Exception(f"No column `{column.name}` in output row for table "
f"`{self.table_context.name}`")
Expand Down
89 changes: 82 additions & 7 deletions sqltask/engine_specs/base.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,47 @@
import logging
from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional, Set, Tuple, Type

from sqlalchemy import types
from sqlalchemy.engine.url import URL
from sqlalchemy.schema import Column
from sqlalchemy.sql import text
from sqlalchemy.sql.type_api import TypeEngine

from sqltask.base.common import UrlParams
from sqltask.base.table import BaseTableContext
from sqltask.utils.engine_specs import get_escaped_string_value

log = logging
logger = logging.getLogger(__name__)


class UploadType(Enum):
SQL_INSERT = 1
CSV = 2


VALID_COLUMN_TYPES: Dict[Type[TypeEngine], Tuple[Any, ...]] = {
types.Date: (date,),
types.DATE: (date,),
types.DateTime: (date, datetime),
types.DATETIME: (date, datetime),
types.INT: (int,),
types.INTEGER: (int,),
types.Integer: (int,),
types.Float: (int, float),
types.String: (str,),
types.NVARCHAR: (str,),
types.VARCHAR: (str,),
types.SmallInteger: (int,),
types.SMALLINT: (int,),
types.BIGINT: (int,),
types.BigInteger: (int,),
types.Numeric: (int, float),
types.NUMERIC: (int, float),
}


class BaseEngineSpec:
"""
Generic spec defining default behaviour for SqlAlchemy engines.
Expand Down Expand Up @@ -100,7 +125,7 @@ def modify_url(cls, url: URL, database: Optional[str], schema: Optional[str]) ->
database_current = url.database
schema_current = None
if not cls.supports_schemas or database is None:
return None
return
if "/" in database_current:
database_current, schema_current = database_current.split("/")

Expand Down Expand Up @@ -142,9 +167,27 @@ def add_column(cls,
:return:
"""
table_name = table_context.table.name
dialect = table_context.engine_context.engine.dialect
logging.debug(f"Add column `{column.name}` to table `{table_name}`")
stmt = f'ALTER TABLE {table_name} ADD COLUMN ' \
f'{column.name} {str(column.type)}'
stmt = f"ALTER TABLE {table_name} ADD COLUMN " \
f"{column.name} {column.type.compile(dialect=dialect)}"
if column.default is not None:
if isinstance(column.default, str):
default_value = f"'{get_escaped_string_value(column.default)}'"
else:
default_value = column.default
stmt += f" DEFAULT {default_value}"
if column.autoincrement is True:
stmt += " AUTOINCREMENT"
if column.nullable is True:
stmt += " NULL"
else:
stmt += " NOT NULL"
if column.primary_key is True:
stmt += " PRIMARY KEY"
if cls.supports_column_comments and column.comment:
comment = get_escaped_string_value(column.comment)
stmt += f" COMMENT '{comment}'"
table_context.engine_context.engine.execute(stmt)

@classmethod
Expand Down Expand Up @@ -177,7 +220,7 @@ def update_table_comment(cls,
"""
table_name = table_context.table.name
logging.info(f"Change comment on table `{table_name}`")
comment = comment.replace("'", "\\'")
comment = get_escaped_string_value(comment)
stmt = f"COMMENT ON TABLE {table_name} IS '{comment}'"
table_context.engine_context.engine.execute(stmt)

Expand All @@ -196,6 +239,38 @@ def update_column_comment(cls,
"""
table_name = table_context.table.name
logging.info(f"Change comment on table `{table_name}`")
comment = comment.replace("'", "\\'")
comment = get_escaped_string_value(comment)
stmt = f"COMMENT ON COLUMN {table_name}.{column_name} IS '{comment}'"
table_context.engine_context.engine.execute(stmt)

@classmethod
def validate_column_value(cls, value: Any, column: Column) -> None:
"""
Ensure that a value is compatible with the target column. The method doesn't
return a value, only raises an Exception if the value and target column type
are incompatible.
:param value: value to insert into a column of a database table
:param column: The target column
"""
global VALID_COLUMN_TYPES
name = column.name
valid_types = VALID_COLUMN_TYPES.get(type(column.type))
if column.nullable and value is None:
pass
elif not column.nullable and value is None:
raise ValueError(f"Column {name} cannot be null")
elif valid_types is None:
# type checking not valid
pass
else:
if type(value) not in valid_types:
raise ValueError(f"Column {name} type {column.type} is not compatible "
f"with value: {value}")
if isinstance(value, str) and hasattr(column.type, "length") and \
column.type.length is not None \
and len(value) > column.type.length: # type: ignore
raise ValueError(f"Column {name} only supports "
f"{column.type.length} " # type: ignore
f"character strings, given value is {len(value)} "
f"characters.")
21 changes: 19 additions & 2 deletions sqltask/utils/engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from sqltask.base.table import BaseTableContext

logger = logging.getLogger(__name__)


def create_tmp_csv(table_context: BaseTableContext, delimiter: str = "\t") -> str:
"""
Expand All @@ -15,19 +17,34 @@ def create_tmp_csv(table_context: BaseTableContext, delimiter: str = "\t") -> st
:return: the path of the created temporary csv file.
"""
csv_rows = []
metadata = table_context.engine_context.metadata
if table_context.name not in metadata.tables:
metadata.reflect(only=[table_context.name])
columns = metadata.tables[table_context.name].columns
for row in table_context.output_rows:
csv_row = []
for column in table_context.table.columns:
for column in columns:
csv_row.append(row[column.name])
csv_rows.append(csv_row)

table = table_context.table

epoch = str(datetime.utcnow().timestamp())
file_path = f"{tempfile.gettempdir()}/{table.name}_{epoch}.csv"
logging.info(f"Creating temporary file `{file_path}`")
logger.info(f"Creating temporary file `{file_path}`")

with open(file_path, 'w', encoding="utf-8", newline='') as csv_file:
writer = csv.writer(csv_file, delimiter=delimiter)
writer.writerows(csv_rows)
csv_file.close()
return file_path


def get_escaped_string_value(value: str) -> str:
"""
Escapes a string to be used in a sql expression
:param value: string value to be escaped
:return: escaped string value
"""
return value.replace("'", "\\'")
Empty file added tests/__init__.py
Empty file.
Empty file added tests/base/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from datetime import date

from sqlalchemy.schema import Column
from sqlalchemy.types import Date, String

from sqltask.base.engine import EngineContext
from sqltask.base.table import BaseTableContext


def get_table_context() -> BaseTableContext:
engine_context = EngineContext("source", "sqlite://")
return BaseTableContext(
name="table",
engine_context=engine_context,
columns=[
Column("report_date", Date, primary_key=True),
Column("customer_name", String(10), comment="Name", primary_key=True),
Column("birthdate", Date, comment="Birthday", nullable=True),
],
comment="The table",
batch_params={"report_date": date(2019, 12, 31)},
)


def populate_dummy_rows(table_context: BaseTableContext) -> None:
rows = (
(date(2019, 12, 31), "Jill", date(2009, 3, 31)),
(date(2019, 12, 31), "Jack", date(1999, 2, 28))
)
for in_row in rows:
row = table_context.get_new_row()
row["report_date"] = in_row[0]
row["customer_name"] = in_row[1]
row["birthdate"] = in_row[2]
row.append()
Empty file added tests/sources/__init__.py
Empty file.

0 comments on commit f6b69fb

Please sign in to comment.