Skip to content

Commit

Permalink
feat: add enum as field type for declarative hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
robcxyz committed May 14, 2023
1 parent 50bff41 commit 5982df2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
21 changes: 21 additions & 0 deletions tackle/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- Insert the output into the appropriate key within the output context
"""
from collections import OrderedDict
import enum
from functools import partialmethod
import os
from pydantic import Field, create_model, ValidationError
Expand Down Expand Up @@ -1491,6 +1492,15 @@ def create_function_model(

# fmt: on
new_func = {'hook_type': func_name, 'function_fields': []}

# First pass through the func_dict to parse out the methods
for k, v in func_dict.copy().items():
if k.endswith(('<-', '<_')):
# Implement method which is instantiated later in `enrich_hook`
new_func[k[:-2]] = (Callable, LazyBaseFunction(function_dict=v))
func_dict.pop(k)
continue

literals = ('str', 'int', 'float', 'bool', 'dict', 'list') # strings to match
# Create function fields from anything left over in the function dict
for k, v in func_dict.items():
Expand All @@ -1506,6 +1516,17 @@ def create_function_model(
continue

elif isinstance(v, dict):
if 'enum' in v:
enum_type = enum.Enum(k, {i: i for i in v['enum']})
if 'default' not in v:
new_func[k] = (enum_type, ...)
else:
new_func[k] = (enum_type, v['default'])

# TODO: Clean up this logic
new_func['function_fields'].append(k)
continue

if 'type' in v:
# TODO: Qualify type in enum -> Type
type_ = v['type']
Expand Down
14 changes: 10 additions & 4 deletions tackle/utils/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
a list of strings for key value lookups and byte encoded integers for items in a list.
"""
from typing import Union, Any, TYPE_CHECKING
from enum import Enum
from ruamel.yaml.constructor import CommentedKeyMap

if TYPE_CHECKING:
Expand Down Expand Up @@ -75,7 +76,7 @@ def get_readable_key_path(key_path: list) -> str:
def nested_delete(element, keys):
"""
Delete items in a generic element (list / dict) based on a key path in the form of
a list with strings for keys and byte encoded integers for indexes in a list.
a list with strings for keys and byte encoded integers for indexes in a list.
"""
num_elements = len(keys)

Expand Down Expand Up @@ -146,8 +147,9 @@ def nested_get(element, keys):
def nested_set(element, keys, value, index: int = 0):
"""
Set the value of an arbitrary object based on a key_path in the form of a list
with strings for keys and byte encoded integers for indexes in a list. This function
recurses through the element until it is at the end of the keys where it sets it.
with strings for keys and byte encoded integers for indexes in a list. This
function recurses through the element until it is at the end of the keys where it
sets it.
:param element: A generic dictionary or list
:param keys: List of string and byte encoded integers.
Expand All @@ -157,6 +159,10 @@ def nested_set(element, keys, value, index: int = 0):
num_elements = len(keys)
# Check if we are at the last element of the list to insert the value
if index == num_elements - 1:
# Check is value is enum and evaluate it as value so it is serializable.
if isinstance(value, Enum):
value = value.value

if isinstance(keys[-1], bytes):
element.insert(decode_list_index(keys[-1]), value)
else:
Expand Down Expand Up @@ -278,7 +284,7 @@ def cleanup_unquoted_strings(element: Union[dict, list]):
def merge(a, b, path=None, update=True):
"""
See https://stackoverflow.com/questions/7204805/python-dictionaries-of-dictionaries-merge
Merges b into a
Merges b into a.
"""
if path is None:
path = []
Expand Down
8 changes: 8 additions & 0 deletions tests/functions/test_functions_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,12 @@ def test_functions_composition_enum_basic(fixture_dir):
assert output['failure_default']

assert output['success']['color'] == 'blue'
assert output['success']['color_default'] == 'red'
assert output['success_default']['color'] == 'blue'
assert output['success_default']['color_default'] == 'green'


def test_functions_composition_enum_basic_1(fixture_dir):
output = tackle('scratch.yaml')

assert output

0 comments on commit 5982df2

Please sign in to comment.