# WorkGraph Tutorial: MD Simulation with DFT Labeling and MLIP Fine-tuning

This tutorial demonstrates building complex workflows with AiiDA WorkGraph making use of a remote system (SCARF), and using Quantum ESPRESSO for generating DFT-labeled training data and fine-tuning a machine learning interatomic potential (MLIP).

## Workflow Steps

1. **MD Simulation** → Generate diverse NaCl configurations (6 snapshots)
2. **Descriptors** → Calculate MLIP features for filtering (dynamic tasks)
3. **Data Split** → FPS filtering into train/validation/test sets
4. **DFT Labeling** → Quantum ESPRESSO energies, forces, stresses (nested dynamic tasks)
5. **Training Files** → Format ExtXYZ files with DFT labels
6. **Fine-tuning** → Train MACE model on DFT data

## WorkGraph Features

- **Dynamic task generation**: Create tasks at runtime based on trajectory length
- **Nested namespaces**: Organize complex hierarchical outputs
- **Automatic dependencies**: Tasks execute when inputs become available
- **Parallel execution**: Independent calculations run simultaneously
- **Provenance tracking**: Full AiiDA database tracking of all calculations and data for reproducibility
- **Mixed local and remote execution**: MD runs locally while DFT calculations execute on HPC clusters

The initial setup is very similar to the other tutorials, such as ../calculations/singlepoint.ipynb, which goes into more detail about what each step is doing


### Load AiiDA Profile

In [None]:
from aiida import load_profile
load_profile()

### Verify SSSP Pseudopotentials

Quantum ESPRESSO requires pseudopotentials for DFT calculations. We check that the SSSP library is installed:

In [None]:
from aiida.orm import load_group
try:
    pseudo_family = load_group('SSSP/1.3/PBE/efficiency')
    print(f" SSSP pseudopotentials installed ({len(pseudo_family.nodes)} pseudos)")
except Exception:
    print(" SSSP not installed, Run: aiida-pseudo install sssp -v 1.3 -p efficiency")

### Load ML Potential Model
We download a pre-trained MACE model from the janus-core repository. This foundation model will be:
- Used for the initial MD simulation
- Used to calculate descriptors for filtering
- Fine-tuned on our DFT-labeled data

The `ModelData.from_uri()` function automatically caches the model locally to avoid repeated downloads.

In [None]:
from aiida_mlip.data.model import ModelData
uri = "https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model"
model = ModelData.from_uri(uri, architecture="mace_mp", cache_dir="mlips")

### Define Initial Structure

We create a rocksalt NaCl structure with a lattice parameter of 5.63 Å. This serves as the starting point for our MD simulation.

In [None]:
from aiida.orm import StructureData
from ase.build import bulk
from ase.io import read, iread

# structure = StructureData(ase=read("Structures/qmof-ffeef76.cif"))
init_structure = StructureData(ase=bulk("NaCl", "rocksalt", 5.63))

### Load Computational Codes

We need two codes:
- **janus@localhost**: For MLIP calculations (MD, descriptors, training) - runs locally
- **qe@scarf**: For DFT calculations with Quantum ESPRESSO - runs on HPC cluster

See examples>tutorials>aiida_setup for setting up janus code and external computers.

In [None]:
from aiida.orm import load_code
janus_code = load_code("janus@localhost")
qe_code = load_code("qe@scarf")

### Define Calculation Types

In [None]:
from aiida.plugins import CalculationFactory

mdCalc = CalculationFactory("mlip.md")
descriptorsCalc = CalculationFactory("mlip.descriptors")
trainCalc = CalculationFactory("mlip.train")

## Task Decorators & Dynamic Outputs

WorkGraph provides decorators to convert functions and calculations into workflow tasks. There are three main types:

- `task(CalculationClass)`: Wrap AiiDA calculations
- `@task.calcfunction`: Wrap AiiDA calcfunctions
- `@task.graph`: Create a sub-workflow (task graph)

**WorkGraph Feature**: The `@task.graph` decorator with `dynamic(SinglefileData)` outputs allows creating variable numbers of tasks at runtime.

**Scientific Purpose**: Calculate MLIP descriptors per structure for filtering diverse configurations. Descriptors characterize local atomic environments and enable intelligent selection of training data.

The loop creates one descriptor calculation per MD snapshot, running in parallel:

In [None]:
from aiida_workgraph import WorkGraph, task, dynamic, namespace
from aiida.orm import SinglefileData

descriptorsTask = task(descriptorsCalc)

@task.graph(outputs=namespace(final_structs=dynamic(SinglefileData)))
def descriptors_task(
    code,
    model,
    device,
    arch,
    file,
):
    results = {}
    
    with file.as_path() as path:
        for i, structs in enumerate(iread(path)):
            structure = StructureData(ase=structs)

            desc_calc = descriptorsTask(
                code=code,
                model=model,
                device=device,
                arch=arch,
                struct=structure,
                calc_per_element=True,
                metadata={"options": {"resources": {"num_machines": 1}}}
            )

            results[f"struct{i}"] = desc_calc.xyz_output

    return {"final_structs": results}

## Nested Dynamic Tasks: DFT Calculations

**WorkGraph Feature**: `namespace()` with nested `dynamic()` outputs organizes hierarchical results. Each QE calculation returns both trajectory (forces/stress) and parameters (energy).

**Scientific Purpose**: Run SCF DFT calculations with Quantum ESPRESSO to obtain reference energies, forces, and stresses for training. Uses SSSP pseudopotentials with recommended cutoffs and Γ-point sampling.

In [None]:
from aiida.orm import (
    TrajectoryData,
    StructureData, 
    load_group, 
    KpointsData, 
    InstalledCode, 
    List, 
    Dict,
)
from pathlib import Path
import yaml
from aiida_quantumespresso.calculations.pw import PwCalculation

PwTask = task(PwCalculation)
qe_output = namespace(trajectory=TrajectoryData, parameters=Dict)

@task.graph(
    outputs=namespace(
        test_file=dynamic(qe_output),
        train_file=dynamic(qe_output),
        valid_file=dynamic(qe_output),
    )
)
def qe(
    code: InstalledCode,
    kpoints_mesh: List,
    task_metadata: Dict,
    test_file: SinglefileData,
    train_file: SinglefileData,
    valid_file: SinglefileData
    ):

    kpoints = KpointsData()
    kpoints.set_kpoints_mesh(kpoints_mesh)

    pseudo_family = load_group('SSSP/1.3/PBE/efficiency')
    
    files = {"test_file": test_file, "train_file": train_file, "valid_file": valid_file}
    results = {"test_file": {}, "train_file": {}, "valid_file": {}}
    
    for file_name, file in files.items():
        with file.as_path() as path:
            for i, structs in enumerate(iread(path, format="extxyz")):
                
                structure = StructureData(ase=structs)
                pseudos = pseudo_family.get_pseudos(structure=structure)

                ecutwfc, ecutrho = pseudo_family.get_recommended_cutoffs(
                    structure=structure,
                    unit='Ry',
                )

                pw_params = {
                    "CONTROL": {
                        "calculation": "scf",
                        'tprnfor': True,
                        'tstress': True,
                    },
                    "SYSTEM": {
                        "ecutwfc": ecutwfc,
                        "ecutrho": ecutrho,
                    },
                }
                
                qe_task = PwTask(
                    code=code,
                    parameters=pw_params,
                    kpoints=kpoints,
                    pseudos=pseudos,
                    metadata=task_metadata.value,
                    structure=structure,
                )
                
                structfile = f"{file_name}.struct{i}"

               
                results[file_name][f"struct_{i}"] = {
                    "trajectory": qe_task.output_trajectory,
                    "parameters": qe_task.output_parameters,
                }
    return results

## CalcFunction Tasks: Training File Creation

**WorkGraph Feature**: `@task.calcfunction` wraps Python functions with provenance tracking. Receives nested QE outputs and processes them.

**Scientific Purpose**: Extract DFT results (energy, forces, stress) from QE trajectory data, convert units (stress from GPa to eV/Å³), and write ExtXYZ files for MACE training.

In [None]:
from aiida_mlip.data.config import JanusConfigfile
from aiida.orm import Dict
from ase.io import write
from ase import units

@task.calcfunction(outputs = ["JanusConfigfile"])
def create_train_file(**inputs):

    training_files = {}
    
    for file_name, structs in inputs.items():
        path = Path(f"mlip_{file_name}.extxyz")

        for struct_out_params in structs.values():
            
            trajectory = struct_out_params["trajectory"]

            fileStructure = trajectory.get_structure(index=0)
            fileAtoms = fileStructure.get_ase()

            stress = trajectory.arrays["stress"][0]
            converted_stress = stress * units.GPa
            fileAtoms.info["qe_stress"] = converted_stress

            fileAtoms.info["units"] = {
                "energy": "eV",
                "forces": "ev/Ang",
                "stress": "ev/Ang^3"
            }
            fileAtoms.set_array("qe_forces", trajectory.arrays["forces"][0])

            parameters = struct_out_params["parameters"]
            fileParams = parameters.get_dict()
            fileAtoms.info["qe_energy"] = fileParams["energy"]
            write(path, fileAtoms, append=True)

        training_files[file_name] = str(path.resolve())

    with open("JanusConfigfile.yml", "a") as f:
        yaml.safe_dump(training_files, f, sort_keys=False)

    return{'JanusConfigfile': JanusConfigfile(Path("JanusConfigfile.yml").resolve())}

## Data Filtering and Splitting

In [None]:
from aiida.orm import Int
from sample_split import process_and_split_data

@task.calcfunction(outputs = ["test_file", "train_file", "valid_file"])
def create_qe_files(**inputs):

    n_samples = Int(len(inputs['trajectory_data']))
    print(n_samples)

    files = process_and_split_data( 
        **inputs,
        n_samples=n_samples,
    )

    return {
        "train_file": SinglefileData(files["train_file"]),
        "test_file": SinglefileData(files["test_file"]),
        "valid_file": SinglefileData(files["valid_file"])
    }

## Input Parameters

In [None]:
from aiida.orm import Str, Float, Bool, Dict
inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cpu"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
    "ensemble": Str("NVT"),
    "struct": init_structure,
    "md_kwargs": Dict(
        {
            "steps": 10,
            "traj-every": 2
        }
    )
}
calc_inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cpu"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
}
split_task_inputs = {
    "config_types": Str(""),
    "prefix": Str(""),
    "scale": Float(1.0e5),
    "append_mode": Bool(False),
}
qe_inputs = {
    "task_metadata": Dict({
            "options": {
                "resources": {
                    "num_machines": 1,
                    "num_mpiprocs_per_machine": 32,
                },
                "max_wallclock_seconds": 3600,
                "queue_name": "scarf",
                "qos": "scarf",
                "environment_variables": {},
                "withmpi": True,
                "prepend_text": """
                    module purge
                    module use /work4/scd/scarf562/eb-common/modules/all
                    module load amd-modules
                    module load QuantumESPRESSO/7.2-foss-2023a
                """,
                "append_text": "",
            },
    }),
    "kpoints_mesh": List([1, 1, 1]),
    "code": qe_code,
}

## Assembling the WorkGraph

In [None]:
mdTask = task(mdCalc)
trainTask = task(trainCalc)

@task.graph()
def md_training_workflow(
    janus_code,
    qe_code,
    model,
    md_kwargs,
    ensemble,
    init_structure,
    split_inputs,
    qe_metadata,
    qe_kpoints,
):
    md_calc = mdTask(
        code=janus_code,
        model=model,
        arch=Str(model.architecture),
        device=Str("cpu"),
        metadata={"options": {"resources": {"num_machines": 1}}},
        ensemble=ensemble,
        struct=init_structure,
        md_kwargs=md_kwargs
    )

    descriptors_calc = descriptors_task(
        code=janus_code,
        model=model,
        device=Str("cpu"),
        arch=Str(model.architecture),
        file=md_calc.traj_file
    )
   
    split_task = create_qe_files(
        **split_inputs,
        trajectory_data=descriptors_calc.final_structs,
    )
    
    qe_task = qe(
        code=qe_code,
        kpoints_mesh=qe_kpoints,
        task_metadata=qe_metadata,
        test_file=split_task.test_file,
        train_file=split_task.train_file,
        valid_file=split_task.valid_file,
    )

    training_files = create_train_file(
        test_file=qe_task.test_file,
        train_file=qe_task.train_file,
        valid_file=qe_task.valid_file,
    )

    train_task = trainTask(
        mlip_config=training_files.JanusConfigfile,
        code=janus_code,
        foundation_model=model,
        metadata={"options": {"resources": {"num_machines": 1}}},
        fine_tune=True,
    )

### Build the WorkGraph

In [None]:
wg = md_training_workflow.build(
    janus_code=janus_code,
    qe_code=qe_code,
    model=model,
    md_kwargs=inputs["md_kwargs"],
    ensemble=inputs["ensemble"],
    init_structure=init_structure,
    split_inputs=split_task_inputs,
    qe_metadata=qe_inputs["task_metadata"],
    qe_kpoints=qe_inputs["kpoints_mesh"],
)

### Run the Workflow

In [None]:
wg.run()

### Visualize

In [None]:
wg

## Results

**Scientific**: Fine-tuned MACE model for NaCl with DFT-labeled training data (energies, forces, stresses)

**WorkGraph**: Demonstrated dynamic task creation, nested namespaces, automatic dependencies, and parallel execution with full provenance tracking