# Nested Workflows with AiiDA

This notebook demonstrates nested workflow support with load and write functionality.

## Define Nested Workflow with AiiDA

In [1]:
from python_workflow_definition.aiida import write_workflow_json, load_workflow_json

from aiida_workgraph import WorkGraph, task, namespace
from aiida import orm, load_profile

load_profile()

workflow_json_filename = "nested_test.pwd.json"

In [2]:
from workflow import (
    get_sum as _get_sum,
    get_prod_and_div as _get_prod_and_div,
    get_square as _get_square,
)

In [3]:
# Wrap the functions with @task decorator
get_prod_and_div = task(outputs=["prod", "div"])(_get_prod_and_div)
get_sum = task(_get_sum)
get_square = task(_get_square)

### Create Nested Workflow

In [4]:
# Create nested workflow manually (corresponds to prod_div.json)
nested_wg = WorkGraph(
    name="nested_workflow",
    inputs=namespace(x=namespace, y=namespace),
    outputs=namespace(result=namespace),
)

# Add tasks to nested workflow
t1 = nested_wg.add_task(get_prod_and_div)
t2 = nested_wg.add_task(get_sum)
t3 = nested_wg.add_task(get_square)

# Connect nested workflow inputs to first task
nested_wg.add_link(nested_wg.inputs.x, t1.inputs.x)
nested_wg.add_link(nested_wg.inputs.y, t1.inputs.y)

# Connect tasks within nested workflow
nested_wg.add_link(t1.outputs.prod, t2.inputs.x)
nested_wg.add_link(t1.outputs.div, t2.inputs.y)
nested_wg.add_link(t2.outputs.result, t3.inputs.x)

# Connect nested workflow output
nested_wg.outputs.result = t3.outputs.result

# Set default values for nested workflow inputs
nested_wg.inputs.x.value = orm.Float(1)
nested_wg.inputs.y.value = orm.Float(2)

### Create Main Workflow with Nested Workflow

In [5]:
# Create main workflow (corresponds to main.pwd.json)
main_wg = WorkGraph(
    name="main_workflow",
    inputs=namespace(a=namespace, b=namespace, c=namespace),
    outputs=namespace(final_result=namespace),
)

# Add tasks to main workflow
preprocessing = main_wg.add_task(get_prod_and_div)
nested_task = main_wg.add_task(nested_wg)  # Add the nested workflow as a task
postprocessing = main_wg.add_task(get_sum)

# Connect main workflow inputs to preprocessing
main_wg.add_link(main_wg.inputs.a, preprocessing.inputs.x)
main_wg.add_link(main_wg.inputs.c, preprocessing.inputs.y)

# Connect preprocessing to nested workflow
main_wg.add_link(preprocessing.outputs.prod, nested_task.inputs.x)
main_wg.add_link(preprocessing.outputs.div, nested_task.inputs.y)

# Connect nested workflow to postprocessing
main_wg.add_link(nested_task.outputs.result, postprocessing.inputs.x)
main_wg.add_link(main_wg.inputs.b, postprocessing.inputs.y)

# Connect main workflow output
main_wg.outputs.final_result = postprocessing.outputs.result

# Set default values for main workflow inputs
main_wg.inputs.a.value = orm.Float(3)
main_wg.inputs.b.value = orm.Float(2)
main_wg.inputs.c.value = orm.Float(4)

### Export Workflow to JSON

In [6]:
write_workflow_json(wg=main_wg, file_name=workflow_json_filename)
print(f"Exported workflow to {workflow_json_filename}")

Exported workflow to nested_test.pwd.json


In [7]:
!cat {workflow_json_filename}

[38;5;238m───────┬────────────────────────────────────────────────────────────────────────[0m
       [38;5;238m│ [0mFile: [1mnested_test.pwd.json[0m
[38;5;238m───────┼────────────────────────────────────────────────────────────────────────[0m
[38;5;238m   1[0m   [38;5;238m│[0m [38;5;231m{[0m
[38;5;238m   2[0m   [38;5;238m│[0m [38;5;231m  [0m[38;5;208m"[0m[38;5;208mversion[0m[38;5;208m"[0m[38;5;231m:[0m[38;5;231m [0m[38;5;186m"[0m[38;5;186m0.1.1[0m[38;5;186m"[0m[38;5;231m,[0m
[38;5;238m   3[0m   [38;5;238m│[0m [38;5;231m  [0m[38;5;208m"[0m[38;5;208mnodes[0m[38;5;208m"[0m[38;5;231m:[0m[38;5;231m [0m[38;5;231m[[0m
[38;5;238m   4[0m   [38;5;238m│[0m [38;5;231m    [0m[38;5;231m{[0m
[38;5;238m   5[0m   [38;5;238m│[0m [38;5;231m      [0m[38;5;208m"[0m[38;5;208mid[0m[38;5;208m"[0m[38;5;231m:[0m[38;5;231m [0m[38;5;141m0[0m[38;5;231m,[0m
[38;5;238m   6[0m   [38;5;238m│[0m [38;5;231m      [0m[38;5;208m"[0m

### Check Nested Workflow File

In [8]:
!cat nested_1.json

[38;5;238m───────┬────────────────────────────────────────────────────────────────────────[0m
       [38;5;238m│ [0mFile: [1mnested_1.json[0m
[38;5;238m───────┼────────────────────────────────────────────────────────────────────────[0m
[38;5;238m   1[0m   [38;5;238m│[0m [38;5;231m{[0m
[38;5;238m   2[0m   [38;5;238m│[0m [38;5;231m  [0m[38;5;208m"[0m[38;5;208mversion[0m[38;5;208m"[0m[38;5;231m:[0m[38;5;231m [0m[38;5;186m"[0m[38;5;186m0.1.1[0m[38;5;186m"[0m[38;5;231m,[0m
[38;5;238m   3[0m   [38;5;238m│[0m [38;5;231m  [0m[38;5;208m"[0m[38;5;208mnodes[0m[38;5;208m"[0m[38;5;231m:[0m[38;5;231m [0m[38;5;231m[[0m
[38;5;238m   4[0m   [38;5;238m│[0m [38;5;231m    [0m[38;5;231m{[0m
[38;5;238m   5[0m   [38;5;238m│[0m [38;5;231m      [0m[38;5;208m"[0m[38;5;208mid[0m[38;5;208m"[0m[38;5;231m:[0m[38;5;231m [0m[38;5;141m0[0m[38;5;231m,[0m
[38;5;238m   6[0m   [38;5;238m│[0m [38;5;231m      [0m[38;5;208m"[0m[38;5;

## Load and Verify Workflow

In [9]:
# Load the workflow back
wg_loaded = load_workflow_json(workflow_json_filename)

print(f"Loaded workflow: {wg_loaded.name}")
print(f"Number of tasks: {len([t for t in wg_loaded.tasks if t.name not in ['graph_inputs', 'graph_outputs', 'graph_ctx']])}")

# Check inputs
print("\nInputs:")
for name, socket in wg_loaded.inputs._sockets.items():
    if not name.startswith('_') and name != 'metadata':
        if hasattr(socket, 'value') and socket.value is not None:
            value = socket.value.value if hasattr(socket.value, 'value') else socket.value
            print(f"  {name} = {value}")

# Check for nested workflows
print("\nNested workflows:")
for task in wg_loaded.tasks:
    if hasattr(task, 'tasks'):
        nested_tasks = [t for t in task.tasks if t.name not in ['graph_inputs', 'graph_outputs', 'graph_ctx']]
        if len(nested_tasks) > 0:
            print(f"  Found '{task.name}' with {len(nested_tasks)} tasks")
            # Check nested workflow defaults
            for subtask in task.tasks:
                if subtask.name == 'graph_inputs' and hasattr(subtask, 'outputs'):
                    print("    Default inputs:")
                    for out in subtask.outputs:
                        if hasattr(out, '_name') and not out._name.startswith('_'):
                            value = out.value.value if hasattr(out.value, 'value') else out.value
                            print(f"      {out._name} = {value}")

Loaded workflow: WorkGraph
Number of tasks: 3

Inputs:
  a = 3
  b = 2
  c = 4

Nested workflows:
  Found 'WorkGraph' with 3 tasks
    Default inputs:
      x = 1
      y = 2


## Round-Trip Test

In [10]:
import json
from pathlib import Path

# Export the loaded workflow again
roundtrip_file = "nested_roundtrip.pwd.json"
write_workflow_json(wg_loaded, roundtrip_file)

# Compare the two exports
with open(workflow_json_filename) as f1, open(roundtrip_file) as f2:
    data1 = json.load(f1)
    data2 = json.load(f2)

match = json.dumps(data1, sort_keys=True) == json.dumps(data2, sort_keys=True)
print(f"Round-trip test: {'PASS' if match else 'FAIL'}")

if not match:
    print("\nDifferences found!")
    raise AssertionError("Round-trip test failed")
else:
    print("Workflow export/import is stable and idempotent")

Round-trip test: PASS
Workflow export/import is stable and idempotent


## Load Workflow with Other Frameworks

### Load Workflow with jobflow

In [11]:
# from python_workflow_definition.jobflow import load_workflow_json

In [12]:
# from jobflow.managers.local import run_locally

In [13]:
# flow = load_workflow_json(file_name=workflow_json_filename)

In [14]:
# result = run_locally(flow)
# result

### Load Workflow with pyiron_base

In [15]:
# from python_workflow_definition.pyiron_base import load_workflow_json

In [16]:
# delayed_object_lst = load_workflow_json(file_name=workflow_json_filename)
# delayed_object_lst[-1].draw()

In [17]:
# delayed_object_lst[-1].pull()

### Load Workflow with pyiron_workflow

In [18]:
# from python_workflow_definition.pyiron_workflow import load_workflow_json

In [19]:
# wf = load_workflow_json(file_name=workflow_json_filename)

In [20]:
# wf.draw(size=(10, 10))

In [21]:
# wf.run()

## Cleanup

In [22]:
# Clean up test files
import os
for f in [workflow_json_filename, roundtrip_file, "nested_1.json"]:
    if os.path.exists(f):
        os.remove(f)
        print(f"Removed {f}")

Removed nested_test.pwd.json
Removed nested_roundtrip.pwd.json
Removed nested_1.json
