# WorkGraph example to run Molecular Dynamics

## Aim

This notebook demonstrates how to run a molecular dynamics simulation. As an example, we start with a salt crystal structure, run a molecular dynamics simulation on it that takes snapshots of the trajectory over time and then compute the descriptors on each structure snapshot. The workgraph uses a filtering function to split the resulting data into three files for training the machine learning model.
The goal is to show how to run the descriptors on each snapshot of the MD sim. 

note to self: NVT keeps temp and vol constant, is this adequate? MD req. temp change?

Load the aiida profile, structure, model and code:

In [1]:
from aiida import load_profile
load_profile()

Profile<uuid='60b17659a9844c4bbd3bef8de0a8f417' name='presto'>

In [2]:
from aiida_mlip.data.model import ModelData
uri = "https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model"
model = ModelData.from_uri(uri, architecture="mace_mp", cache_dir="mlips")

In [3]:
from aiida.orm import StructureData
from ase.build import bulk
from ase.io import read, iread

# structure = StructureData(ase=read("Structures/qmof-ffeef76.cif"))
init_structure = StructureData(ase=bulk("NaCl", "rocksalt", 5.63))

In [4]:
from aiida.orm import load_code
janus_code = load_code("janus@localhost")
qe_code = load_code("qe@scarf")

We initialize the inputs we want for all the calculations. These variables can be changed depending on the configuration you are running and whether you want to change any inputs.

In [5]:
from aiida.plugins import CalculationFactory

mdCalc = CalculationFactory("mlip.md")
descriptorsCalc = CalculationFactory("mlip.descriptors")
trainCalc = CalculationFactory("mlip.train")

configure the descriptors task to run on each trajectory snapshop so we create multiple tasks dynamically within the same task using `get_current_graph()`. This allows us to run descriptors for each structure.

In [6]:
from aiida_workgraph import WorkGraph, task
from aiida_workgraph.manager import get_current_graph

@task.graph(outputs=["structs"])
def descriptors_task(
    code,
    model,
    device,
    arch,
    file,
):
    descriptorsCalc = CalculationFactory("mlip.descriptors")
    wg = get_current_graph()
    final_structures={}

    with file.as_path() as path:
        for i, structs in enumerate(iread(path)):
            structure = StructureData(ase=structs)

            desc_calc = wg.add_task(
                descriptorsCalc,
                code=code,
                model=model,
                device=device,
                arch=arch,
                struct=structure,
                metadata={"options": {"resources": {"num_machines": 1}}}
            )

            final_structures[f"structs{i}"] = structs
            #   final_structures[f"structs{i}"] = desc_calc.outputs.xyz_output

    wg.update_ctx({
        "structs":final_structures
    })

    return{
        "structs": wg.ctx.structs
    }

configure the `Quantum Espresso (QE)` task by defining the code and input parameters. Since we need to run QE on multiple structures, we create multiple `PwCalculation` tasks dynamically within the same task using `get_current_graph()`. This allows us to run QE for each structure and return the corresponding `TrajectoryData` and parameters for each.

In [7]:
from aiida_workgraph import task
from aiida_workgraph.manager import get_current_graph
from aiida.orm import StructureData, load_group, KpointsData, SinglefileData, InstalledCode, List, Dict
from ase.io import iread
from pathlib import Path
import yaml
from aiida_quantumespresso.calculations.pw import PwCalculation
from sample_split import process_and_split_data


@task.graph(outputs = ["test_file", "train_file", "valid_file"])
def qe(
    code: InstalledCode,
    kpoints_mesh: List,
    task_metadata: Dict,
    test_file: SinglefileData,
    train_file: SinglefileData,
    valid_file: SinglefileData
    ):

    wg = get_current_graph()

    kpoints = KpointsData()
    kpoints.set_kpoints_mesh(kpoints_mesh)

    pseudo_family = load_group('SSSP/1.3/PBE/efficiency')
    
    files = {"test_file": test_file, "train_file": train_file, "valid_file": valid_file}

    for file_name, file in files.items():
        with file.as_path() as path:
            for i, structs in enumerate(iread(path, format="extxyz")):
                
                structure = StructureData(ase=structs)
                pseudos = pseudo_family.get_pseudos(structure=structure)

                ecutwfc, ecutrho = pseudo_family.get_recommended_cutoffs(
                    structure=structure,
                    unit='Ry',
                )

                pw_params = {
                    "CONTROL": {
                        "calculation": "scf",
                        'tprnfor': True,
                        'tstress': True,
                    },
                    "SYSTEM": {
                        "ecutwfc": ecutwfc,
                        "ecutrho": ecutrho,
                    },
                }
                
                qe_task = wg.add_task(
                    PwCalculation,
                    code=code,
                    parameters=pw_params,
                    kpoints=kpoints,
                    pseudos=pseudos,
                    metadata=task_metadata.value,
                    structure=structure,
                )
                
                structfile = f"{file_name}.struct{i}"

                wg.update_ctx({
                    structfile:{
                        "trajectory":qe_task.outputs.output_trajectory,
                        "parameters": qe_task.outputs.output_parameters
                    }
                })

    return {
        "test_file": wg.ctx.test_file,
        "train_file": wg.ctx.train_file,
        "valid_file": wg.ctx.valid_file
    }    

The next task we need is a function which can extract the required parameters from the QE tasks and create the files for training. This task creates `mlip_[file]_file.extxyz`, adds the filepath to the example `JanusConfigfile.yml` that we provide. This file sets all of the inputs for fine-tuning, so should be modified according to your needs. Finally, the task returns a `JanusConfigfile` object which is used for the training calculations.

In [8]:
from aiida_mlip.data.config import JanusConfigfile
from aiida.orm import Dict
from ase.io import write
from ase import units

@task.calcfunction(outputs = ["JanusConfigfile"])
def create_train_file(**inputs):

    training_files = {}
    
    for file_name, structs in inputs.items():
        path = Path(f"mlip_{file_name}.extxyz")

        for struct_out_params in structs.values():
            
            trajectory = struct_out_params["trajectory"]

            fileStructure = trajectory.get_structure(index=0)
            fileAtoms = fileStructure.get_ase()

            stress = trajectory.arrays["stress"][0]
            converted_stress = stress * units.GPa
            fileAtoms.info["qe_stress"] = converted_stress

            fileAtoms.info["units"] = {"energy": "eV","forces": "ev/Ang","stress": "ev/Ang^3"}
            fileAtoms.set_array("qe_forces", trajectory.arrays["forces"][0])

            parameters = struct_out_params["parameters"]
            fileParams = parameters.get_dict()
            fileAtoms.info["qe_energy"] = fileParams["energy"]
            write(path, fileAtoms, append=True)

        training_files[file_name] = str(path.resolve())

    with open("JanusConfigfile.yml", "a") as f:
        yaml.safe_dump(training_files, f, sort_keys=False)

    return{'JanusConfigfile': JanusConfigfile(Path("JanusConfigfile.yml").resolve())}

In [9]:
# need to tidy up these inputs: inputs is original (needs renaming for md_inputs) and others are QE etc

from aiida.orm import Str, Float, Bool, Int, Dict
inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cpu"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
    "ensemble": Str("NVT"),
    "struct": init_structure,
    "md_kwargs": Dict(
        {
            "steps": 10,
            "traj-every": 2
        }
    )
}
calc_inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cpu"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
}

split_task_inputs = {
    "config_types": Str(""),
    "prefix": Str(""),
    "scale": Float(1.0e5),
    "append_mode": Bool(False),
}

qe_inputs = {
    "task_metadata": Dict({
            "options": {
                "resources": {
                    "num_machines": 1,
                    "num_mpiprocs_per_machine": 32,
                },
                "max_wallclock_seconds": 3600,
                "queue_name": "scarf",
                "qos": "scarf",
                "environment_variables": {},
                "withmpi": True,
                "prepend_text": """
                    module purge
                    module use /work4/scd/scarf562/eb-common/modules/all
                    module load amd-modules
                    module load QuantumESPRESSO/7.2-foss-2023a
                """,
                "append_text": "",
            },
    }),
    "kpoints_mesh": List([1, 1, 1]),
    "code": qe_code,
}

For this task, we are using a task to run a pure python function. This is to demonstrate the flexibility of tasks and how you can run python functions with Workgraph. Also this task has to be a calcfunction, as we are returning `SinglefileData` instances of the test, train and valid files.

In [10]:
@task.calcfunction(outputs = ["test_file", "train_file", "valid_file"])
def create_qe_files(**inputs):
     
    files = process_and_split_data(**inputs)

    return {
        "train_file": SinglefileData(files["train_file"]),
        "test_file": SinglefileData(files["test_file"]),
        "valid_file": SinglefileData(files["valid_file"])
    }

In [13]:
with WorkGraph("MD") as wg:

    md_calc = wg.add_task(
        mdCalc,
        name="md_calc",
        **inputs
    )

    descriptors_calc = wg.add_task(
        descriptors_task,
        code=janus_code,
        model=model,
        device=Str("cpu"),
        arch=Str(model.architecture),
        file=md_calc.outputs.traj_file
    )

    # split_task = wg.add_task(
    #     create_split_files,
    #     name="split_data",
    #     config_types=Str(""),
    #     prefix=Str(""),
    #     scale=Float(1.0e5),
    #     append_mode=Bool(False),
    #     trajectory_data=descriptors_calc.outputs.structs,  
    # )

    # split_task = wg.add_task(
    #     process_and_split_data,
    #     config_types= Str(""),
    #     n_samples=Int(len(descriptors_calc.outputs.structs)),
    #     prefix= Str(""),
    #     scale= Float(1.0e5),
    #     append_mode= Bool(False),
    #     # create_aiida_files, 
    #     # **split_task_inputs,
    #     trajectory_data=descriptors_calc.outputs.structs,
    #     # n_samples= Int(len(final_structures)),
    # )

    


In [14]:
wg.run()

11/21/2025 10:42:25 AM <9195> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [2412|WorkGraphEngine|continue_workgraph]: tasks ready to run: md_calc
11/21/2025 10:42:26 AM <9195> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [2412|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 2415


/home/qoj42292/.aiida/scratch/presto/e1/22/d68d-bd03-4f16-9029-a191c0c8904e md-summary.yml


11/21/2025 10:42:37 AM <9195> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [2412|WorkGraphEngine|update_task_state]: Task: md_calc, type: CALCJOB, finished.
11/21/2025 10:42:37 AM <9195> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [2412|WorkGraphEngine|continue_workgraph]: tasks ready to run: descriptors_task


defining outputnode


11/21/2025 10:42:38 AM <9195> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [2412|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 2432
11/21/2025 10:42:38 AM <9195> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [2432|WorkGraphEngine|continue_workgraph]: tasks ready to run: Descriptors,Descriptors1,Descriptors2,Descriptors3,Descriptors4,Descriptors5
11/21/2025 10:42:40 AM <9195> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [2432|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 2436, 2440, 2444, 2448, 2452, 2456
11/21/2025 10:42:56 AM <9195> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [2432|WorkGraphEngine|update_task_state]: Task: Descriptors, type: CALCJOB, finished.
11/21/2025 10:42:57 AM <9195> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [2432|WorkGraphEngine|continue_workgraph]: tasks ready to run: 
11/21/2025 10:42:57 AM <9195> a

{}

In [15]:
print(wg.tasks)

NodeCollection(parent = "MD", nodes = ["graph_inputs", "graph_outputs", "graph_ctx", "md_calc", "descriptors_task"])


In [None]:
wg

In [17]:
! verdi process list -a

[22m  PK  Created    Process label                        ♻    Process State     Process status
----  ---------  -----------------------------------  ---  ----------------  ---------------------------------------------------------------------------------------------------
 114  50D ago    Singlepoint                               ⏹ Finished [0]
 132  50D ago    GeomOpt                                   ⏹ Finished [0]
 153  49D ago    GeomOpt                                   ⏹ Finished [0]
 164  49D ago    prepare_struct_inputs                     ⏹ Finished [0]
 170  49D ago    Singlepoint                               ⏹ Finished [0]
 188  49D ago    GeomOpt                                   ⏹ Finished [0]
 199  49D ago    prepare_struct_inputs                     ⏹ Finished [0]
 205  49D ago    Singlepoint                               ⏹ Finished [0]
 223  49D ago    GeomOpt                                   ⏹ Finished [0]
 234  49D ago    prepare_struct_inputs                     ⏹

# Additional information

In [None]:
print(descriptorsCalc.spec().inputs)


In [None]:
print(type(wg.tasks.md_calc.outputs.traj_file.value))

In [None]:
mdCalc.get_description()["spec"]["outputs"].keys()

In [None]:
print('outputs of mdCalc:', md_calc.outputs)


In [None]:
print(type(wg.tasks.md_calc.outputs.traj_file.value))

In [None]:
print('outputs of descriptors_task:', descriptors_calc.outputs)


In [None]:
# To find inputs/outputs of mdcalc uncomment following: 
# mdCalc.get_description()["spec"]["outputs"].keys()

In [None]:
# from aiida.orm import load_node
# traj = load_node(PK) 

# # print(len(list(iread(wg.tasks.md_calc.outputs.traj_file.value.as_path()))))
# traj_length = (wg.tasks.md_calc.outputs.traj_file.value.as_path()).numsteps

# Loop Descriptor

 the workflow does not have an output yet for wg.tasks.md.outputs.traj_file 

 usually you can just pass in a socket. But because we have to get the path to read it you have to create a task
 
 you need to create a task which waites for md_task to run and then gets that output

In [None]:
# from sample_split import process_and_split_data
# # @task.calcfunction(outputs = ["test_file", "train_file", "valid_file"])
# # def create_split_files(**inputs):
     
# #     files = process_and_split_data(**inputs)

# #     return {
# #         "train_file": SinglefileData(files["train_file"]),
# #         "test_file": SinglefileData(files["test_file"]),
# #         "valid_file": SinglefileData(files["valid_file"])
# #     }

# @task.calcfunction(outputs=["test_file", "train_file", "valid_file"])
# def create_split_files(trajectory_data, config_types, prefix, scale, append_mode):
#     """Create split files using plain Python function call"""
    
#     # Call the plain Python function directly (not as a task) ??
#     files = process_and_split_data(
#         trajectory_data=trajectory_data,
#         config_types=config_types,
#         prefix=prefix,
#         scale=scale,
#         append_mode=append_mode
#     )
    
#     return {
#         "train_file": SinglefileData(files["train_file"]),
#         "test_file": SinglefileData(files["test_file"]),
#         "valid_file": SinglefileData(files["valid_file"])
#     }