Skip to content

Commit

Permalink
feature: support complex types #161
Browse files Browse the repository at this point in the history
  • Loading branch information
robcxyz committed Jun 14, 2023
1 parent 193cd8d commit 01f3e80
Showing 1 changed file with 114 additions and 20 deletions.
134 changes: 114 additions & 20 deletions tackle/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import re
from ruamel.yaml.constructor import CommentedKeyMap, CommentedMap
from ruamel.yaml.parser import ParserError
import typing
from typing import Type, Any, Union, Callable, Optional

from tackle import exceptions
Expand Down Expand Up @@ -1323,6 +1324,25 @@ def parse_tmp_context(context: Context, element: Any, existing_context: dict):
return tmp_context.public_context


def get_complex_field(
field: Any,
) -> Type:
"""
Takes an input field such as `list[str]` which can include uncalled hooks and calls
those hooks before returning the field. Works recursively to find nested hooks
within types.
"""
if isinstance(field, list):
for i, v in enumerate(field):
field[i] = get_complex_field(v)
elif isinstance(field, dict):
for k, v in field.items():
field[k] = get_complex_field(v)
elif isinstance(field, BaseFunction):
field = field.exec()
return field


def function_walk(
self: 'Context',
input_element: Union[list, dict],
Expand All @@ -1340,7 +1360,7 @@ def function_walk(
# a function call without an exec method.
input_element = {}
for i in self.function_fields:
input_element[i] = getattr(self, i)
input_element[i] = get_complex_field(getattr(self, i))

if self.public_context:
existing_context = self.public_context.copy()
Expand All @@ -1366,7 +1386,8 @@ def function_walk(
function=self, # noqa
) from None
else:
existing_context.update({i: getattr(self, i)})
# Otherwise just the value itself
existing_context.update({i: get_complex_field(getattr(self, i))})

tmp_context = Context(
public_hooks=self.public_hooks,
Expand Down Expand Up @@ -1415,6 +1436,79 @@ def function_walk(
raise NotImplementedError(f"Return must be of list or string {return_}.")
return tmp_context.public_context

LITERAL_TYPES: set = {'str', 'int', 'float', 'bool', 'dict', 'list'} # strings to match


def parse_function_type(
context: Context,
type_str: str,
func_name: str,
):
"""
Parse the `type` field within a declarative hook and use recursion to parse the
string into real types.
"""
type_str = type_str.strip()
# Check if it's a generic type with type arguments
if '[' in type_str:
# Strip the brackets. Base type will then have subtypes
base_type_str, type_args_str_raw = type_str.split('[', 1)
type_args_str = type_args_str_raw.rsplit(']', 1)[0]
# Get list of types separated by commas but not within brackets.
# ie `'dict[str, Base], Base'` -> `['dict[str, Base]', 'Base']`
type_args = [
parse_function_type(
context=context,
type_str=arg,
func_name=func_name,
) for arg in re.split(r',(?![^[\]]*])', type_args_str)
]
# Get base type
base_type = parse_function_type(
context=context,
type_str=base_type_str,
func_name=func_name,
)

if len(type_args) == 0:
return base_type
elif base_type == typing.Optional:
# Optional only takes a single arg
if len(type_args) == 1:
return base_type[type_args[0]]
else:
raise exceptions.MalformedFunctionFieldException(
"The type `Optional` only takes one arg.",
context=context, function_name=func_name,
) from None
else:
return base_type[tuple(type_args)]

# Check if it's a generic type without type arguments
if hasattr(typing, type_str):
return getattr(typing, type_str)
elif type_str not in LITERAL_TYPES:
hook = get_public_or_private_hook(context=context, hook_type=type_str)
if hook is None:
try:
type_ = getattr(typing, type_str.title())
except AttributeError:
raise exceptions.MalformedFunctionFieldException(
f"The type `{type_str}` is not recognized. Must be in python's "
f"`typing` module.", context=context, function_name=func_name,
) from None
return type_
elif isinstance(hook, LazyBaseFunction):
# We have a hook we need to build
return create_function_model(
context=context,
func_name=type_str,
func_dict=hook.function_dict.copy(),
)
else:
return hook
# Treat it as a plain name - Safe eval as already qualified as literal
return eval(type_str)

def create_function_model(
context: 'Context', func_name: str, func_dict: dict
Expand Down Expand Up @@ -1501,7 +1595,6 @@ def create_function_model(
func_dict.pop(k)
continue

literals = ('str', 'int', 'float', 'bool', 'dict', 'list') # strings to match
# Create function fields from anything left over in the function dict
for k, v in func_dict.items():
if v is None:
Expand All @@ -1517,29 +1610,30 @@ def create_function_model(

elif isinstance(v, dict):
if 'enum' in v:
if 'type' in v:
raise exceptions.MalformedFunctionFieldException(
'Enums are implicitly typed.',
context=context, function_name=func_name
)
enum_type = enum.Enum(k, {i: i for i in v['enum']})
if 'default' not in v:
new_func[k] = (enum_type, ...)
else:
if 'default' in v:
new_func[k] = (enum_type, v['default'])

# TODO: Clean up this logic
new_func['function_fields'].append(k)
continue

if 'type' in v:
# TODO: Qualify type in enum -> Type
else:
new_func[k] = (enum_type, ...)
elif 'type' in v:
type_ = v['type']
if type_ not in literals:
raise exceptions.MalformedFunctionFieldException(
f"Function field {k} with type={v} unknown. Must be one of {','.join(literals)}",
function_name=func_name,
if type_ in LITERAL_TYPES:
parsed_type = locate(type_).__name__
else:
parsed_type = parse_function_type(
context=context,
) from None
type_str=type_,
func_name=func_name,
)
if 'description' in v:
v = dict(v)
v['description'] = v['description'].__repr__()
new_func[k] = (type_, Field(**v))
new_func[k] = (parsed_type, Field(**v))
elif 'default' in v:
if isinstance(v['default'], dict) and '->' in v['default']:
# For hooks in the default fields.
Expand All @@ -1548,7 +1642,7 @@ def create_function_model(
new_func[k] = (type(v['default']), Field(**v))
else:
new_func[k] = (dict, v)
elif v in literals:
elif v in LITERAL_TYPES:
new_func[k] = (locate(v).__name__, Field(...))
elif isinstance(v, (str, int, float, bool)):
new_func[k] = v
Expand Down

0 comments on commit 01f3e80

Please sign in to comment.