Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
dd9927f
Add separate_data
samwaseda Jul 25, 2025
69be4b1
add tests
samwaseda Jul 25, 2025
04108a6
black
samwaseda Jul 25, 2025
4433b1f
typo
samwaseda Jul 25, 2025
6daf78a
Update flowrep/workflow.py
samwaseda Jul 25, 2025
19c8111
Convert to str
samwaseda Jul 25, 2025
576cdb2
forgot to convert
samwaseda Jul 25, 2025
bbadc0d
Implement get and set entry and add tests
samwaseda Jul 26, 2025
229f176
append nodes as prefix
samwaseda Jul 26, 2025
67a3809
And of course I forgot black
samwaseda Jul 26, 2025
1adf1c0
Merge pull request #9 from pyiron/dot
samwaseda Jul 26, 2025
e4818fe
Add docstring
samwaseda Jul 26, 2025
517ddcc
correct errors and add more tests
samwaseda Jul 26, 2025
3896675
black
samwaseda Jul 26, 2025
ab8fa5f
Add get_type etc., which come from pyiron_database, but since pyiron_…
samwaseda Jul 31, 2025
f62a5a0
Remove __class__ because it's not pyiron_workflow
samwaseda Jul 31, 2025
da27e66
Add tests and remove unused functions
samwaseda Jul 31, 2025
0ede732
Update hashing algorithm and add tests
samwaseda Jul 31, 2025
026009d
Remove json because it's not used
samwaseda Jul 31, 2025
8e11b2c
I'm hopeless
samwaseda Jul 31, 2025
8d889f4
Make functions public for pyiron_database
samwaseda Jul 31, 2025
ebbc641
Add tests
samwaseda Jul 31, 2025
e0786c5
black
samwaseda Jul 31, 2025
8e2f68e
Docstring and type hints
samwaseda Jul 31, 2025
ade1bf9
Remove unused workflow decorator and remove also tests for separate_d…
samwaseda Jul 31, 2025
779381d
Update flowrep/workflow.py
samwaseda Jul 31, 2025
647b9ab
Update flowrep/workflow.py
samwaseda Jul 31, 2025
b9e89a0
Update flowrep/workflow.py
samwaseda Jul 31, 2025
eb70e3c
Add json in the end
samwaseda Jul 31, 2025
25a345d
Black again
samwaseda Jul 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions flowrep/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import builtins
import copy
import dataclasses
import hashlib
import inspect
import json
import textwrap
from collections import deque
from collections.abc import Callable, Iterable
Expand Down Expand Up @@ -1254,3 +1256,151 @@ def parse_workflow(
edges=edges,
metadata=metadata,
)


def get_workflow_graph(workflow_dict: dict[str, Any]) -> nx.DiGraph:
edges = cast(list[tuple[str, str]], workflow_dict["edges"])
missing_edges = [edge for edge in _get_missing_edges(edges)]
all_edges = sorted(edges + missing_edges)
return _replace_input_ports(nx.DiGraph(all_edges), workflow_dict)


def _replace_input_ports(
graph: nx.DiGraph, workflow_dict: dict[str, Any]
) -> nx.DiGraph:
G = graph.copy()
for n in list(G.nodes):
if G.in_degree(n) == 0:
assert n.startswith("inputs.")
data = _get_entry(workflow_dict, n)
if "value" in data:
nx.relabel_nodes(G, {n: data["value"]}, copy=False)
elif "default" in data:
nx.relabel_nodes(G, {n: data["default"]}, copy=False)
return G


def get_hashed_node_dict(
node: str, graph: nx.DiGraph, nodes_dict: dict[str, dict]
) -> dict[str, Any]:
"""
Get a dictionary representation of a node for hashing purposes and database
entries. This function extracts the metadata of the node, its inputs, and
outputs, and returns a dictionary that can be hashed.

Args:
node (str): The name of the node to be hashed.
graph (nx.DiGraph): The directed graph representing the function.
nodes_dict (dict[str, dict]): A dictionary containing metadata for all nodes.

Returns:
dict[str, Any]: A dictionary representation of the node for hashing.

Raises:
ValueError: If the node does not have a function or if the data is not flat.
"""
if "function" not in nodes_dict[node]:
raise ValueError("Hashing works only on flat data")
data_dict = {
"nodes": _get_function_metadata(nodes_dict[node]["function"]),
"inputs": {},
"outputs": list(nodes_dict[node]["outputs"].keys()),
}
connected_inputs = []
for key in nodes_dict[node]["inputs"]:
tag = f"{node}.inputs.{key}"
predecessor = list(graph.predecessors(tag))
assert len(predecessor) == 1
predecessor = predecessor[0]
pre_predecessor = list(graph.predecessors(predecessor))
if len(pre_predecessor) > 0:
assert len(pre_predecessor) == 1
value = (
get_node_hash(pre_predecessor[0], graph, nodes_dict)
+ "@"
+ predecessor.split(".")[-1]
)
connected_inputs.append(key)
else:
value = predecessor
data_dict["inputs"][key] = value
data_dict["nodes"]["connected_inputs"] = connected_inputs
return data_dict


def get_node_hash(node: str, graph: nx.DiGraph, nodes_dict: dict[str, dict]) -> str:
"""
Get a hash of the node's metadata, inputs, and outputs.

Args:
node (str): The name of the node to be hashed.
graph (nx.DiGraph): The directed graph representing the function.
nodes_dict (dict[str, dict]): A dictionary containing metadata for all nodes.

Returns:
str: A SHA-256 hash of the node's metadata, inputs, and outputs.
"""
data_dict = get_hashed_node_dict(node=node, graph=graph, nodes_dict=nodes_dict)
return hashlib.sha256(
json.dumps(data_dict, sort_keys=True).encode("utf-8")
).hexdigest()


def _get_entry(data: dict[str, Any], key: str) -> Any:
"""
Get a value from a nested dictionary at the specified key path.

Args:
data (dict[str, Any]): The dictionary to search.
key (str): The key path to retrieve the value from, separated by dots.

Returns:
Any: The value at the specified key path.

Raises:
KeyError: If the key path does not exist in the dictionary.
"""
for item in key.split("."):
data = data[item]
return data


def _set_entry(
data: dict[str, Any], key: str, value: Any, create_missing: bool = False
) -> None:
"""
Set a value in a nested dictionary at the specified key path.

Args:
data (dict[str, Any]): The dictionary to modify.
key (str): The key path to set the value at, separated by dots.
value (Any): The value to set.
create_missing (bool): Whether to create missing keys in the path.
"""
keys = key.split(".")
for k in keys[:-1]:
if k not in data:
if create_missing:
data[k] = {}
else:
raise KeyError(f"Key '{k}' not found in data.")
data = data[k]
data[keys[-1]] = value


def _get_function_metadata(cls: Callable) -> dict[str, str]:
module = cls.__module__
qualname = cls.__qualname__
from importlib import import_module

base_module = import_module(module.split(".")[0])
version = (
base_module.__version__
if hasattr(base_module, "__version__")
else "not_defined"
)
return {
"module": module,
"qualname": qualname,
"version": version,
}
72 changes: 72 additions & 0 deletions tests/unit/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,78 @@ def g(x):

self.assertRaises(NotImplementedError, fwf.get_workflow_dict, f)

def test_get_hashed_node_dict(self):

def workflow_with_data(a=10, b=20):
x = add(a, b)
y = multiply(x, b)
return x, y

workflow_dict = fwf.get_workflow_dict(workflow_with_data)
graph = fwf.get_workflow_graph(workflow_dict)
data_dict = fwf.get_hashed_node_dict("add_0", graph, workflow_dict["nodes"])
self.assertEqual(
data_dict,
{
"nodes": {
"module": add.__module__,
"qualname": "add",
"version": "not_defined",
"connected_inputs": [],
},
"inputs": {"x": 10, "y": 20},
"outputs": ["output"],
},
)
add_hashed = fwf.get_node_hash("add_0", graph, workflow_dict["nodes"])
data_dict = fwf.get_hashed_node_dict(
"multiply_0", graph, workflow_dict["nodes"]
)
self.assertEqual(
data_dict,
{
"nodes": {
"module": multiply.__module__,
"qualname": "multiply",
"version": "not_defined",
"connected_inputs": ["x"],
},
"inputs": {"x": add_hashed + "@output", "y": 20},
"outputs": ["output"],
},
)
graph = fwf.get_workflow_graph(example_workflow._semantikon_workflow)
self.assertRaises(
ValueError,
fwf.get_hashed_node_dict,
"add_0",
graph,
example_workflow._semantikon_workflow["nodes"],
)

def test_get_and_set_entry(self):

def yet_another_workflow(a=10, b=20):
x = add(a, b)
y = multiply(x, b)
return x, y

workflow_dict = fwf.get_workflow_dict(yet_another_workflow)
self.assertEqual(fwf._get_entry(workflow_dict, "inputs.a.default"), 10)
self.assertRaises(KeyError, fwf._get_entry, workflow_dict, "inputs.x.default")
fwf._set_entry(workflow_dict, "inputs.a.value", 42)
self.assertEqual(fwf._get_entry(workflow_dict, "inputs.a.value"), 42)

def test_get_function_metadata(self):
self.assertEqual(
fwf._get_function_metadata(operation),
{
"module": operation.__module__,
"qualname": operation.__qualname__,
"version": "not_defined",
},
)


if __name__ == "__main__":
unittest.main()
Loading