From 5982df29fc0c729be134903cc66d845d618b9134 Mon Sep 17 00:00:00 2001 From: robcxyz Date: Sun, 14 May 2023 22:32:10 +0800 Subject: [PATCH] feat: add enum as field type for declarative hooks --- tackle/parser.py | 21 +++++++++++++++++++++ tackle/utils/dicts.py | 14 ++++++++++---- tests/functions/test_functions_enums.py | 8 ++++++++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/tackle/parser.py b/tackle/parser.py index 9fa536d5e..27a797389 100644 --- a/tackle/parser.py +++ b/tackle/parser.py @@ -12,6 +12,7 @@ - Insert the output into the appropriate key within the output context """ from collections import OrderedDict +import enum from functools import partialmethod import os from pydantic import Field, create_model, ValidationError @@ -1491,6 +1492,15 @@ def create_function_model( # fmt: on new_func = {'hook_type': func_name, 'function_fields': []} + + # First pass through the func_dict to parse out the methods + for k, v in func_dict.copy().items(): + if k.endswith(('<-', '<_')): + # Implement method which is instantiated later in `enrich_hook` + new_func[k[:-2]] = (Callable, LazyBaseFunction(function_dict=v)) + func_dict.pop(k) + continue + literals = ('str', 'int', 'float', 'bool', 'dict', 'list') # strings to match # Create function fields from anything left over in the function dict for k, v in func_dict.items(): @@ -1506,6 +1516,17 @@ def create_function_model( continue elif isinstance(v, dict): + if 'enum' in v: + enum_type = enum.Enum(k, {i: i for i in v['enum']}) + if 'default' not in v: + new_func[k] = (enum_type, ...) + else: + new_func[k] = (enum_type, v['default']) + + # TODO: Clean up this logic + new_func['function_fields'].append(k) + continue + if 'type' in v: # TODO: Qualify type in enum -> Type type_ = v['type'] diff --git a/tackle/utils/dicts.py b/tackle/utils/dicts.py index 5b29dbe38..68deb944c 100644 --- a/tackle/utils/dicts.py +++ b/tackle/utils/dicts.py @@ -3,6 +3,7 @@ a list of strings for key value lookups and byte encoded integers for items in a list. """ from typing import Union, Any, TYPE_CHECKING +from enum import Enum from ruamel.yaml.constructor import CommentedKeyMap if TYPE_CHECKING: @@ -75,7 +76,7 @@ def get_readable_key_path(key_path: list) -> str: def nested_delete(element, keys): """ Delete items in a generic element (list / dict) based on a key path in the form of - a list with strings for keys and byte encoded integers for indexes in a list. + a list with strings for keys and byte encoded integers for indexes in a list. """ num_elements = len(keys) @@ -146,8 +147,9 @@ def nested_get(element, keys): def nested_set(element, keys, value, index: int = 0): """ Set the value of an arbitrary object based on a key_path in the form of a list - with strings for keys and byte encoded integers for indexes in a list. This function - recurses through the element until it is at the end of the keys where it sets it. + with strings for keys and byte encoded integers for indexes in a list. This + function recurses through the element until it is at the end of the keys where it + sets it. :param element: A generic dictionary or list :param keys: List of string and byte encoded integers. @@ -157,6 +159,10 @@ def nested_set(element, keys, value, index: int = 0): num_elements = len(keys) # Check if we are at the last element of the list to insert the value if index == num_elements - 1: + # Check is value is enum and evaluate it as value so it is serializable. + if isinstance(value, Enum): + value = value.value + if isinstance(keys[-1], bytes): element.insert(decode_list_index(keys[-1]), value) else: @@ -278,7 +284,7 @@ def cleanup_unquoted_strings(element: Union[dict, list]): def merge(a, b, path=None, update=True): """ See https://stackoverflow.com/questions/7204805/python-dictionaries-of-dictionaries-merge - Merges b into a + Merges b into a. """ if path is None: path = [] diff --git a/tests/functions/test_functions_enums.py b/tests/functions/test_functions_enums.py index c65ba3583..5de17305d 100644 --- a/tests/functions/test_functions_enums.py +++ b/tests/functions/test_functions_enums.py @@ -19,4 +19,12 @@ def test_functions_composition_enum_basic(fixture_dir): assert output['failure_default'] assert output['success']['color'] == 'blue' + assert output['success']['color_default'] == 'red' assert output['success_default']['color'] == 'blue' + assert output['success_default']['color_default'] == 'green' + + +def test_functions_composition_enum_basic_1(fixture_dir): + output = tackle('scratch.yaml') + + assert output