Skip to content

Commit

Permalink
Merge c5befd0 into c0ec1d8
Browse files Browse the repository at this point in the history
  • Loading branch information
paalvibe committed Sep 30, 2019
2 parents c0ec1d8 + c5befd0 commit 3336596
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 56 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
COV_MIN = 84 # Gradually increase this as we add more tests
COV_MIN = 85 # Gradually increase this as we add more tests
TAG = latest

SRC_DIR = $(shell pwd)
Expand Down
10 changes: 2 additions & 8 deletions birgitta/recipetest/localtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from birgitta import timing
from birgitta.recipetest import localtest
from birgitta.recipetest.coverage import report
from birgitta.recipetest.localtest import fixturing, assertion # noqa 401
from birgitta.recipetest.localtest import fixturing, assertion, script_prepend # noqa 401
from birgitta.recipetest.coverage.report import cov_report, dbg_counts, cov_results # noqa 401
from birgitta.recipetest.coverage.transform import prepare

Expand Down Expand Up @@ -153,16 +153,10 @@ def process_recipe(path,
in_fixture_fns,
tmpdir,
spark_session)
full_code = prepend_code() + code_w_reporting
full_code = script_prepend.code() + code_w_reporting
execute_recipe(full_code, globals_dict)


def prepend_code():
with open("birgitta/recipetest/localtest/script_prepend.py") as f:
code = f.read()
return code


def dump_test_recipe(test_case, tmpdir, code):
dump_path = tmpdir + "/" + test_case + ".py"
print("\nTest recipe python file:\n", repr(dump_path), "\n")
Expand Down
9 changes: 8 additions & 1 deletion birgitta/recipetest/localtest/script_prepend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Prepend to all pyspark recipes run by localtest
"""
from birgitta import timing
PREPEND = """from birgitta import timing
import sys
Expand All @@ -20,3 +20,10 @@
sys.modules['dataiku.spark'] = mock.MagicMock()
timing.time("script_prepend after magicmock")
"""

__all__ = ['code']


def code():
return PREPEND
9 changes: 9 additions & 0 deletions birgitta/schema/fixtures/values/dtvals.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,12 @@ def format_dt(dt):

def set_timezone(pydt):
return pydt.replace(tzinfo=timezone.utc)


def date_types_to_str(val):
"""Convert data types to string. Used to avoid conversion bugs."""
if type(val) == datetime.date:
return val.strftime('%Y-%m-%d %H:%M:%S')
if type(val) == datetime.datetime:
return val.strftime('%Y-%m-%d')
return val
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from birgitta.dataset.dataset import DataSet
from .schema import schema # noqa 401

from .schema import schema

dataset = DataSet("one_call_prepared", schema)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from birgitta.dataset.dataset import DataSet
from .schema import schema # noqa 401

from .schema import schema

dataset = DataSet("chronicle_contracts", schema)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from birgitta.dataset.dataset import DataSet
from .schema import schema # noqa 401

from .schema import schema

dataset = DataSet("tribune_contracts", schema)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from birgitta.dataset.dataset import DataSet
from .schema import schema # noqa 401

from .schema import schema

dataset = DataSet("daily_contract_states", schema)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from birgitta.dataset.dataset import DataSet
from .schema import schema # noqa 401

from .schema import schema

dataset = DataSet("date_dim", schema)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from birgitta.dataset.dataset import DataSet
from .schema import schema # noqa 401

from .schema import schema

dataset = DataSet("filtered_contracts", schema)
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import datetime

from birgitta.schema import fixtures
from birgitta.schema.fixtures.values import dtvals

Expand All @@ -14,17 +12,9 @@ def fx_default(spark):
return fixtures.df_w_rows(spark, schema, rows)


def escape_date_types(val):
if type(val) == datetime.date:
return val.strftime('%Y-%m-%d %H:%M:%S')
if type(val) == datetime.datetime:
return val.strftime('%Y-%m-%d')
return val


def date_row(dt):
return [
escape_date_types(
dtvals.date_types_to_str( # escape to avoid type errors
dtvals.field_val(f, dt)
) for f in schema.fields()
]

This file was deleted.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from setuptools import setup


version = '0.1.3'
version = '0.1.5'
here = path.abspath(path.dirname(__file__))

long_description = """Birgitta is a Python ETL test and schema framework,
Expand Down
52 changes: 52 additions & 0 deletions tests/schema/spark/test_print_df_rows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
from birgitta import spark
from birgitta.schema.spark import print_df_rows
from pyspark.sql.types import IntegerType
from pyspark.sql.types import StringType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType


@pytest.fixture(scope="session")
def spark_session():
return spark.local_session() # duration: about 3secs


fixtures_schema = StructType([
StructField('letter', StringType()),
StructField('number', IntegerType())
])


@pytest.fixture()
def fixtures_data():
return [
['a', 1],
['b', 2]
]


@pytest.fixture()
def df_print():
return """[
{
'letter': 'a',
'number': 1
},
{
'letter': 'b',
'number': 2
}
]
"""


@pytest.mark.filterwarnings("ignore:numpy.ufunc size changed")
def test_print_df_rows(spark_session,
fixtures_data,
df_print,
capsys):
df = spark_session.createDataFrame(fixtures_data, fixtures_schema)
print_df_rows(df)
captured = capsys.readouterr()
assert captured.out == df_print

0 comments on commit 3336596

Please sign in to comment.