In [1]:
from aiida import load_profile
from aiida.orm import Int
from aiida_workgraph import WorkGraph
from aiida.engine import calcfunction
from aiida_workgraph.decorator import build_task_from_callable

In [2]:
from inspect import isfunction

In [3]:
def get_kwargs(lst):
    return {t['targetHandle']: {'source': t['source'], 'sourceHandle': t['sourceHandle']} for t in lst}

In [4]:
def wrap_function(func, **kwargs):
    # First, apply the calcfunction decorator
    func_decorated = calcfunction(func)
    # Then, apply task decorator
    task_decorated = build_task_from_callable(
        func_decorated,
        inputs=kwargs.get("inputs", []),
        outputs=kwargs.get("outputs", []),
    )
    identifier = kwargs.get("identifier", None)
    func_decorated.identifier = identifier if identifier else func.__name__
    func_decorated.task = func_decorated.node = task_decorated
    return func_decorated

In [5]:
def group_edges(edges_lst):
    # edges_sorted_lst = sorted(edges_lst, key=lambda x: x['target'], reverse=True)     
    total_dict = {}
    tmp_lst = []
    target_id = edges_lst[0]['target'] 
    for ed in edges_lst:
        if target_id == ed["target"]:
            tmp_lst.append(ed)
        else:
            total_dict[target_id] = get_kwargs(lst=tmp_lst)
            target_id = ed["target"]
            tmp_lst = [ed]
    total_dict[target_id] = get_kwargs(lst=tmp_lst)
    return total_dict

In [6]:
def get_output_labels(edges_lst):
    output_label_dict = {}
    for ed in edges_lst:
        if ed['sourceHandle'] is not None:
            if ed["source"] not in output_label_dict.keys():
                output_label_dict[ed["source"]] = []
            output_label_dict[ed["source"]].append(ed['sourceHandle'])
    return output_label_dict

In [7]:
def get_function_dict(nodes_dict, output_label_dict):
    function_dict = {}
    for k, v in nodes_dict.items():
        if isfunction(v):
            if k in output_label_dict.keys():
                kwargs = {"outputs": [{"name": label} for label in output_label_dict[k]]}
                function_dict[k] = wrap_function(func=v, **kwargs)
            else: 
                function_dict[k] = wrap_function(func=v)
    
    return function_dict

In [8]:
def build_workgraph(function_dict, total_dict, nodes_dict, label="my_workflow"):
    wg = WorkGraph(label)
    mapping_dict = {}
    for k, v in function_dict.items():
        name = v.__name__
        mapping_dict[k] = name
        total_item_dict = total_dict[k]
        kwargs = {}
        for tk, tv in total_item_dict.items():
            if tv['source'] in mapping_dict.keys():
                kwargs[tk] = wg.tasks[mapping_dict[tv['source']]].outputs[tv['sourceHandle']]
            elif tv['sourceHandle'] is None:
                kwargs[tk] = nodes_dict[tv['source']]
            else:
                raise ValueError()
        wg.add_task(v, name=name, **kwargs)
    return wg

In [9]:
def add_x_and_y(x, y):
    z = x + y
    return {"x": x, "y": y, "z": z}

In [10]:
def add_x_and_y_and_z(x, y, z):
    w = x + y + z
    return w

In [11]:
edges_lst = [
    {'target': 1, 'targetHandle': 'z', 'source': 0, 'sourceHandle': 'z'},
    {'target': 1, 'targetHandle': 'x', 'source': 0, 'sourceHandle': 'x'},
    {'target': 1, 'targetHandle': 'y', 'source': 0, 'sourceHandle': 'y'},
    {'target': 0, 'targetHandle': 'x', 'source': 2, 'sourceHandle': None},
    {'target': 0, 'targetHandle': 'y', 'source': 3, 'sourceHandle': None},
]

In [12]:
nodes_dict = {
    0: add_x_and_y,
    1: add_x_and_y_and_z,
    2: 1,
    3: 2,
}

In [13]:
output_label_dict = get_output_labels(edges_lst)
output_label_dict

{0: ['z', 'x', 'y']}

In [14]:
total_dict = group_edges(edges_lst=edges_lst)
total_dict

{1: {'z': {'source': 0, 'sourceHandle': 'z'},
  'x': {'source': 0, 'sourceHandle': 'x'},
  'y': {'source': 0, 'sourceHandle': 'y'}},
 0: {'x': {'source': 2, 'sourceHandle': None},
  'y': {'source': 3, 'sourceHandle': None}}}

In [15]:
function_dict = get_function_dict(nodes_dict=nodes_dict, output_label_dict=output_label_dict)
function_dict 

{0: <function __main__.add_x_and_y(x, y)>,
 1: <function __main__.add_x_and_y_and_z(x, y, z)>}

In [16]:
load_profile()

Profile<uuid='7bb8761123324468bb98821cbb757251' name='presto'>

In [17]:
wg = build_workgraph(function_dict=function_dict, total_dict=total_dict, nodes_dict=nodes_dict, label="my_workflow")
wg

Widget dependency not found. To visualize the workgraph, please install the widget dependency. Use 'pip install aiida-workgraph[widget]' if installing from PyPI. For local source installations, use 'pip install .[widget]' and then build the JavaScript library. Refer to the documentation for more details.


NodeGraph(name="my_workflow, uuid="79495230-ae25-11ef-8590-9ec6524ec879")

In [18]:
wg.run()

11/29/2024 08:42:25 AM <57231> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [58|WorkGraphEngine|continue_workgraph]: Continue workgraph.
11/29/2024 08:42:25 AM <57231> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [58|WorkGraphEngine|continue_workgraph]: tasks ready to run: add_x_and_y
11/29/2024 08:42:25 AM <57231> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [58|WorkGraphEngine|run_tasks]: Run task: add_x_and_y, type: CALCFUNCTION


------------------------------------------------------------


11/29/2024 08:42:26 AM <57231> aiida.orm.nodes.process.calculation.calcfunction.CalcFunctionNode: [REPORT] [61|add_x_and_y|on_except]: Traceback (most recent call last):
  File "/home/jan/mambaforge/lib/python3.12/site-packages/plumpy/base/state_machine.py", line 324, in transition_to
    self._enter_next_state(new_state)
  File "/home/jan/mambaforge/lib/python3.12/site-packages/plumpy/base/state_machine.py", line 388, in _enter_next_state
    self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)
  File "/home/jan/mambaforge/lib/python3.12/site-packages/plumpy/base/state_machine.py", line 300, in _fire_state_event
    callback(self, hook, state)
  File "/home/jan/mambaforge/lib/python3.12/site-packages/plumpy/processes.py", line 331, in <lambda>
    lambda _s, _h, from_state: self.on_entered(cast(Optional[process_states.State], from_state)),
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jan/mambaforge/lib/pytho

{}