Skip to content

Commit

Permalink
fix: override inputs so that they are able to apply to declarative ho…
Browse files Browse the repository at this point in the history
…ok execs and are passed through remote calls
  • Loading branch information
robcxyz committed Dec 12, 2022
1 parent 93f5020 commit d3ccf0a
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 31 deletions.
10 changes: 7 additions & 3 deletions tackle/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_global_kwargs(kwargs):
"""Check for unknown kwargs and return so they can be consumed later."""
global_kwargs = {}
for k, v in kwargs.items():
if k not in Context.__fields__:
if k not in Context.__fields__ and k != 'override':
global_kwargs.update({k: v})
if global_kwargs == {}:
return None
Expand Down Expand Up @@ -70,15 +70,19 @@ def tackle(
if os.path.exists(overrides):
override_dict = read_config_file(overrides)
if override_dict is not None:
context.global_kwargs.update(override_dict)
# context.global_kwargs.update(override_dict)
context.override_context.update(override_dict)
else:
raise exceptions.UnknownInputArgumentException(
f"The `override` input={overrides}, when given as a string must "
f"be a path to an file. Exiting.",
context=context,
)
elif isinstance(overrides, dict):
context.global_kwargs.update(overrides)
context.override_context.update(overrides)

# context.global_kwargs.update(overrides)
# context.global_kwargs.pop('override')

# Main loop
output = update_source(context)
Expand Down
6 changes: 1 addition & 5 deletions tackle/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class BaseContext(BaseModel):
private_context: Union[dict, list] = None
temporary_context: Union[dict, list] = None
existing_context: dict = {}
override_context: dict = Field({}, description="A dict to override inputs with.")

key_path: list = []
key_path_block: list = []
Expand Down Expand Up @@ -82,15 +83,10 @@ class Context(BaseContext):
input_string: str = None
input_dir: str = None
input_file: str = None
override_context: Union[str, dict] = Field(
None, description="A str for a file or dict to override inputs with."
)

function_fields: list = None
function_dict: dict = None

# return_: bool = False

hook_dirs: list = Field(
None, description="A list of additional directories to import hooks."
)
Expand Down
65 changes: 46 additions & 19 deletions tackle/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def parse_hook(
verbose=context.verbose,
env_=context.env_,
is_hook_call=True,
override_context=context.override_context,
)
except TypeError as e:
# TODO: Improve -> This is an error when we have multiple of the same
Expand Down Expand Up @@ -840,20 +841,27 @@ def walk_sync(context: 'Context', element):
set_key(context=context, value=element)


def update_input_context_with_kwargs(context: 'Context', kwargs: dict):
def update_input_context(input_dict: dict, update_dict: dict) -> dict:
"""
Update the input dict with kwargs which in this context are treated as overriding
the keys. Takes into account if the key is a hook and replaces that.
Update the input dict with update_dict which in this context are treated as
overriding the keys. Takes into account if the key is a hook and replaces that.
"""
for k, v in kwargs.items():
if k in context.input_context:
context.input_context.update({k: v})
elif f"{k}->" in context.input_context:
for k, v in update_dict.items():
if k in input_dict:
input_dict.update({k: v})
elif f"{k}->" in input_dict:
# Replace the keys and value in the same position it was in
context.input_context = {
input_dict = {
key if key != f"{k}->" else k: value if key != f"{k}->" else v
for key, value in context.input_context.items()
for key, value in input_dict.items()
}
elif f"{k}_>" in input_dict:
# Same but for private hooks
input_dict = {
key if key != f"{k}_>" else k: value if key != f"{k}_>" else v
for key, value in input_dict.items()
}
return input_dict


def update_hook_with_kwargs_and_flags(hook: ModelMetaclass, kwargs: dict) -> dict:
Expand Down Expand Up @@ -1097,9 +1105,19 @@ def run_source(context: 'Context', args: list, kwargs: dict, flags: list) -> Opt
else:
# If there are no declarative hooks defined, use the kwargs to override values
# within the context.
update_input_context_with_kwargs(context=context, kwargs=kwargs)
context.input_context = update_input_context(
input_dict=context.input_context,
update_dict=kwargs,
)
# Apply overrides
context.input_context = update_input_context(
input_dict=context.input_context,
update_dict=context.override_context,
)

for i in flags:
# TODO: This should use `update_input_context` as we don't know if the key has
# a hook in it -> Right? It has not been expanded...
# Process flags by setting key to true
context.input_context.update({i: True})

Expand Down Expand Up @@ -1136,6 +1154,7 @@ def parse_tmp_context(context: Context, element: Any, existing_context: dict):
calling_file=context.calling_file,
verbose=context.verbose,
env_=context.env_,
override_context=context.override_context,
)
walk_sync(context=tmp_context, element=element)

Expand All @@ -1152,8 +1171,8 @@ def function_walk(
many returnable string keys. Function is meant to be implanted into a function
object and called either as `exec` or some other arbitrary method.
"""
if input_element is None:
# If there is no `exec` method, input_element is None so we infer that the
if input_element == {}:
# If there is no `exec` method, input_element is {} so we infer that the
# input fields are to be returned. This is useful if the user would like to
# validate a dict easily with a function and is the only natural meaning of
# a function call without an exec method.
Expand Down Expand Up @@ -1198,6 +1217,7 @@ def function_walk(
calling_directory=self.calling_directory,
calling_file=self.calling_file,
env_=self.env_,
override_context=self.override_context,
)
walk_sync(context=tmp_context, element=input_element.copy())

Expand Down Expand Up @@ -1256,6 +1276,11 @@ def create_function_model(
# Macro to expand all keys properly so that a field's default can be parsed
func_dict = function_field_to_parseable_macro(func_dict, context, func_name)

# Apply overrides to input fields
for k, v in context.override_context.items():
if k in func_dict:
func_dict[k] = v

# Implement inheritance
if 'extends' in func_dict and func_dict['extends'] is not None:
base_hook = get_hook(func_dict['extends'], context)
Expand All @@ -1267,12 +1292,18 @@ def create_function_model(

# fmt: off
# Validate raw input params against pydantic object where values will be used later
exec_ = None
exec_ = {}
if 'exec' in func_dict:
exec_ = func_dict.pop('exec')
elif 'exec<-' in func_dict:
exec_ = func_dict.pop('exec<-')

# Apply overrides to exec_
exec_ = update_input_context(
input_dict=exec_,
update_dict=context.override_context,
)

# Special vars
function_input = FunctionInput(
exec_=exec_,
Expand Down Expand Up @@ -1322,12 +1353,7 @@ def create_function_model(
else:
new_func[k] = (type(v['default']), Field(**v))
else:
raise exceptions.MalformedFunctionFieldException(
f"Function field {k} must have either a `type` or `default` field "
f"where the type can be inferred.",
function_name=func_name,
context=context,
) from None
new_func[k] = (dict, v)
elif v in literals:
new_func[k] = (locate(v).__name__, Field(...))
elif isinstance(v, (str, int, float, bool)):
Expand Down Expand Up @@ -1358,6 +1384,7 @@ def create_function_model(
private_hooks=context.private_hooks,
calling_directory=context.calling_directory,
calling_file=context.calling_file,
override_context=context.override_context,
# Causes TypeError in pydantic -> __subclasscheck__
# env_=context.env_,
)
Expand Down
1 change: 1 addition & 0 deletions tackle/providers/logic/hooks/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def run_key(self, value):
calling_directory=self.calling_directory,
calling_file=self.calling_file,
verbose=self.verbose,
override_context=self.override_context,
)
walk_sync(context=tmp_context, element=value.copy())

Expand Down
1 change: 1 addition & 0 deletions tackle/providers/tackle/hooks/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def exec(self) -> Union[dict, list]:
calling_directory=self.calling_directory,
calling_file=self.calling_file,
verbose=self.verbose,
override_context=self.override_context,
)
walk_sync(context=tmp_context, element=self.items.copy())

Expand Down
3 changes: 2 additions & 1 deletion tackle/providers/tackle/hooks/tackle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tackle.models import BaseHook, Field


class TackleHook(BaseHook):
class TackleHook(BaseHook, smart_union=True):
"""Hook for calling external tackle providers."""

hook_type: str = 'tackle'
Expand Down Expand Up @@ -103,6 +103,7 @@ def exec(self) -> dict:
global_args=self.additional_args,
find_in_parent=self.find_in_parent,
verbose=self.verbose,
override_context=self.override_context,
)

return output_context
1 change: 1 addition & 0 deletions tackle/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def create_jinja_hook(context: 'Context', hook: 'ModelMetaclass') -> 'JinjaHook'
key_path=context.key_path,
verbose=context.verbose,
env_=context.env_,
override_context=context.override_context,
),
)

Expand Down
3 changes: 3 additions & 0 deletions tests/main/fixtures/dict-input-overrides.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
this: stuff
that: things

this_private: stuff
that_private: things
4 changes: 4 additions & 0 deletions tests/main/fixtures/dict-input.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@ this->: input This?

that->: input That?

this_private_>: input This?

that_private_>: input That?

foo: 1
13 changes: 13 additions & 0 deletions tests/main/fixtures/func-exec-input.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

<-:

exec:
this->: input This?

that->: input That?

this_private_>: input This?

that_private_>: input That?

foo: 1
7 changes: 7 additions & 0 deletions tests/main/fixtures/func-input.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

<-:
this->: input This?

that->: input That?

foo: 1
31 changes: 30 additions & 1 deletion tests/main/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def test_main_input_dict(change_curdir_fixtures):
input_dict = {
'this': 1,
'that': 2,
'this_private': 1,
'that_private': 2,
'stuff': 'things', # Missing key
'foo': 2, # Non-hook key
}
Expand All @@ -64,7 +66,20 @@ def test_main_input_dict(change_curdir_fixtures):

def test_main_from_cli_input_dict(change_curdir_fixtures, capsys):
"""Test same as above but from command line."""
main(["dict-input.yaml", "--this", "1", "--that", "2", "--print"])
main(
[
"dict-input.yaml",
"--this",
"1",
"--that",
"2",
"--this_private",
"1",
"--that_private",
"2",
"--print",
]
)
assert 'this' in capsys.readouterr().out


Expand All @@ -77,6 +92,20 @@ def test_main_overrides_str(change_curdir_fixtures):
main(["dict-input.yaml", "--override", "dict-input-overrides.yaml"])


def test_main_overrides_str_for_func(change_curdir_fixtures):
"""Test that we can override inputs for a default hook."""
o = tackle("func-input.yaml", override="dict-input-overrides.yaml")
# Should normally throw error with prompt
assert o['this'] == "stuff"


def test_main_overrides_str_for_func_exec(change_curdir_fixtures):
"""Test that we can override inputs for a default hook."""
o = tackle("func-exec-input.yaml", override="dict-input-overrides.yaml")
# Should normally throw error with prompt
assert o['this'] == "stuff"


def test_main_overrides_str_not_found_error(change_curdir_fixtures):
"""Test that we get error on ."""
with pytest.raises(exceptions.UnknownInputArgumentException):
Expand Down
2 changes: 0 additions & 2 deletions tests/parser/functions/test_functions_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
('field-require.yaml', exceptions.HookParseException),
# Check that type is one of literals.
('field-bad-type.yaml', exceptions.MalformedFunctionFieldException),
# Check that type or default is given.
('field-type-or-default.yaml', exceptions.MalformedFunctionFieldException),
]


Expand Down

0 comments on commit d3ccf0a

Please sign in to comment.