# WorkGraph example to run Molecular Dynamics

## Aim

As an example, we start from a structure, run an md simulation, compute descriptors, and then use a filtering function to split the resulting structures into `train.xyz`, `test.xyz`, and `valid.xyz`.

Load the aiida profile, structure, model and code:

In [1]:
from aiida import load_profile
load_profile()

Profile<uuid='60b17659a9844c4bbd3bef8de0a8f417' name='presto'>

In [2]:
from aiida_mlip.data.model import ModelData
uri = "https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model"
model = ModelData.from_uri(uri, architecture="mace_mp", cache_dir="mlips")

In [3]:
from aiida.orm import StructureData
from ase.build import bulk
from ase.io import read, iread

# structure = StructureData(ase=read("Structures/qmof-ffeef76.cif"))
init_structure = StructureData(ase=bulk("NaCl", "rocksalt", 5.63))

In [4]:
from aiida.orm import load_code
janus_code = load_code("janus@localhost")

In [5]:
from aiida.orm import Str, Float, Bool, Int, Dict
inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cpu"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
    "ensemble": Str("NVT"),
    "struct": init_structure,
    "md_kwargs": Dict(
        {
            "steps": 10,
            "traj-every": 2
        }
    )
}

In [6]:
from aiida.plugins import CalculationFactory

mdCalc = CalculationFactory("mlip.md")
descriptorsCalc = CalculationFactory("mlip.descriptors")

In [7]:
from aiida_workgraph import WorkGraph, task
from aiida_workgraph.manager import get_current_graph

@task.graph(outputs=["structs"])
def descriptors_task(
    code,
    model,
    device,
    arch,
    file,
):
    descriptorsCalc = CalculationFactory("mlip.descriptors")
    wg = get_current_graph()
    final_structures={}

    with file.as_path() as path:
        for i, structs in enumerate(iread(path)):
            structure = StructureData(ase=structs)

            desc_calc = wg.add_task(
                descriptorsCalc,
                code=code,
                model=model,
                device=device,
                arch=arch,
                struct=structure,
                metadata={"options": {"resources": {"num_machines": 1}}}
            )

            final_structures[f"structs{i}"] = desc_calc.outputs.xyz_output

    wg.update_ctx({
        "structs":final_structures
    })

    return{
        "structs": wg.ctx.structs
    }

In [8]:
@task.calcfunction(outputs = ["test_file", "train_file", "valid_file"])
def create_aiida_files(**inputs):
     
    files = process_and_split_data(**inputs)

    return {
        "train_file": SinglefileData(files["train_file"]),
        "test_file": SinglefileData(files["test_file"]),
        "valid_file": SinglefileData(files["valid_file"])
    }

In [9]:
with WorkGraph("MD") as wg:

    md_calc = wg.add_task(
        mdCalc,
        name="md_calc",
        **inputs
    )

    descriptors_calc = wg.add_task(
        descriptors_task,
        code=janus_code,
        model=model,
        device=Str("cpu"),
        arch=Str(model.architecture),
        file=md_calc.outputs.traj_file
    )
    
    split_task = wg.add_task(
        create_aiida_files, 
        # **split_task_inputs,
        trajectory_data=descriptors_calc.outputs.structs,  # Use the output from descriptors_task
        # n_samples= Int(len(final_structures)),
    )

In [10]:
wg.run()

11/19/2025 10:56:41 AM <27790> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1183|WorkGraphEngine|continue_workgraph]: tasks ready to run: md_calc
11/19/2025 10:56:41 AM <27790> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1183|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 1186


/home/qoj42292/.aiida/scratch/presto/fe/3f/c429-18ef-46c5-bf69-5d7dd023892f md-summary.yml


11/19/2025 10:56:53 AM <27790> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1183|WorkGraphEngine|update_task_state]: Task: md_calc, type: CALCJOB, finished.
11/19/2025 10:56:54 AM <27790> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1183|WorkGraphEngine|continue_workgraph]: tasks ready to run: descriptors_task


defining outputnode


11/19/2025 10:56:54 AM <27790> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1183|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 1203
11/19/2025 10:56:54 AM <27790> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1203|WorkGraphEngine|continue_workgraph]: tasks ready to run: Descriptors,Descriptors1,Descriptors2,Descriptors3,Descriptors4,Descriptors5
11/19/2025 10:56:56 AM <27790> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1203|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 1207, 1211, 1215, 1219, 1223, 1227
11/19/2025 10:57:18 AM <27790> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1203|WorkGraphEngine|update_task_state]: Task: Descriptors, type: CALCJOB, finished.
11/19/2025 10:57:18 AM <27790> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1203|WorkGraphEngine|update_task_state]: Task: Descriptors1, type: CALCJOB, finished.
11/

{}

In [11]:
wg

NodeGraphWidget(settings={'minimap': True}, states={'graph_inputs': 'FINISHED', 'graph_outputs': 'FINISHED', 'â€¦

In [None]:
# @task.calcfunction(outputs = ["test_file", "train_file", "valid_file"])
# def create_aiida_files(**inputs):
     
#     files = process_and_split_data(**inputs)

#     return {
#         "train_file": SinglefileData(files["train_file"]),
#         "test_file": SinglefileData(files["test_file"]),
#         "valid_file": SinglefileData(files["valid_file"])
#     }

In [None]:

# split_task = wg.add_task(
#     create_aiida_files, 
#     # **split_task_inputs,
#     trajectory_data=descriptors_calc.outputs.structs,  # Use the output from descriptors_task
#     # n_samples= Int(len(final_structures)),
# )

In [None]:
wg.run()

In [None]:
split_task = wg.add_task(
        create_aiida_files, 
        # **split_task_inputs,
        trajectory_data=final_structures,
        # n_samples= Int(len(final_structures)),
        )


# Single Descriptor

In [None]:

# wg = WorkGraph("MD_workgraph")
# md_calc = wg.add_task(
#     mdCalc,
#     name="md_calc",
#     **inputs
# )
# wg.run()

In [None]:
print(type(wg.tasks.md_calc.outputs.traj_file.value))

In [None]:
# from aiida.orm import load_node
# traj = load_node(PK) 

# # print(len(list(iread(wg.tasks.md_calc.outputs.traj_file.value.as_path()))))
# traj_length = (wg.tasks.md_calc.outputs.traj_file.value.as_path()).numsteps

In [None]:
print(descriptorsCalc.spec().inputs)


In [None]:
# # from aiida.orm.nodes import as_path

# from aiida.orm import SinglefileData, StructureData
# final_structures = {}

# # for i, struct in enumerate(iread(wg.tasks.md_calc.outputs.traj_file.value.as_path())):
# # for i in enumerate(iread(wg.tasks.md_calc.outputs.traj_file.value.as_path())):
# for i in range(4):
#     with wg.tasks.md_calc.outputs.traj_file.value.as_path() as path:

#         output=StructureData(ase=read(path, index=i))
#         # print(f"Read structure at index {i}: {output}")
        
#     descriptors_calc = wg.add_task(
#         descriptorsCalc,
#         name="descriptors_calc_%d" % i,
#         # struct=md_calc.outputs.final_structure,
#         struct=output,
#         # code=inputs['code'],
#         # model=inputs['model'],
#         # arch=inputs['arch'],
#         # device=inputs['device'],
#         # metadata=inputs['metadata'],
#         # calc_per_element=Bool(True),
#         # trajectory_data=output,
#     )
#     final_structures[f"structs{i}"] = descriptors_calc.outputs.xyz_output

In [None]:
wg

In [None]:
from sample_split import process_and_split_data

split_data = wg.add_task(
    process_and_split_data,
    name="split_data",
    struct=desc_calc.outputs.xyz_output
)


In [None]:
wg

In [None]:
wg.tasks

In [None]:
# wg.run()

# Additional information

In [None]:
mdCalc.get_description()["spec"]["outputs"].keys()

In [None]:
trajectory = md_calc.outputs.traj_output
print(trajectory)
output= StructureData(ase=read(trajectory, index=1))
print(output)

In [None]:
print('outputs of mdCalc:', md_calc.outputs)


In [None]:
print(type(wg.tasks.md_calc.outputs.traj_file.value))

In [None]:
# To find inputs/outputs of mdcalc uncomment following: 
# mdCalc.get_description()["spec"]["outputs"].keys()

# Loop Descriptor

 the workflow does not have an output yet for wg.tasks.md.outputs.traj_file 

 usually you can just pass in a socket. But because we have to get the path to read it you have to create a task
 
 you need to create a task which waites for md_task to run and then gets that output

In [None]:
# read the trajectory data with iread
with md_calc.outputs.traj_output.open(mode="r") as traj_file:
    traj = list(iread(traj_file, format="extxyz"))  
print(f"Number of frames in trajectory: {len(traj)}")

In [None]:
from aiida_workgraph import WorkGraph

with WorkGraph("MD_Simple") as wg:
    
    # MD simulation
    md_task = wg.add_task(
        mdCalc,
        name="md",
        **inputs
    )
    trajectory = md_task.outputs.traj_output
   
    # output=StructureData(ase=read(trajectory, index=":"))
    # for i in enumerate(output):
    # for i in range(trajectory.numsteps): (see geom_opt.ipynb)
    for i in range(2):
        output=StructureData(ase=read(trajectory, index=i))
        # Descriptors on intermediate structures
        desc_task = wg.add_task(
            descriptorsCalc,
            name=f"descriptors_{i}",
            code=inputs['code'],
            model=inputs['model'],
            arch=inputs['arch'],
            device=inputs['device'],
            metadata=inputs['metadata'],
            calc_per_element=Bool(True),
            struct=md_task.outputs.final_structure,
            trajectory_data=output,
            
        )
    # # Descriptors on final structure
    
    # desc_task = wg.add_task(
    #     descriptorsCalc,
    #     name="descriptors",
    #     code=inputs['code'],
    #     model=inputs['model'],
    #     arch=inputs['arch'],
    #     device=inputs['device'],
    #     metadata=inputs['metadata'],
    #     calc_per_element=Bool(True),
    #     struct=md_task.outputs.final_structure,
    # )

wg.run()

In [None]:
wg