# WorkGraph example to run Molecular Dynamics

## Aim

As an example, we start from a structure, run an md simulation, compute descriptors, and then use a filtering function to split the resulting structures into `train.xyz`, `test.xyz`, and `valid.xyz`.

Load the aiida profile, structure, model and code:

In [1]:
from aiida import load_profile
load_profile()

Profile<uuid='60b17659a9844c4bbd3bef8de0a8f417' name='presto'>

In [2]:
from aiida_mlip.data.model import ModelData
uri = "https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model"
model = ModelData.from_uri(uri, architecture="mace_mp", cache_dir="mlips")

In [3]:
from aiida.orm import StructureData
from ase.build import bulk
from ase.io import read, iread

# structure = StructureData(ase=read("Structures/qmof-ffeef76.cif"))
init_structure = StructureData(ase=bulk("NaCl", "rocksalt", 5.63))

In [4]:
from aiida.orm import load_code
janus_code = load_code("janus@localhost")

In [5]:
from aiida.orm import Str, Float, Bool, Int, Dict
inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cpu"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
    "ensemble": Str("NVT"),
    "struct": init_structure,
    "md_kwargs": Dict(
        {
            "steps": 10,
            "traj-every": 2
        }
    )
}

In [6]:
from aiida.plugins import CalculationFactory

mdCalc = CalculationFactory("mlip.md")
descriptorsCalc = CalculationFactory("mlip.descriptors")

# Single Descriptor

In [7]:
from aiida_workgraph import WorkGraph
wg = WorkGraph("MD_workgraph")
md_calc = wg.add_task(
    mdCalc,
    name="md_calc",
    **inputs
)
wg.run()

11/12/2025 04:50:19 PM <25252> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [722|WorkGraphEngine|continue_workgraph]: tasks ready to run: md_calc
11/12/2025 04:50:20 PM <25252> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [722|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 725


/home/qoj42292/.aiida/scratch/presto/95/52/a2c5-f25b-4934-94b4-0a6f54decdaa md-summary.yml


11/12/2025 04:50:30 PM <25252> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [722|WorkGraphEngine|update_task_state]: Task: md_calc, type: CALCJOB, finished.
11/12/2025 04:50:30 PM <25252> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [722|WorkGraphEngine|continue_workgraph]: tasks ready to run: 
11/12/2025 04:50:30 PM <25252> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [722|WorkGraphEngine|finalize]: Finalize workgraph.


{}

In [8]:
print(type(wg.tasks.md_calc.outputs.traj_file.value))

<class 'aiida.orm.nodes.data.singlefile.SinglefileData'>


In [9]:
# from aiida.orm.nodes import as_path

from aiida.orm import SinglefileData, StructureData

for i in range(4):
    with wg.tasks.md_calc.outputs.traj_file.value.as_path() as path:
        output=StructureData(ase=read(path, index=i))
        
    descriptors_calc = wg.add_task(
        descriptorsCalc,
        name="descriptors_calc_%d" % i,
        struct=md_calc.outputs.final_structure,
    )

defining outputnode


In [10]:
wg

NodeGraphWidget(settings={'minimap': True}, states={'graph_inputs': 'FINISHED', 'graph_outputs': 'FINISHED', 'â€¦

In [None]:
from sample_split import process_and_split_data

split_data = wg.add_task(
    process_and_split_data,
    name="split_data",
    struct=descriptors_calc.outputs.xyz_output
)


In [None]:
wg

In [None]:
wg.tasks

In [None]:
# wg.run()

# Additional information

In [None]:
mdCalc.get_description()["spec"]["outputs"].keys()

In [None]:
trajectory = md_calc.outputs.traj_output
print(trajectory)
output= StructureData(ase=read(trajectory, index=1))
print(output)

In [None]:
print('outputs of mdCalc:', md_calc.outputs)


In [None]:
print(type(wg.tasks.md_calc.outputs.traj_file.value))

In [None]:
# To find inputs/outputs of mdcalc uncomment following: 
# mdCalc.get_description()["spec"]["outputs"].keys()

# Loop Descriptor

 the workflow does not have an output yet for wg.tasks.md.outputs.traj_file 

 usually you can just pass in a socket. But because we have to get the path to read it you have to create a task
 
 you need to create a task which waites for md_task to run and then gets that output

In [None]:
# read the trajectory data with iread
with md_calc.outputs.traj_output.open(mode="r") as traj_file:
    traj = list(iread(traj_file, format="extxyz"))  
print(f"Number of frames in trajectory: {len(traj)}")

In [None]:
from aiida_workgraph import WorkGraph

with WorkGraph("MD_Simple") as wg:
    
    # MD simulation
    md_task = wg.add_task(
        mdCalc,
        name="md",
        **inputs
    )
    trajectory = md_task.outputs.traj_output
   
    # output=StructureData(ase=read(trajectory, index=":"))
    # for i in enumerate(output):
    # for i in range(trajectory.numsteps): (see geom_opt.ipynb)
    for i in range(2):
        output=StructureData(ase=read(trajectory, index=i))
        # Descriptors on intermediate structures
        desc_task = wg.add_task(
            descriptorsCalc,
            name=f"descriptors_{i}",
            code=inputs['code'],
            model=inputs['model'],
            arch=inputs['arch'],
            device=inputs['device'],
            metadata=inputs['metadata'],
            calc_per_element=Bool(True),
            struct=md_task.outputs.final_structure,
            trajectory_data=output,
            
        )
    # # Descriptors on final structure
    
    # desc_task = wg.add_task(
    #     descriptorsCalc,
    #     name="descriptors",
    #     code=inputs['code'],
    #     model=inputs['model'],
    #     arch=inputs['arch'],
    #     device=inputs['device'],
    #     metadata=inputs['metadata'],
    #     calc_per_element=Bool(True),
    #     struct=md_task.outputs.final_structure,
    # )

wg.run()

In [None]:
wg