From d3ccf0aed8fd60d3e287a98cc166e5817f859023 Mon Sep 17 00:00:00 2001 From: robcxyz Date: Mon, 12 Dec 2022 10:49:17 +0530 Subject: [PATCH] fix: override inputs so that they are able to apply to declarative hook execs and are passed through remote calls --- tackle/main.py | 10 ++- tackle/models.py | 6 +- tackle/parser.py | 65 +++++++++++++------ tackle/providers/logic/hooks/match.py | 1 + tackle/providers/tackle/hooks/block.py | 1 + tackle/providers/tackle/hooks/tackle.py | 3 +- tackle/render.py | 1 + tests/main/fixtures/dict-input-overrides.yaml | 3 + tests/main/fixtures/dict-input.yaml | 4 ++ tests/main/fixtures/func-exec-input.yaml | 13 ++++ tests/main/fixtures/func-input.yaml | 7 ++ tests/main/test_main.py | 31 ++++++++- .../functions/test_functions_exceptions.py | 2 - 13 files changed, 116 insertions(+), 31 deletions(-) create mode 100644 tests/main/fixtures/func-exec-input.yaml create mode 100644 tests/main/fixtures/func-input.yaml diff --git a/tackle/main.py b/tackle/main.py index 5cc53ba77..4f0be2e27 100644 --- a/tackle/main.py +++ b/tackle/main.py @@ -12,7 +12,7 @@ def get_global_kwargs(kwargs): """Check for unknown kwargs and return so they can be consumed later.""" global_kwargs = {} for k, v in kwargs.items(): - if k not in Context.__fields__: + if k not in Context.__fields__ and k != 'override': global_kwargs.update({k: v}) if global_kwargs == {}: return None @@ -70,7 +70,8 @@ def tackle( if os.path.exists(overrides): override_dict = read_config_file(overrides) if override_dict is not None: - context.global_kwargs.update(override_dict) + # context.global_kwargs.update(override_dict) + context.override_context.update(override_dict) else: raise exceptions.UnknownInputArgumentException( f"The `override` input={overrides}, when given as a string must " @@ -78,7 +79,10 @@ def tackle( context=context, ) elif isinstance(overrides, dict): - context.global_kwargs.update(overrides) + context.override_context.update(overrides) + + # context.global_kwargs.update(overrides) + # context.global_kwargs.pop('override') # Main loop output = update_source(context) diff --git a/tackle/models.py b/tackle/models.py index 82d06388a..55165387a 100644 --- a/tackle/models.py +++ b/tackle/models.py @@ -42,6 +42,7 @@ class BaseContext(BaseModel): private_context: Union[dict, list] = None temporary_context: Union[dict, list] = None existing_context: dict = {} + override_context: dict = Field({}, description="A dict to override inputs with.") key_path: list = [] key_path_block: list = [] @@ -82,15 +83,10 @@ class Context(BaseContext): input_string: str = None input_dir: str = None input_file: str = None - override_context: Union[str, dict] = Field( - None, description="A str for a file or dict to override inputs with." - ) function_fields: list = None function_dict: dict = None - # return_: bool = False - hook_dirs: list = Field( None, description="A list of additional directories to import hooks." ) diff --git a/tackle/parser.py b/tackle/parser.py index 07ec363ef..aa672cd59 100644 --- a/tackle/parser.py +++ b/tackle/parser.py @@ -477,6 +477,7 @@ def parse_hook( verbose=context.verbose, env_=context.env_, is_hook_call=True, + override_context=context.override_context, ) except TypeError as e: # TODO: Improve -> This is an error when we have multiple of the same @@ -840,20 +841,27 @@ def walk_sync(context: 'Context', element): set_key(context=context, value=element) -def update_input_context_with_kwargs(context: 'Context', kwargs: dict): +def update_input_context(input_dict: dict, update_dict: dict) -> dict: """ - Update the input dict with kwargs which in this context are treated as overriding - the keys. Takes into account if the key is a hook and replaces that. + Update the input dict with update_dict which in this context are treated as + overriding the keys. Takes into account if the key is a hook and replaces that. """ - for k, v in kwargs.items(): - if k in context.input_context: - context.input_context.update({k: v}) - elif f"{k}->" in context.input_context: + for k, v in update_dict.items(): + if k in input_dict: + input_dict.update({k: v}) + elif f"{k}->" in input_dict: # Replace the keys and value in the same position it was in - context.input_context = { + input_dict = { key if key != f"{k}->" else k: value if key != f"{k}->" else v - for key, value in context.input_context.items() + for key, value in input_dict.items() + } + elif f"{k}_>" in input_dict: + # Same but for private hooks + input_dict = { + key if key != f"{k}_>" else k: value if key != f"{k}_>" else v + for key, value in input_dict.items() } + return input_dict def update_hook_with_kwargs_and_flags(hook: ModelMetaclass, kwargs: dict) -> dict: @@ -1097,9 +1105,19 @@ def run_source(context: 'Context', args: list, kwargs: dict, flags: list) -> Opt else: # If there are no declarative hooks defined, use the kwargs to override values # within the context. - update_input_context_with_kwargs(context=context, kwargs=kwargs) + context.input_context = update_input_context( + input_dict=context.input_context, + update_dict=kwargs, + ) + # Apply overrides + context.input_context = update_input_context( + input_dict=context.input_context, + update_dict=context.override_context, + ) for i in flags: + # TODO: This should use `update_input_context` as we don't know if the key has + # a hook in it -> Right? It has not been expanded... # Process flags by setting key to true context.input_context.update({i: True}) @@ -1136,6 +1154,7 @@ def parse_tmp_context(context: Context, element: Any, existing_context: dict): calling_file=context.calling_file, verbose=context.verbose, env_=context.env_, + override_context=context.override_context, ) walk_sync(context=tmp_context, element=element) @@ -1152,8 +1171,8 @@ def function_walk( many returnable string keys. Function is meant to be implanted into a function object and called either as `exec` or some other arbitrary method. """ - if input_element is None: - # If there is no `exec` method, input_element is None so we infer that the + if input_element == {}: + # If there is no `exec` method, input_element is {} so we infer that the # input fields are to be returned. This is useful if the user would like to # validate a dict easily with a function and is the only natural meaning of # a function call without an exec method. @@ -1198,6 +1217,7 @@ def function_walk( calling_directory=self.calling_directory, calling_file=self.calling_file, env_=self.env_, + override_context=self.override_context, ) walk_sync(context=tmp_context, element=input_element.copy()) @@ -1256,6 +1276,11 @@ def create_function_model( # Macro to expand all keys properly so that a field's default can be parsed func_dict = function_field_to_parseable_macro(func_dict, context, func_name) + # Apply overrides to input fields + for k, v in context.override_context.items(): + if k in func_dict: + func_dict[k] = v + # Implement inheritance if 'extends' in func_dict and func_dict['extends'] is not None: base_hook = get_hook(func_dict['extends'], context) @@ -1267,12 +1292,18 @@ def create_function_model( # fmt: off # Validate raw input params against pydantic object where values will be used later - exec_ = None + exec_ = {} if 'exec' in func_dict: exec_ = func_dict.pop('exec') elif 'exec<-' in func_dict: exec_ = func_dict.pop('exec<-') + # Apply overrides to exec_ + exec_ = update_input_context( + input_dict=exec_, + update_dict=context.override_context, + ) + # Special vars function_input = FunctionInput( exec_=exec_, @@ -1322,12 +1353,7 @@ def create_function_model( else: new_func[k] = (type(v['default']), Field(**v)) else: - raise exceptions.MalformedFunctionFieldException( - f"Function field {k} must have either a `type` or `default` field " - f"where the type can be inferred.", - function_name=func_name, - context=context, - ) from None + new_func[k] = (dict, v) elif v in literals: new_func[k] = (locate(v).__name__, Field(...)) elif isinstance(v, (str, int, float, bool)): @@ -1358,6 +1384,7 @@ def create_function_model( private_hooks=context.private_hooks, calling_directory=context.calling_directory, calling_file=context.calling_file, + override_context=context.override_context, # Causes TypeError in pydantic -> __subclasscheck__ # env_=context.env_, ) diff --git a/tackle/providers/logic/hooks/match.py b/tackle/providers/logic/hooks/match.py index 220efb9d4..91d2fcbff 100644 --- a/tackle/providers/logic/hooks/match.py +++ b/tackle/providers/logic/hooks/match.py @@ -94,6 +94,7 @@ def run_key(self, value): calling_directory=self.calling_directory, calling_file=self.calling_file, verbose=self.verbose, + override_context=self.override_context, ) walk_sync(context=tmp_context, element=value.copy()) diff --git a/tackle/providers/tackle/hooks/block.py b/tackle/providers/tackle/hooks/block.py index ed66d4fdb..829606bb9 100644 --- a/tackle/providers/tackle/hooks/block.py +++ b/tackle/providers/tackle/hooks/block.py @@ -39,6 +39,7 @@ def exec(self) -> Union[dict, list]: calling_directory=self.calling_directory, calling_file=self.calling_file, verbose=self.verbose, + override_context=self.override_context, ) walk_sync(context=tmp_context, element=self.items.copy()) diff --git a/tackle/providers/tackle/hooks/tackle.py b/tackle/providers/tackle/hooks/tackle.py index 5e4f64b68..c5cca0b79 100644 --- a/tackle/providers/tackle/hooks/tackle.py +++ b/tackle/providers/tackle/hooks/tackle.py @@ -5,7 +5,7 @@ from tackle.models import BaseHook, Field -class TackleHook(BaseHook): +class TackleHook(BaseHook, smart_union=True): """Hook for calling external tackle providers.""" hook_type: str = 'tackle' @@ -103,6 +103,7 @@ def exec(self) -> dict: global_args=self.additional_args, find_in_parent=self.find_in_parent, verbose=self.verbose, + override_context=self.override_context, ) return output_context diff --git a/tackle/render.py b/tackle/render.py index 5cdbee667..ad19ccad7 100644 --- a/tackle/render.py +++ b/tackle/render.py @@ -61,6 +61,7 @@ def create_jinja_hook(context: 'Context', hook: 'ModelMetaclass') -> 'JinjaHook' key_path=context.key_path, verbose=context.verbose, env_=context.env_, + override_context=context.override_context, ), ) diff --git a/tests/main/fixtures/dict-input-overrides.yaml b/tests/main/fixtures/dict-input-overrides.yaml index bc04c7458..4d1b85e17 100644 --- a/tests/main/fixtures/dict-input-overrides.yaml +++ b/tests/main/fixtures/dict-input-overrides.yaml @@ -1,2 +1,5 @@ this: stuff that: things + +this_private: stuff +that_private: things \ No newline at end of file diff --git a/tests/main/fixtures/dict-input.yaml b/tests/main/fixtures/dict-input.yaml index e430102a1..0cd045481 100644 --- a/tests/main/fixtures/dict-input.yaml +++ b/tests/main/fixtures/dict-input.yaml @@ -3,4 +3,8 @@ this->: input This? that->: input That? +this_private_>: input This? + +that_private_>: input That? + foo: 1 diff --git a/tests/main/fixtures/func-exec-input.yaml b/tests/main/fixtures/func-exec-input.yaml new file mode 100644 index 000000000..058020a81 --- /dev/null +++ b/tests/main/fixtures/func-exec-input.yaml @@ -0,0 +1,13 @@ + +<-: + + exec: + this->: input This? + + that->: input That? + + this_private_>: input This? + + that_private_>: input That? + + foo: 1 diff --git a/tests/main/fixtures/func-input.yaml b/tests/main/fixtures/func-input.yaml new file mode 100644 index 000000000..2045c8e4e --- /dev/null +++ b/tests/main/fixtures/func-input.yaml @@ -0,0 +1,7 @@ + +<-: + this->: input This? + + that->: input That? + + foo: 1 diff --git a/tests/main/test_main.py b/tests/main/test_main.py index 85fb41e97..3ac99c6c9 100644 --- a/tests/main/test_main.py +++ b/tests/main/test_main.py @@ -54,6 +54,8 @@ def test_main_input_dict(change_curdir_fixtures): input_dict = { 'this': 1, 'that': 2, + 'this_private': 1, + 'that_private': 2, 'stuff': 'things', # Missing key 'foo': 2, # Non-hook key } @@ -64,7 +66,20 @@ def test_main_input_dict(change_curdir_fixtures): def test_main_from_cli_input_dict(change_curdir_fixtures, capsys): """Test same as above but from command line.""" - main(["dict-input.yaml", "--this", "1", "--that", "2", "--print"]) + main( + [ + "dict-input.yaml", + "--this", + "1", + "--that", + "2", + "--this_private", + "1", + "--that_private", + "2", + "--print", + ] + ) assert 'this' in capsys.readouterr().out @@ -77,6 +92,20 @@ def test_main_overrides_str(change_curdir_fixtures): main(["dict-input.yaml", "--override", "dict-input-overrides.yaml"]) +def test_main_overrides_str_for_func(change_curdir_fixtures): + """Test that we can override inputs for a default hook.""" + o = tackle("func-input.yaml", override="dict-input-overrides.yaml") + # Should normally throw error with prompt + assert o['this'] == "stuff" + + +def test_main_overrides_str_for_func_exec(change_curdir_fixtures): + """Test that we can override inputs for a default hook.""" + o = tackle("func-exec-input.yaml", override="dict-input-overrides.yaml") + # Should normally throw error with prompt + assert o['this'] == "stuff" + + def test_main_overrides_str_not_found_error(change_curdir_fixtures): """Test that we get error on .""" with pytest.raises(exceptions.UnknownInputArgumentException): diff --git a/tests/parser/functions/test_functions_exceptions.py b/tests/parser/functions/test_functions_exceptions.py index 2894bd5f9..c38180d16 100644 --- a/tests/parser/functions/test_functions_exceptions.py +++ b/tests/parser/functions/test_functions_exceptions.py @@ -14,8 +14,6 @@ ('field-require.yaml', exceptions.HookParseException), # Check that type is one of literals. ('field-bad-type.yaml', exceptions.MalformedFunctionFieldException), - # Check that type or default is given. - ('field-type-or-default.yaml', exceptions.MalformedFunctionFieldException), ]