Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial draft at a caching hook implementation #652

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import logging
from abc import ABC, abstractmethod
from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor, as_completed, wait
from typing import Any, Dict, Iterable
from typing import Any, Dict, Iterable, Tuple

from kedro.framework.hooks import get_hook_manager
from kedro.io import AbstractDataSet, DataCatalog
Expand Down Expand Up @@ -222,15 +222,19 @@ def _collect_inputs_from_hook(
inputs: Dict[str, Any],
is_async: bool,
run_id: str = None,
) -> Dict[str, Any]:
) -> Tuple[Dict[str, Any], bool]:
inputs = inputs.copy() # shallow copy to prevent in-place modification by the hook
hook_manager = get_hook_manager()
hook_response = hook_manager.hook.before_node_run( # pylint: disable=no-member
node=node, catalog=catalog, inputs=inputs, is_async=is_async, run_id=run_id,
)

additional_inputs = {}
skip = False
for response in hook_response:
if response is not None and isinstance(response, bool):
skip = response
continue
if response is not None and not isinstance(response, dict):
response_type = type(response).__name__
raise TypeError(
Expand All @@ -240,7 +244,7 @@ def _collect_inputs_from_hook(
response = response or {}
additional_inputs.update(response)

return additional_inputs
return additional_inputs, skip


def _call_node_run(
Expand Down Expand Up @@ -289,10 +293,12 @@ def _run_node_sequential(node: Node, catalog: DataCatalog, run_id: str = None) -

is_async = False

additional_inputs = _collect_inputs_from_hook(
additional_inputs, skip = _collect_inputs_from_hook(
node, catalog, inputs, is_async, run_id=run_id
)
inputs.update(additional_inputs)
if skip:
return node

outputs = _call_node_run(node, catalog, inputs, is_async, run_id=run_id)

Expand Down Expand Up @@ -324,10 +330,12 @@ def _run_node_async(node: Node, catalog: DataCatalog, run_id: str = None) -> Nod
wait(inputs.values(), return_when=ALL_COMPLETED)
inputs = {key: value.result() for key, value in inputs.items()}
is_async = True
additional_inputs = _collect_inputs_from_hook(
additional_inputs, skip = _collect_inputs_from_hook(
node, catalog, inputs, is_async, run_id=run_id
)
inputs.update(additional_inputs)
if skip:
return node

outputs = _call_node_run(node, catalog, inputs, is_async, run_id=run_id)

Expand Down
147 changes: 147 additions & 0 deletions tests/framework/hooks/test_caching_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import json
import os
from pathlib import Path
from random import random
from tempfile import gettempdir
from unittest.mock import Mock

from kedro.extras.caching.CachingHook import LocalFileCachingHook, get_function_fingerprint
from kedro.extras.datasets.json import JSONDataSet
from kedro.framework.hooks import get_hook_manager
from kedro.io import DataCatalog, MemoryDataSet
from kedro.pipeline import Pipeline, node
from kedro.runner import SequentialRunner

STATIC_STRING = "ABC"


# some functions to use for equality evaluations
def ident(x):
return x


def a_loop():
for i in range(10):
a = "1"


def with_imported():
os.listdir()


def a_func():
# with comment
a_loop()
a = 2
ident(STATIC_STRING)
with_imported()
return {"a": a}


def b_func(a_out):
# with comment
a_loop()
b = 2
ident(STATIC_STRING)
with_imported()
return {"b": b}


def _run_pipeline_twice(pipeline: Pipeline, catalog: DataCatalog, second_pipeline=None):
second_pipeline = second_pipeline or pipeline
manager = get_hook_manager()
runner = SequentialRunner()
tmp_dir = gettempdir()
state_path = Path(tmp_dir) / "state.kstate"
state_path.unlink(missing_ok=True) # ensure no old file in use

# run pipeline 1st time
hook1 = LocalFileCachingHook(state_path)
manager.register(hook1)
runner.run(pipeline, catalog, "1")
manager.unregister(hook1)
hook1._persist()

# run pipeline 2nd time
hook2 = LocalFileCachingHook(state_path)
manager.register(hook2)
runner.run(second_pipeline, catalog, "2")
manager.unregister(hook2)
hook2._persist()
with open(state_path) as f:
state = f.read()
return hook1, hook2, state


def test_simple_in_memory_pipeline():
# Assemble nodes into a pipeline
a = Mock(return_value={"a": 1})
b = Mock()
pipeline = Pipeline([
node(lambda: a(), inputs=None, outputs="a_out", name="A"),
node(lambda x: b(x), inputs="a_out", outputs="b_out", name="B")
])
data_catalog = DataCatalog({"a_out": MemoryDataSet(), "b_out": MemoryDataSet()})
hook1, hook2, state_content = _run_pipeline_twice(pipeline, data_catalog)

# both nodes can't be skipped due to MemoryDataSet exclusion criteria
assert a.call_count == 2
assert b.call_count == 2

# check state_content
state = json.loads(state_content)
assert state['datasets']['a_out'] == 3
assert state['datasets']['b_out'] == 4


def test_with_edited_pipeline(tmp_path):
# tests that nodes are skipped when called subsequently without changes to the nodes
a = Mock(return_value={"a": 1})
b = Mock(return_value={"b": 1})
c = Mock()

initial_pipeline = Pipeline([
node(lambda: a(), inputs=None, outputs="a_out", name="A"),
node(lambda x: b(x), inputs="a_out", outputs="b_out", name="B")
])
edited_pipeline = Pipeline([
node(lambda:a(), inputs=None, outputs="a_out", name="A"),
node(lambda x:b(x), inputs="a_out", outputs="b_out", name="B"),
node(lambda x:c(x), inputs="b_out", outputs=None, name="C")
])

data_catalog = DataCatalog(
{"a_out": JSONDataSet(str(tmp_path / "a.json")), "b_out": JSONDataSet(str(tmp_path / "b.json"))})
hook1, hook2, state_content = _run_pipeline_twice(initial_pipeline, data_catalog, second_pipeline=edited_pipeline)

# call each node only once, skipping the second run
assert a.call_count == 1
assert b.call_count == 1
assert c.call_count == 1

# check state_content
state = json.loads(state_content)
assert state['datasets']['a_out'] == 1
assert state['datasets']['b_out'] == 2


def test_expect_two_functions_equal():
# ignoring name in specific case because the function's can't have the same name in the same namespace
b = a_func
hash_one = get_function_fingerprint(a_func, ["co_name"])
hash_two = get_function_fingerprint(b, ["co_name"])
assert hash_one == hash_two

# works for lambdas
hash_one = get_function_fingerprint(lambda: print(1))
hash_two = get_function_fingerprint(lambda: print(2))
assert hash_one != hash_two

hash_one = get_function_fingerprint(lambda: print(1))
hash_two = get_function_fingerprint(lambda: print(1))
assert hash_one == hash_two

# works for randomness
hash_one = get_function_fingerprint(lambda: random())
hash_two = get_function_fingerprint(lambda: random())
assert hash_one == hash_two