From 40e75936c08408cb89c6814d890e97ae7889e02d Mon Sep 17 00:00:00 2001 From: okay Date: Tue, 23 Jan 2018 17:47:52 -0800 Subject: [PATCH] [external modules] compatibility with plaitpy-ipc (0.1.0) this diff includes features and fixes added to support external modules like plaitpy-ipc. plaitpy-ipc is a module that provides IPC like features for multiple plaitpy processes, like locks and queues. new features: * add "effect" field for "exec" type statements with side effects * add "init / setup" field for custom template setup code * expose GLOBALS variable in templates that can hold data cleanup: * update Makefile, install scripts and tests to run py2 and py3 * move debug printing to its own module * rename fields.py -> template.py (thats what it really is) * all field errors should cause process exit except "suppressed" fields speed improvements: * add cache for parsed YAML templates * memoize lambda compilation tests: * add test for effect, init fields * add test for CSV and JSON printing * add test for custom printer --- README.md | 9 +- docs/FORMAT.md | 3 + scripts/install_package.sh | 3 +- src/__init__.py | 9 +- src/cli.py | 23 ++- src/debug.py | 15 ++ src/fakerb.py | 3 +- src/helpers.py | 76 ++++++--- src/{fields.py => template.py} | 184 +++++++++++++-------- src/version.py | 2 +- templates/behavior/web_browsing.yaml | 2 + tests/templates/csv_no_indexing_error.yaml | 19 +++ tests/templates/effects.yaml | 3 + tests/templates/print_test.yaml | 3 + tests/templates/printer.yaml | 7 + tests/templates/setup.yaml | 7 + tests/test_fields.py | 128 ++++++++++---- 17 files changed, 355 insertions(+), 141 deletions(-) create mode 100644 src/debug.py rename src/{fields.py => template.py} (89%) create mode 100644 tests/templates/csv_no_indexing_error.yaml create mode 100644 tests/templates/effects.yaml create mode 100644 tests/templates/print_test.yaml create mode 100644 tests/templates/printer.yaml create mode 100644 tests/templates/setup.yaml diff --git a/README.md b/README.md index 768e5c0..6c2dc77 100644 --- a/README.md +++ b/README.md @@ -146,10 +146,13 @@ plait.py also simplifies looking up faker fields: * see docs/TROUBLESHOOTING.md -### future direction -Currently, plait.py models independent markov processes - future investigation -into modeling processes that can interact with each other is needed. +### Dependent Markov Processes + +To simulate data that comes from many markov processes (a markov ecosystem), +see the [plaitpy-ipc](https://github.com/plaitpy/plaitpy-ipc) repository. + +### future direction If you have ideas on features to add, open an issue - Feedback is appreciated! diff --git a/docs/FORMAT.md b/docs/FORMAT.md index f3001aa..fe9f9f1 100644 --- a/docs/FORMAT.md +++ b/docs/FORMAT.md @@ -41,6 +41,7 @@ field can be of multiple types, including: * **switch**: specifies that this field is a switch field (similar to if / else if clauses) * **csv**: specifies that this field is generated via sampling from a CSV file * **lambda**: specifies that this field is generated via custom lambda function +* **effect**: specifies that this field is an effect field with sideeffects * **mixture**: specifies that this field is a mixture field (similar to a probabilistic if clause) * **random**: specifies that this field should be generated from random module * **template**: specifies that this field is generated from another template @@ -55,6 +56,8 @@ params: * cast: the type to cast this field to, "int", "float", "str", etc * initial: the initial value for this field (if it is self-referential) * finalize: a final lambda expression to be run on this field (a finalizer) + * suppress: suppress errors from this field (if its known to throw any) + * onlyif: only add this field if it matches the supplied expr ## fields API diff --git a/scripts/install_package.sh b/scripts/install_package.sh index c232ebc..677121d 100644 --- a/scripts/install_package.sh +++ b/scripts/install_package.sh @@ -1,4 +1,5 @@ #!/usr/bin/env bash VERSION=`python src/version.py` -sudo pip install dist/plaitpy-${VERSION}.tar.gz --upgrade +sudo pip2 install dist/plaitpy-${VERSION}.tar.gz --upgrade +sudo pip3 install dist/plaitpy-${VERSION}.tar.gz --upgrade diff --git a/src/__init__.py b/src/__init__.py index 8dbb104..5625ded 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,6 +1,9 @@ -from . fields import Template +from . template import Template from . import cli from . version import VERSION -from . ecosystem import Ecosystem +from . import helpers -__all__ = [ "Template", "cli" ] +import sys +sys.modules['plaitpy'] = sys.modules[__name__] + +__all__ = [ "Template", "cli", "helpers" ] diff --git a/src/cli.py b/src/cli.py index 6170464..acd2307 100644 --- a/src/cli.py +++ b/src/cli.py @@ -8,9 +8,9 @@ import os from . import fakerb -from . import fields +from . import template from . import helpers -from .helpers import debug +from . import debug def setup_args(): parser = argparse.ArgumentParser(description='Generate fake datasets from yaml template files') @@ -33,7 +33,7 @@ def setup_args(): parser.add_argument('--debug', dest='debug', action="store_true", default=False, help='Turn on debugging output for plait.py') parser.add_argument('--exit-on-error', dest='exit_error', action="store_true", default=False, - help='Exit loudly on error') + help='Exit loudly on any error') args = parser.parse_args() @@ -44,20 +44,19 @@ def main(): args, parser = setup_args() if args.csv: - fields.CSV = True - fields.JSON = False + template.CSV = True + template.JSON = False elif args.json: - fields.JSON = True - fields.CSV = False + template.JSON = True + template.CSV = False if args.exit_error: args.debug = True - fields.EXIT_ON_ERROR = True + template.EXIT_ON_ERROR = True if args.debug: - fields.DEBUG = True - helpers.DEBUG = True + debug.DEBUG = True @@ -105,7 +104,7 @@ def main(): helpers.add_template_path(args.dir) helpers.setup_globals() - tmpl = fields.Template(template_file) - debug("*** GENERATING %s RECORDS" % args.num_records) + tmpl = template.Template(template_file) + debug.debug("*** GENERATING %s RECORDS" % args.num_records) tmpl.print_records(args.num_records) fakerb.save_cache() diff --git a/src/debug.py b/src/debug.py new file mode 100644 index 0000000..fbed62e --- /dev/null +++ b/src/debug.py @@ -0,0 +1,15 @@ +from __future__ import print_function +from os import environ as ENV + +import sys + +DEBUG="DEBUG" in ENV +VERBOSE=False + +def debug(*args): + if DEBUG: + print(" ".join(map(str, args)), file=sys.stderr) + +def verbose(*args): + if VERBOSE: + print(" ".join(map(str, args)), file=sys.stderr) diff --git a/src/fakerb.py b/src/fakerb.py index 9082fa7..01036bc 100644 --- a/src/fakerb.py +++ b/src/fakerb.py @@ -5,7 +5,8 @@ import re import hashlib -from .helpers import debug, verbose, exit_error, readfile, LAMBDA_TYPE +from .debug import debug, verbose +from .helpers import LAMBDA_TYPE, readfile, exit_error # a key looks like: base.field # base maps to faker/lib/locales/en/base.yaml diff --git a/src/helpers.py b/src/helpers.py index 08a8c95..320dd90 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -8,22 +8,24 @@ import math import os -from . import tween +from . import debug from os import environ as ENV -DEBUG="DEBUG" in ENV LAMBDA_TYPE = type(lambda w: w) -VERBOSE=False TRACEBACK=True -def debug(*args): - if DEBUG: - print(" ".join(map(str, args)), file=sys.stderr) +class DotWrapper(dict): + def __getattr__(self, attr): + if attr in self: + return self[attr] + + if not attr in self: + debug.debug("MISSING ATTR", attr) + + def __setattr__(self, attr, val): + self[attr] = val -def verbose(*args): - if VERBOSE: - print(" ".join(map(str, args)), file=sys.stderr) def exit(): sys.exit(1) @@ -31,17 +33,31 @@ def exit(): def exit_error(e=None): import traceback if e: - debug("Error:", e) + debug.debug("Error:", e) if TRACEBACK: traceback.print_exc() sys.exit(1) - +# from comment on http://code.activestate.com/recipes/578231-probably-the-fastest-memoization-decorator-in-the-/ +def memoize(f): + """ Memoization decorator for functions taking one or more arguments. """ + class memodict(dict): + def __init__(self, f): + self.f = f + def __call__(self, *args): + return self[args] + def __missing__(self, key): + ret = self[key] = self.f(*key) + return ret + return memodict(f) + +@memoize def make_func(expr, name): func = compile_lambda(str(expr), name, 'exec') return lambda: eval(func, GLOBALS, LOCALS) +@memoize def make_lambda(expr, name): func = compile_lambda(str(expr), name) return lambda: eval(func, GLOBALS, LOCALS) @@ -60,27 +76,47 @@ def __setattr__(self, attr, val): self[attr] = val - GLOBALS = ObjWrapper({}) RAND_GLOBALS = ObjWrapper({}) LOCALS = ObjWrapper() +class GlobalAssigner(dict): + def __str__(self): + return "GLOBAL ASSIGNER %s" % (id(self)) + + def __setitem__(self, attr, val): + GLOBALS[attr] = val + RAND_GLOBALS[attr] = val + + def __setattr__(self, attr, val): + GLOBALS[attr] = val + RAND_GLOBALS[attr] = val + + def __getitem__(self, attr): + if attr in RAND_GLOBALS: + return RAND_GLOBALS[attr] + + def __getattr__(self, attr): + if attr in RAND_GLOBALS: + return RAND_GLOBALS[attr] def setup_globals(): if "__plaitpy__" in GLOBALS: return - g = globals() + ga = GlobalAssigner() + ga["__plaitpy__"] = True + ga.time = time + ga.random = random + ga.re = re + ga.GLOBALS = ga + ga.globals = ga - GLOBALS.time = time - GLOBALS.random = random - GLOBALS.tween = tween - GLOBALS.re = re - GLOBALS["__plaitpy__"] = True + from . import tween + ga.tween = tween for field in dir(math): - GLOBALS[field] = getattr(math, field) - RAND_GLOBALS[field] = getattr(math, field) + ga[field] = getattr(math, field) for field in dir(random): RAND_GLOBALS[field] = getattr(random, field) diff --git a/src/fields.py b/src/template.py similarity index 89% rename from src/fields.py rename to src/template.py index 6e50847..fc75a2a 100644 --- a/src/fields.py +++ b/src/template.py @@ -12,6 +12,7 @@ import yaml from os import environ as ENV +from . import debug from . import fakerb from . import toposort from .helpers import * @@ -21,7 +22,6 @@ CUSTOM_OUT=True SKIP = False -DEBUG="DEBUG" in ENV DEBUG_FIELD_SETUP = False DEBUG_GEN_TIMINGS = "TIMING" in ENV PROFILE_EVERY=25 # profile every 25 records @@ -37,19 +37,30 @@ STATIC_BRACE_CAPTURE = "\${(.*?)}" STATIC_NOBRACE_CAPTURE = "\$(\w+)" - LANGUAGE = "en_US" -class DotWrapper(dict): - def __getattr__(self, attr): - if attr in self: - return self[attr] +CSV_WRITER = None +def print_record(process, r, csv_writer=None): + if CUSTOM_OUT and process.output_func: + process.output_func(r) + elif CSV and csv_writer: + pr = {} + for field in process.public_fields(): + pr[field] = r[field] + csv_writer.writerow(pr) + elif JSON: + pr = {} + for field in process.public_fields(): + if field not in r: + continue - if not attr in self: - debug("MISSING ATTR", attr) + if r[field] is not None and r[field] is not "": + pr[field] = r[field] - def __setattr__(self, attr, val): - self[attr] = val + clean_json(pr) + print(json.dumps(pr)) + else: + raise Exception("UNDEFINED PRINT") class RecordWrapper(dict): def __init__(self, *args, **kwargs): @@ -65,6 +76,10 @@ def stop_profile(self): def set_template(self, template): self.__template = template + self.__id = id(self.__template) + + def get_id(self): + return self.__id def populate_field(self, field): if field in self: @@ -83,7 +98,7 @@ def populate_field(self, field): try: self[field] = val() except Exception as e: - self.error("ERROR POPULATING FIELD", field, "TEMPLATE IS", self.__template.name, "RECORD IS", self) + self.__template.error("ERROR POPULATING FIELD", field, "TEMPLATE IS", self.__template.name, "RECORD IS", self) exit_error() if self.__profile: @@ -99,9 +114,10 @@ def __getattr__(self, attr): if attr in self: return self[attr] - +YAML_CACHE = {} +REGISTERED = {} class Template(object): - def __init__(self, template, overrides=None, hidden=None, depth=0): + def __init__(self, template, overrides=None, hidden=None, depth=0, quiet=False): self.name = template setup_globals() @@ -109,6 +125,9 @@ def __init__(self, template, overrides=None, hidden=None, depth=0): self.field_errors = {} self.field_definitions = {} + # fields in ignore_errors should have their errors not halt program + self.ignore_errors = {} + self.error_types = {} self.count_until_profile = -1 if DEBUG_GEN_TIMINGS: @@ -118,6 +137,7 @@ def __init__(self, template, overrides=None, hidden=None, depth=0): self.depth = depth self.pad = " " * self.depth + self.quiet = quiet self.templates = {} self.timings = {} @@ -138,6 +158,7 @@ def __init__(self, template, overrides=None, hidden=None, depth=0): self.record = RecordWrapper({}) self.record.set_template(self) + self.record_invalid = False overrides = overrides or {} self.overrides = {} @@ -159,7 +180,7 @@ def __init__(self, template, overrides=None, hidden=None, depth=0): self.overrides[o] = overrides[o] - self.load_template(template) + self.setup_template(template) all_keys = {} for k in self.field_data: all_keys[k] = 1 @@ -169,6 +190,10 @@ def __init__(self, template, overrides=None, hidden=None, depth=0): def register_paths(self): + if self.name in REGISTERED: + return + + REGISTERED[self.name] = 1 # setup any paths on the way to the template (for convenience) this_path = os.path.realpath(".") that_path = os.path.realpath(self.name) @@ -178,20 +203,21 @@ def register_paths(self): fullpath = "." for token in tokens[:-1]: fullpath = os.path.join(fullpath, token) + debug.debug("ADDING PATH", fullpath) add_path(fullpath) def error(self, *args): - if DEBUG or not TEST_MODE: + if debug.DEBUG or not TEST_MODE: print("%s" % (" ".join(map(str, args))), file=sys.stderr) def debug(self, *args): - if DEBUG: - print("%s%s" % (self.pad, " ".join(map(str, args))), file=sys.stderr) + if self.quiet: + return - def set_language(self, ln): - self.language = ln + if debug.DEBUG: + print("%s%s" % (self.pad, " ".join(map(str, args))), file=sys.stderr) def build_template(self, name, data): if type(data) == dict: @@ -218,12 +244,16 @@ def build_template(self, name, data): if key in d: hidden = d[key] - return Template("%s" % name, overrides=overrides, hidden=hidden, depth=self.depth+1) + return Template("%s" % name, overrides=overrides, hidden=hidden, depth=self.depth+1, quiet=self.quiet) else: - return Template("%s" % name, depth=self.depth+1) + return Template("%s" % name, depth=self.depth+1, quiet=self.quiet) + + def setup_template(self, template): + + if template in YAML_CACHE: + return self.setup_template_from_cache(template) - def load_template(self, template): if template.endswith(".json"): with readfile(template) as f: try: @@ -237,15 +267,24 @@ def load_template(self, template): data = f.read() else: self.error("Unknown template type: %s" % template) + exit_error() + + return self.setup_template_from_data(template, data) - self.load_template_from_data(template, data) + def setup_template_from_cache(self, template): + doc = YAML_CACHE[template] + return self.setup_template_from_yaml_doc(template, doc) - def load_template_from_data(self, template, data): + + def setup_template_from_data(self, template, data): doc = yaml.load(data) - self.include = {} + YAML_CACHE[template] = doc + return self.setup_template_from_yaml_doc(template, doc) + def setup_template_from_yaml_doc(self, template, doc): + self.include = {} for choice in [ "include", "includes" ]: if choice in doc: @@ -281,6 +320,36 @@ def load_template_from_data(self, template, data): else: self.debug("DEF. %s - $%s as %s" % (template, o, self.static[o])) + if "imports" in doc: + for m in doc["imports"]: + if type(m) == str: + d = {} + d[m] = m + m = d + + try: + for modname in m: + asname = m[modname] + mod = __import__(modname) + + GLOBALS[asname] = mod + RAND_GLOBALS[asname] = mod + self.debug("IMPORTING", modname, "AS", asname) + except ImportError as e: + self.debug("*** COULD NOT IMPORT MODULE %s" % m) + if "requirements" in doc: + self.debug("*** MAKE SURE TO INSTALL ALL REQUIREMENTS:", ", ".join(doc["requirements"])) + + exit_error() + + + for choice in [ "setup", "init" ]: + if choice in doc: + init_data = self.replace_statics(doc[choice]) + init_lambda = make_func(init_data, "" % (self.name)) + init_lambda() + + for choice in [ "print", "printer", "format", "output" ]: if choice in doc: print_lambda = make_func(doc[choice], "" % (self.name)) @@ -289,26 +358,12 @@ def print_func(r): print_lambda() pop_this_record() self.output_func = print_func - debug("ADDED CUSTOM PRINTER TO", self.name) + debug.debug("ADDED CUSTOM PRINTER TO", self.name) for g in [ "mixin", "mixins"]: if g in doc: self.setup_mixins(template, doc[g]) - if "imports" in doc: - for m in doc["imports"]: - try: - mod = __import__(m) - except ImportError as e: - self.debug("*** COULD NOT IMPORT MODULE %s" % m) - if "requirements" in doc: - self.debug("*** MAKE SURE TO INSTALL ALL REQUIREMENTS:", ", ".join(doc["requirements"])) - - exit_error() - - GLOBALS[m] = mod - RAND_GLOBALS[m] = mod - if "embed" in doc: raise Exception("Embeds are not supported") @@ -587,7 +642,7 @@ def setup_mixture_field(self, field, **field_data): total_weight += weight cases.append([weight, case_id, val_func]) - debug("CASES ARE", cases) + debug.debug("CASES ARE", cases) cases.sort(reverse=True) for c in cases: c[0] = c[0]/total_weight @@ -647,6 +702,9 @@ def pick_case(): return pick_case + def setup_effect_field(self, field, **kwargs): + return make_func(str(kwargs.get('effect')), '' % field) + def setup_lambda_field(self, field, **kwargs): return make_lambda(str(kwargs.get('lambda')), '' % field) @@ -757,6 +815,9 @@ def setup_field(self, field, field_data): if type(deps) == str: deps = [ deps ] + if "suppress" in field_data: + self.ignore_errors[field] = True + val_func = None ## the heart of it all if "value" in field_data: @@ -771,6 +832,8 @@ def setup_field(self, field, field_data): val_func = self.setup_template_field(field, **field_data) elif "switch" in field_data: val_func = self.setup_switch_field(field, **field_data) + elif "effect" in field_data: + val_func = self.setup_effect_field(field, **field_data) elif "onlyif" in field_data: val_func = self.setup_fixed_field(field, value="true") elif "mixture" in field_data: @@ -857,7 +920,6 @@ def replace_field(m): static_lambda = make_lambda(static_init, '' % field) self.static[field] = static_lambda() - def replace_statics(self, field_data): def replace_field(m): return str(self.static[m.group(1)]) @@ -920,7 +982,7 @@ def track_error(self, field, e): else: error_types[err_str] += 1 - if EXIT_ON_ERROR: + if field not in self.ignore_errors or EXIT_ON_ERROR: exit_error() @@ -1020,43 +1082,25 @@ def gen_record(self, args={}, scrub_record=True): return ret - def print_records(self, num_records): - csv_writer = None + def print_headers(self): + self.csv_writer = None if CSV: - csv_writer = csv.DictWriter(sys.stdout, fieldnames=self.headers) - csv_writer.writeheader() + self.csv_writer = csv.DictWriter(sys.stdout, fieldnames=self.headers) + self.csv_writer.writeheader() + def print_records(self, num_records): chunk_size = 1000 - def print_record(r): - if CUSTOM_OUT and self.output_func: - self.output_func(r) - elif JSON: - pr = {} - for field in self.public_fields(): - if field not in r: - continue - - if r[field] is not None and r[field] is not "": - pr[field] = r[field] - - clean_json(pr) - print(json.dumps(pr)) - - elif CSV: - pr = {} - for field in self.public_fields(): - pr[field] = r[field] - csv_writer.writerow(pr) + self.print_headers() for _ in range(num_records // chunk_size): for r in self.gen_records(chunk_size, print_timing=False): - print_record(r) + print_record(self, r, csv_writer=self.csv_writer) if num_records % chunk_size != 0: for r in self.gen_records(num_records % chunk_size, print_timing=False): - print_record(r) + print_record(self, r, csv_writer=self.csv_writer) self.print_dropped() if DEBUG_GEN_TIMINGS: @@ -1078,7 +1122,7 @@ def gen_records(self, num_records, print_timing=True): ret.append(r) except Exception as e: - debug(e) + debug.debug(e) finally: if print_timing: self.print_dropped() diff --git a/src/version.py b/src/version.py index 1d246c1..1a39f63 100644 --- a/src/version.py +++ b/src/version.py @@ -1,4 +1,4 @@ -VERSION="0.0.13" +VERSION="0.1.0" if __name__ == "__main__": print(VERSION) diff --git a/templates/behavior/web_browsing.yaml b/templates/behavior/web_browsing.yaml index 3f94a92..44da8fa 100644 --- a/templates/behavior/web_browsing.yaml +++ b/templates/behavior/web_browsing.yaml @@ -43,6 +43,8 @@ fields: lambda: 0 - onlyif: not this._goto_new_site lambda: prev.page_visit + 1 + - default: + lambda: 0 current_site: switch: diff --git a/tests/templates/csv_no_indexing_error.yaml b/tests/templates/csv_no_indexing_error.yaml new file mode 100644 index 0000000..4069a3e --- /dev/null +++ b/tests/templates/csv_no_indexing_error.yaml @@ -0,0 +1,19 @@ +fields: + foo: + csv: indexing.csv + index: [0, 1] + column: 2 + suppress: 1 + lookup: > + [this.bar, this.baz] + + bar: a + baz: b + + boo: + csv: indexing.csv + index: [0, 1] + column: 2 + suppress: 1 + lookup: > + [this.bar, this.bag] diff --git a/tests/templates/effects.yaml b/tests/templates/effects.yaml new file mode 100644 index 0000000..4c81fe2 --- /dev/null +++ b/tests/templates/effects.yaml @@ -0,0 +1,3 @@ +fields: + foo: + effect: globals.effects_foo = 100 diff --git a/tests/templates/print_test.yaml b/tests/templates/print_test.yaml new file mode 100644 index 0000000..e4a948e --- /dev/null +++ b/tests/templates/print_test.yaml @@ -0,0 +1,3 @@ +fields: + foo: + lambda: 100 diff --git a/tests/templates/printer.yaml b/tests/templates/printer.yaml new file mode 100644 index 0000000..3340d9c --- /dev/null +++ b/tests/templates/printer.yaml @@ -0,0 +1,7 @@ +fields: + foo: + lambda: 100 + + +printer: | + print("foo: %s" % this.foo) diff --git a/tests/templates/setup.yaml b/tests/templates/setup.yaml new file mode 100644 index 0000000..17a784b --- /dev/null +++ b/tests/templates/setup.yaml @@ -0,0 +1,7 @@ +setup: | + globals.foo = 100 + +fields: + foo: + lambda: foo + diff --git a/tests/test_fields.py b/tests/test_fields.py index d08084b..35a44ff 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -6,12 +6,23 @@ # i hate the path dance sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) -from src import fields +from src import template from src import helpers helpers.TRACEBACK = False -fields.TEST_MODE = True -fields.EXIT_ON_ERROR = True +template.TEST_MODE = True +template.EXIT_ON_ERROR = True + +from contextlib import contextmanager + +@contextmanager +def stdout_redirector(stream): + old_stdout = sys.stdout + sys.stdout = stream + try: + yield + finally: + sys.stdout = old_stdout class TestTemplateFields(unittest.TestCase): @@ -24,20 +35,20 @@ def assert_num_errors(self, t, n=0): def test_imports(self): import marshal - t = fields.Template("tests/templates/imports.yaml") + t = template.Template("tests/templates/imports.yaml") r = t.gen_record() self.assertEqual(r.testfield, marshal) self.assert_no_errors(t) def test_mixins(self): - t = fields.Template("tests/templates/mixins.yaml") + t = template.Template("tests/templates/mixins.yaml") r = t.gen_record() self.assertEqual(r.foo, "bar") self.assertEqual(r.abc, "def") self.assert_no_errors(t) def test_defines(self): - t = fields.Template("tests/templates/define.yaml") + t = template.Template("tests/templates/define.yaml") r = t.gen_record() self.assertEqual(r.foo, "foobarbaz") self.assertEqual(r.bar, 510) @@ -45,21 +56,21 @@ def test_defines(self): # field types def test_mixin_overrides(self): - t = fields.Template("tests/templates/mixins_override.yaml") + t = template.Template("tests/templates/mixins_override.yaml") r = t.gen_record() self.assertEqual(r.foo, "baz") self.assertEqual(r.abc, "def") self.assert_no_errors(t) def test_lambda_fields(self): - t = fields.Template("tests/templates/lambdas.yaml") + t = template.Template("tests/templates/lambdas.yaml") r = t.gen_record() self.assertEqual(r.foo, "foobarbaz") self.assertEqual(r.sum, 100) self.assert_no_errors(t) def test_csv_sampling(self): - t = fields.Template("tests/templates/csv_sampling.yaml") + t = template.Template("tests/templates/csv_sampling.yaml") rs = t.gen_records(100) counts = { "foo" : 0, "bar" : 0, "baz" : 0 } for r in rs: @@ -70,22 +81,21 @@ def test_csv_sampling(self): self.assertEqual(counts["bar"], 0) def test_csv_resampling(self): - t = fields.Template("tests/templates/resample.yaml") + t = template.Template("tests/templates/resample.yaml") r = t.gen_record() self.assertEqual(r.foo * 2, r.bar) self.assertEqual(r.foo * 3, r.baz) def test_csv_indexing_error(self): - t = fields.Template("tests/templates/csv_indexing_error.yaml") + template.EXIT_ON_ERROR = False + t = template.Template("tests/templates/csv_indexing_error.yaml") with self.assertRaises(SystemExit) as context: r = t.gen_record() - fields.EXIT_ON_ERROR = False - t = fields.Template("tests/templates/csv_indexing_error.yaml") + t = template.Template("tests/templates/csv_no_indexing_error.yaml") r = t.gen_record() - fields.EXIT_ON_ERROR = True self.assertEqual(r.foo, "foobarbaz") # this template has a bad lookup in it, @@ -93,7 +103,7 @@ def test_csv_indexing_error(self): self.assert_num_errors(t, 1) def test_csv_indexing(self): - t = fields.Template("tests/templates/csv_indexing.yaml") + t = template.Template("tests/templates/csv_indexing.yaml") r = t.gen_record() self.assertEqual(r.foo, "foobarbaz") @@ -102,7 +112,7 @@ def test_csv_indexing(self): self.assert_no_errors(t) def test_switch_fields(self): - t = fields.Template("tests/templates/switch.yaml") + t = template.Template("tests/templates/switch.yaml") r = t.gen_record() self.assertEqual(r.foo, "foo") @@ -110,7 +120,7 @@ def test_switch_fields(self): self.assertEqual(r.baz, None) def test_mixture_fields(self): - t = fields.Template("tests/templates/mixture.yaml") + t = template.Template("tests/templates/mixture.yaml") rs = t.gen_records(1000) counts = { 1 :0, 2: 0, 3: 0 } @@ -125,14 +135,14 @@ def test_mixture_fields(self): def test_template_fields(self): - t = fields.Template("tests/templates/nesting.yaml") + t = template.Template("tests/templates/nesting.yaml") r = t.gen_record() self.assertEqual(r.nested.foo, "bar") self.assertEqual(r.nested.abc, "def") self.assert_no_errors(t) def test_template_bad_args(self): - t = fields.Template("tests/templates/bad_args.yaml") + t = template.Template("tests/templates/bad_args.yaml") r = None with self.assertRaises(SystemExit) as context: @@ -140,14 +150,14 @@ def test_template_bad_args(self): def test_template_args(self): - t = fields.Template("tests/templates/args.yaml") + t = template.Template("tests/templates/args.yaml") r = t.gen_record() self.assertEqual(r.nested.foo, "baz") self.assertEqual(r.nested.abc, "xyz") self.assert_no_errors(t) def test_faker_interpolation(self): - t = fields.Template("tests/templates/faker.yaml") + t = template.Template("tests/templates/faker.yaml") self.assertEqual(t.field_definitions["foo"], "#{name.name}") r = t.gen_record() @@ -155,7 +165,7 @@ def test_faker_interpolation(self): self.assert_no_errors(t) def test_hidden_fields(self): - t = fields.Template("tests/templates/hidden.yaml") + t = template.Template("tests/templates/hidden.yaml") self.assertEqual(t.field_definitions["_foo"], "bar") self.assertEqual(t.field_definitions["vis"], "baz") r = t.gen_record() @@ -164,7 +174,7 @@ def test_hidden_fields(self): self.assertEqual(r.vis, "baz") def test_random_fields(self): - t = fields.Template("tests/templates/random.yaml") + t = template.Template("tests/templates/random.yaml") rs = t.gen_records(1000) for r in rs: @@ -176,7 +186,7 @@ def test_random_fields(self): def test_prev_record(self): for _ in range(3): - t = fields.Template("tests/templates/init.yaml") + t = template.Template("tests/templates/init.yaml") r = t.gen_record() self.assertEqual(r.foo, 1) @@ -187,7 +197,7 @@ def test_prev_record(self): # field operations def test_field_cast(self): - t = fields.Template("tests/templates/casts.yaml") + t = template.Template("tests/templates/casts.yaml") r = t.gen_record() self.assertEqual(type(r.foo), str) @@ -195,7 +205,7 @@ def test_field_cast(self): def test_field_init(self): # inadvertently tests "prev", as well - t = fields.Template("tests/templates/init.yaml") + t = template.Template("tests/templates/init.yaml") r = t.gen_record() self.assertEqual(r.foo, 1) @@ -204,26 +214,84 @@ def test_field_init(self): self.assertEqual(r.foo, 2) def test_field_finalize(self): - t = fields.Template("tests/templates/finalize.yaml") + t = template.Template("tests/templates/finalize.yaml") r = t.gen_record() self.assertEqual(r.foo, 100) def test_field_finalize_json(self): - t = fields.Template("tests/templates/finalize.json") + t = template.Template("tests/templates/finalize.json") r = t.gen_record() self.assertEqual(r.foo, 100) def test_field_include_json(self): - t = fields.Template("tests/templates/finalize_wrapper.json") + t = template.Template("tests/templates/finalize_wrapper.json") r = t.gen_record() self.assertEqual(r.foo, 100) + def test_print_csv(self): + import json + template.CSV = True + template.JSON = False + + t = template.Template("tests/templates/print_test.yaml") + r = t.gen_record() + + import io + f = io.BytesIO() + with stdout_redirector(f): + t.print_records(1) + + self.assertEqual(f.getvalue(), "foo\r\n100\r\n") + + + def test_print_json(self): + import json + template.JSON = True + template.CSV = False + t = template.Template("tests/templates/print_test.yaml") + r = t.gen_record() + + import io + f = io.BytesIO() + with stdout_redirector(f): + t.print_records(1) + + self.assertEqual(f.getvalue(), json.dumps({"foo" : 100}) + "\n") + + def test_custom_printer(self): + t = template.Template("tests/templates/printer.yaml") + r = t.gen_record() + + self.assertEqual(r.foo, 100) + + import io + f = io.StringIO() + with stdout_redirector(f): + t.print_records(1) + + self.assertEqual(f.getvalue(), "foo: 100\n") + + + def test_setup(self): + t = template.Template("tests/templates/setup.yaml") + r = t.gen_record() + + self.assertEqual(r.foo, 100) + + def test_effects_field(self): + helpers.GLOBALS.effects_foo = None + t = template.Template("tests/templates/effects.yaml") + r = t.gen_record() + + self.assertEqual(helpers.GLOBALS.effects_foo, 100) + + # test field errors def test_exit_on_error(self): - t = fields.Template("tests/templates/bad_args.yaml") + t = template.Template("tests/templates/bad_args.yaml") with self.assertRaises(SystemExit) as context: r = t.gen_record()