# WorkGraph example to run Molecular Dynamics

## Aim

As an example, we start from a structure, run an md simulation, compute descriptors, and then use a filtering function to split the resulting structures into `train.xyz`, `test.xyz`, and `valid.xyz`.

Load the aiida profile, structure, model and code:

In [None]:
from aiida import load_profile
load_profile()

In [None]:
from aiida_mlip.data.model import ModelData
uri = "https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model"
model = ModelData.from_uri(uri, architecture="mace_mp", cache_dir="mlips")

In [None]:
from aiida.orm import StructureData
from ase.build import bulk
from ase.io import read, iread

# structure = StructureData(ase=read("Structures/qmof-ffeef76.cif"))
init_structure = StructureData(ase=bulk("NaCl", "rocksalt", 5.63))

In [None]:
from aiida.orm import load_code
janus_code = load_code("janus@localhost")

In [None]:
from aiida.orm import Str, Float, Bool, Int, Dict
inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cpu"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
    "ensemble": Str("NVT"),
    "struct": init_structure,
    "md_kwargs": Dict(
        {
            "steps": 10,
            "traj-every": 2
        }
    )
}

In [None]:
from aiida.plugins import CalculationFactory

mdCalc = CalculationFactory("mlip.md")
descriptorsCalc = CalculationFactory("mlip.descriptors")

In [None]:
from aiida_workgraph import WorkGraph, task
from aiida_workgraph.manager import get_current_graph

@task.graph(outputs=["structs"])
def descriptors_task(
    code,
    model,
    device,
    arch,
    file,
):
    descriptorsCalc = CalculationFactory("mlip.descriptors")
    wg = get_current_graph()
    final_structures={}

    with file.as_path() as path:
        for i, structs in enumerate(iread(path)):
            structure = StructureData(ase=structs)

            desc_calc = wg.add_task(
                descriptorsCalc,
                code=code,
                model=model,
                device=device,
                arch=arch,
                struct=structure,
                metadata={"options": {"resources": {"num_machines": 1}}}
            )

            final_structures[f"structs{i}"] = structs
            #   final_structures[f"structs{i}"] = desc_calc.outputs.xyz_output

    wg.update_ctx({
        "structs":final_structures
    })

    return{
        "structs": wg.ctx.structs
    }

In [None]:
from sample_split import process_and_split_data
# @task.calcfunction(outputs = ["test_file", "train_file", "valid_file"])
# def create_split_files(**inputs):
     
#     files = process_and_split_data(**inputs)

#     return {
#         "train_file": SinglefileData(files["train_file"]),
#         "test_file": SinglefileData(files["test_file"]),
#         "valid_file": SinglefileData(files["valid_file"])
#     }

@task.calcfunction(outputs=["test_file", "train_file", "valid_file"])
def create_split_files(trajectory_data, config_types, prefix, scale, append_mode):
    """Create split files using plain Python function call"""
    
    # Call the plain Python function directly (not as a task) ??
    files = process_and_split_data(
        trajectory_data=trajectory_data,
        config_types=config_types,
        prefix=prefix,
        scale=scale,
        append_mode=append_mode
    )
    
    return {
        "train_file": SinglefileData(files["train_file"]),
        "test_file": SinglefileData(files["test_file"]),
        "valid_file": SinglefileData(files["valid_file"])
    }

In [None]:
with WorkGraph("MD") as wg:

    md_calc = wg.add_task(
        mdCalc,
        name="md_calc",
        **inputs
    )

    descriptors_calc = wg.add_task(
        descriptors_task,
        code=janus_code,
        model=model,
        device=Str("cpu"),
        arch=Str(model.architecture),
        file=md_calc.outputs.traj_file
    )

    split_task = wg.add_task(
        create_split_files,
        name="split_data",
        config_types=Str(""),
        prefix=Str(""),
        scale=Float(1.0e5),
        append_mode=Bool(False),
        trajectory_data=descriptors_calc.outputs.structs,  
    )

    # split_task = wg.add_task(
    #     process_and_split_data,
    #     config_types= Str(""),
    #     n_samples=Int(len(descriptors_calc.outputs.structs)),
    #     prefix= Str(""),
    #     scale= Float(1.0e5),
    #     append_mode= Bool(False),
    #     # create_aiida_files, 
    #     # **split_task_inputs,
    #     trajectory_data=descriptors_calc.outputs.structs,
    #     # n_samples= Int(len(final_structures)),
    # )

    


In [None]:
wg.run()

In [None]:
wg

# Additional information

In [None]:
print(descriptorsCalc.spec().inputs)


In [None]:
print(type(wg.tasks.md_calc.outputs.traj_file.value))

In [None]:
mdCalc.get_description()["spec"]["outputs"].keys()

In [None]:
print('outputs of mdCalc:', md_calc.outputs)


In [None]:
print(type(wg.tasks.md_calc.outputs.traj_file.value))

In [None]:
# To find inputs/outputs of mdcalc uncomment following: 
# mdCalc.get_description()["spec"]["outputs"].keys()

In [None]:
# from aiida.orm import load_node
# traj = load_node(PK) 

# # print(len(list(iread(wg.tasks.md_calc.outputs.traj_file.value.as_path()))))
# traj_length = (wg.tasks.md_calc.outputs.traj_file.value.as_path()).numsteps

# Loop Descriptor

 the workflow does not have an output yet for wg.tasks.md.outputs.traj_file 

 usually you can just pass in a socket. But because we have to get the path to read it you have to create a task
 
 you need to create a task which waites for md_task to run and then gets that output