# WorkGraph example to run Molecular Dynamics

## Aim

This notebook demonstrates how to run a molecular dynamics simulation. As an example, we start with a salt crystal structure, run a molecular dynamics simulation on it that takes snapshots of the trajectory over time and then compute the descriptors on each structure snapshot. The workgraph uses a filtering function to split the resulting data into three files for training the machine learning model.
The goal is to show how to run the descriptors on each snapshot of the MD sim. 

note to self: NVT keeps temp and vol constant, is this adequate? MD req. temp change?

Load the aiida profile, structure, model and code:

In [41]:
from aiida import load_profile
load_profile()

Profile<uuid='60b17659a9844c4bbd3bef8de0a8f417' name='presto'>

In [42]:
from aiida_mlip.data.model import ModelData
uri = "https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model"
model = ModelData.from_uri(uri, architecture="mace_mp", cache_dir="mlips")

In [43]:
from aiida.orm import StructureData
from ase.build import bulk
from ase.io import read, iread

# structure = StructureData(ase=read("Structures/qmof-ffeef76.cif"))
init_structure = StructureData(ase=bulk("NaCl", "rocksalt", 5.63))

In [44]:
from aiida.orm import load_code
janus_code = load_code("janus@localhost")

In [45]:
from aiida.orm import Str, Float, Bool, Int, Dict
inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cpu"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
    "ensemble": Str("NVT"),
    "struct": init_structure,
    "md_kwargs": Dict(
        {
            "steps": 10,
            "traj-every": 2
        }
    )
}

In [46]:
from aiida.plugins import CalculationFactory

mdCalc = CalculationFactory("mlip.md")
descriptorsCalc = CalculationFactory("mlip.descriptors")

Before setting up the work graph, we first configure the descriptors task to run on each trajectory snapshop so we create multiple tasks dynamically within the same task using `get_current_graph()`. This allows us to run descriptors for each structure.

In [47]:
from aiida_workgraph import WorkGraph, task
from aiida_workgraph.manager import get_current_graph

@task.graph(outputs=["structs"])
def descriptors_task(
    code,
    model,
    device,
    arch,
    file,
):
    descriptorsCalc = CalculationFactory("mlip.descriptors")
    wg = get_current_graph()
    final_structures={}

    with file.as_path() as path:
        for i, structs in enumerate(iread(path)):
            structure = StructureData(ase=structs)

            desc_calc = wg.add_task(
                descriptorsCalc,
                code=code,
                model=model,
                device=device,
                arch=arch,
                struct=structure,
                metadata={"options": {"resources": {"num_machines": 1}}}
            )

            final_structures[f"structs{i}"] = structs
            #   final_structures[f"structs{i}"] = desc_calc.outputs.xyz_output

    wg.update_ctx({
        "structs":final_structures
    })

    return{
        "structs": wg.ctx.structs
    }

In [49]:
with WorkGraph("MD") as wg:

    md_calc = wg.add_task(
        mdCalc,
        name="md_calc",
        **inputs
    )

    descriptors_calc = wg.add_task(
        descriptors_task,
        code=janus_code,
        model=model,
        device=Str("cpu"),
        arch=Str(model.architecture),
        file=md_calc.outputs.traj_file
    )

    split_task = wg.add_task(
        create_split_files,
        name="split_data",
        config_types=Str(""),
        prefix=Str(""),
        scale=Float(1.0e5),
        append_mode=Bool(False),
        trajectory_data=descriptors_calc.outputs.structs,  
    )

    # split_task = wg.add_task(
    #     process_and_split_data,
    #     config_types= Str(""),
    #     n_samples=Int(len(descriptors_calc.outputs.structs)),
    #     prefix= Str(""),
    #     scale= Float(1.0e5),
    #     append_mode= Bool(False),
    #     # create_aiida_files, 
    #     # **split_task_inputs,
    #     trajectory_data=descriptors_calc.outputs.structs,
    #     # n_samples= Int(len(final_structures)),
    # )

    


In [50]:
wg.run()

11/20/2025 02:43:47 PM <56463> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1813|WorkGraphEngine|continue_workgraph]: tasks ready to run: md_calc
11/20/2025 02:43:47 PM <56463> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1813|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 1816


/home/qoj42292/.aiida/scratch/presto/34/d1/9b86-1210-42a4-aa21-3e0cf7bf2dcc md-summary.yml


11/20/2025 02:44:00 PM <56463> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1813|WorkGraphEngine|continue_workgraph]: tasks ready to run: descriptors_task
11/20/2025 02:44:00 PM <56463> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1813|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 1833
11/20/2025 02:44:01 PM <56463> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1833|WorkGraphEngine|continue_workgraph]: tasks ready to run: Descriptors,Descriptors1,Descriptors2,Descriptors3,Descriptors4,Descriptors5
11/20/2025 02:44:02 PM <56463> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1833|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 1837, 1841, 1845, 1849, 1853, 1857
11/20/2025 02:44:21 PM <56463> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [1833|WorkGraphEngine|update_task_state]: Task: Descriptors3, type: CALCJOB, finished.
11/20/202

{}

In [51]:
wg

NodeGraphWidget(settings={'minimap': True}, states={'graph_inputs': 'FINISHED', 'graph_outputs': 'FINISHED', 'â€¦

# Additional information

In [52]:
print(descriptorsCalc.spec().inputs)


{
    "_attrs": {
        "default": [],
        "dynamic": false,
        "help": null,
        "required": "True",
        "valid_type": "<class 'aiida.orm.nodes.data.data.Data'>"
    },
    "arch": {
        "help": "MLIP architecture to use for calculation",
        "is_metadata": "False",
        "name": "arch",
        "non_db": "False",
        "required": "False",
        "valid_type": "(<class 'aiida.orm.nodes.data.str.Str'>, <class 'NoneType'>)"
    },
    "calc_kwargs": {
        "help": "Keyword arguments to pass to selected calculator.",
        "is_metadata": "False",
        "name": "calc_kwargs",
        "non_db": "False",
        "required": "False",
        "valid_type": "(<class 'aiida.orm.nodes.data.dict.Dict'>, <class 'NoneType'>)"
    },
    "calc_per_atom": {
        "help": "Calculate descriptors for each atom.",
        "is_metadata": "False",
        "name": "calc_per_atom",
        "non_db": "False",
        "required": "False",
        "valid_type": "(<class

In [53]:
print(type(wg.tasks.md_calc.outputs.traj_file.value))

<class 'aiida.orm.nodes.data.singlefile.SinglefileData'>


In [54]:
mdCalc.get_description()["spec"]["outputs"].keys()

dict_keys(['_attrs', 'remote_folder', 'remote_stash', 'retrieved', 'std_output', 'log_output', 'results_dict', 'summary', 'stats_file', 'traj_file', 'traj_output', 'final_structure'])

In [55]:
print('outputs of mdCalc:', md_calc.outputs)


outputs of mdCalc: TaskSocketNamespace(name='outputs', sockets=['remote_folder', 'remote_stash', 'retrieved', 'std_output', 'log_output', 'results_dict', 'summary', 'stats_file', 'traj_file', 'traj_output', 'final_structure', '_outputs', '_wait'])


In [56]:
print(type(wg.tasks.md_calc.outputs.traj_file.value))

<class 'aiida.orm.nodes.data.singlefile.SinglefileData'>


In [63]:
print('outputs of descriptors_task:', descriptors_calc.outputs)


outputs of descriptors_task: TaskSocketNamespace(name='outputs', sockets=['structs', '_outputs', '_wait'])


In [57]:
# To find inputs/outputs of mdcalc uncomment following: 
# mdCalc.get_description()["spec"]["outputs"].keys()

In [58]:
# from aiida.orm import load_node
# traj = load_node(PK) 

# # print(len(list(iread(wg.tasks.md_calc.outputs.traj_file.value.as_path()))))
# traj_length = (wg.tasks.md_calc.outputs.traj_file.value.as_path()).numsteps

# Loop Descriptor

 the workflow does not have an output yet for wg.tasks.md.outputs.traj_file 

 usually you can just pass in a socket. But because we have to get the path to read it you have to create a task
 
 you need to create a task which waites for md_task to run and then gets that output

In [None]:
# from sample_split import process_and_split_data
# # @task.calcfunction(outputs = ["test_file", "train_file", "valid_file"])
# # def create_split_files(**inputs):
     
# #     files = process_and_split_data(**inputs)

# #     return {
# #         "train_file": SinglefileData(files["train_file"]),
# #         "test_file": SinglefileData(files["test_file"]),
# #         "valid_file": SinglefileData(files["valid_file"])
# #     }

# @task.calcfunction(outputs=["test_file", "train_file", "valid_file"])
# def create_split_files(trajectory_data, config_types, prefix, scale, append_mode):
#     """Create split files using plain Python function call"""
    
#     # Call the plain Python function directly (not as a task) ??
#     files = process_and_split_data(
#         trajectory_data=trajectory_data,
#         config_types=config_types,
#         prefix=prefix,
#         scale=scale,
#         append_mode=append_mode
#     )
    
#     return {
#         "train_file": SinglefileData(files["train_file"]),
#         "test_file": SinglefileData(files["test_file"]),
#         "valid_file": SinglefileData(files["valid_file"])
#     }