# WorkGraph example to run Molecular Dynamics

## Aim

As an example, we start from a structure, run an md simulation, compute descriptors, and then use a filtering function to split the resulting structures into `train.xyz`, `test.xyz`, and `valid.xyz`.

Load the aiida profile, structure, model and code:

In [None]:
from aiida import load_profile
load_profile()

In [None]:
from aiida_mlip.data.model import ModelData
uri = "https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model"
model = ModelData.from_uri(uri, architecture="mace_mp", cache_dir="mlips")

In [None]:
from aiida.orm import StructureData
from ase.build import bulk
from ase.io import read

# structure = StructureData(ase=read("Structures/qmof-ffeef76.cif"))
structure = StructureData(ase=bulk("NaCl", "rocksalt", 5.63))

In [None]:
from aiida.orm import load_code
janus_code = load_code("janus@localhost")

Inputs should include the model, code, metadata, and any other keyword arguments expected by the calculation we are running: 

to find out try using "janus md --help" which reveals that the arch, struct and ensemble is required, for which we shall use NVT. 
The traj output is recorded every 100 steps, which is nested in md_kwargs Dict.

In [None]:
from aiida.orm import Str, Float, Bool, Int, Dict
inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cpu"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
    "ensemble": Str("NVT"),
    "struct": structure,
    "md_kwargs": Dict(
        {
            "steps": 400,
        }
    )
}

We now load the calculations we want to run:

In [None]:
from aiida.plugins import CalculationFactory

mdCalc = CalculationFactory("mlip.md")
descriptorsCalc = CalculationFactory("mlip.descriptors")

Now we can create our WorkGraph. 

In [None]:
from aiida_workgraph import WorkGraph
wg = WorkGraph("MD_workgraph")
md_calc = wg.add_task(
    mdCalc,
    name="md_calc",
    **inputs
)


To know the inputs or outputs from the mdcalc, try using:

In [None]:
mdCalc.get_description()["spec"]["outputs"].keys()

Now run the descriptors calc

In [None]:
descriptors_calc = wg.add_task(
    descriptorsCalc,
    name="descriptors_calc",
    struct=md_calc.outputs.traj_output
)

Pass to process_and_split_data 

In [None]:
from sample_split import process_and_split_data

split_data = wg.add_task(
    process_and_split_data,
    name="split_data",
    struct=descriptors_calc.outputs.xyz_output
)


Visualise the workgraph

In [None]:
wg

We can visual the tasks of the Workgraph and run the tasks

In [None]:
wg.tasks

In [None]:
wg.tasks["descriptors_calc"].outputs

In [None]:
# wg.run()