# WorkGraph Example: MD Simulation with DFT Labeling and Model Fine-tuning

## Overview

This tutorial demonstrates an end-to-end workflow for generating training data and fine-tuning a machine learning interatomic potential (MLIP). The workflow consists of:

1. **Molecular Dynamics (MD)**: Generate diverse structural configurations of NaCl
2. **Descriptor Calculation**: Compute MLIP descriptors for structure filtering
3. **Data Splitting**: Filter and split structures into train/validation/test sets
4. **DFT Labeling**: Calculate accurate energies, forces, and stresses using Quantum ESPRESSO
5. **Model Fine-tuning**: Train a MACE model on the DFT-labeled data


## Workflow Architecture

```
Initial Structure (NaCl)
        ↓
    [MD Simulation] → Trajectory (6 snapshots)
        ↓
    [Descriptors] → Structure features
        ↓
    [Filter & Split] → Train/Valid/Test sets
        ↓
    [Quantum ESPRESSO] → DFT energies/forces/stresses
        ↓
    [Create Training Files] → ExtXYZ with DFT labels
        ↓
    [Fine-tune MACE] → Improved MLIP model
```


### Load AiiDA Profile

In [1]:
from aiida import load_profile
load_profile()

Profile<uuid='60b17659a9844c4bbd3bef8de0a8f417' name='presto'>

### Verify SSSP Pseudopotentials

Quantum ESPRESSO requires pseudopotentials for DFT calculations. We check that the SSSP library is installed:

In [2]:
from aiida.orm import load_group
try:
    pseudo_family = load_group('SSSP/1.3/PBE/efficiency')
    print(f" SSSP pseudopotentials installed ({len(pseudo_family.nodes)} pseudos)")
except Exception:
    print(" SSSP not installed, Run: aiida-pseudo install sssp -v 1.3 -p efficiency")

 SSSP pseudopotentials installed (103 pseudos)


### Load ML Potential Model
We download a pre-trained MACE model from the janus-core repository. This foundation model will be:
- Used for the initial MD simulation
- Used to calculate descriptors for filtering
- Fine-tuned on our DFT-labeled data

The `ModelData.from_uri()` function automatically caches the model locally to avoid repeated downloads.

In [3]:
from aiida_mlip.data.model import ModelData
uri = "https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model"
model = ModelData.from_uri(uri, architecture="mace_mp", cache_dir="mlips")

### Define Initial Structure

We create a rocksalt NaCl structure with a lattice parameter of 5.63 Å. This serves as the starting point for our MD simulation.

**Why NaCl?** It's a simple ionic crystal that demonstrates:
- Multi-element systems
- Different atomic environments
- Structural variations during MD

In [4]:
from aiida.orm import StructureData
from ase.build import bulk
from ase.io import read, iread

# structure = StructureData(ase=read("Structures/qmof-ffeef76.cif"))
init_structure = StructureData(ase=bulk("NaCl", "rocksalt", 5.63))

### Load Computational Codes

We need two codes:
- **janus@localhost**: For MLIP calculations (MD, descriptors, training) - runs locally
- **qe@scarf**: For DFT calculations with Quantum ESPRESSO - runs on HPC cluster

See examples>tutorials>aiida_setup for setting up janus code and external computers.

In [5]:
from aiida.orm import load_code
janus_code = load_code("janus@localhost")
qe_code = load_code("qe@scarf")

### Define Calculation Types

We define three calculation types from the `aiida-mlip` plugin:
- `mdCalc`: Molecular dynamics simulation
- `descriptorsCalc`: Compute structural descriptors
- `trainCalc`: Fine-tune the MLIP model

In [6]:
from aiida.plugins import CalculationFactory

mdCalc = CalculationFactory("mlip.md")
descriptorsCalc = CalculationFactory("mlip.descriptors")
trainCalc = CalculationFactory("mlip.train")

## Task 1: Descriptors Calculation

### Purpose
After MD generates trajectory snapshots, we calculate descriptors for each structure. These descriptors:
- Characterize the local atomic environment
- Enable intelligent filtering of structures
- Are calculated per-element (separate Na and Cl descriptors)

### Implementation Details
This task uses a **dynamic graph** pattern:
- Loops over each structure in the trajectory
- Creates a separate `Descriptors` calculation task for each
- Stores results in the workflow context
- Returns all structures as `final_structs` dictionary



In [7]:
from aiida_workgraph import WorkGraph, task, dynamic, namespace
from aiida.orm import SinglefileData

descriptorsTask = task(descriptorsCalc)

@task.graph(outputs=namespace(final_structs=dynamic(SinglefileData)))
def descriptors_task(
    code,
    model,
    device,
    arch,
    file,
):
    results = {}
    
    with file.as_path() as path:
        for i, structs in enumerate(iread(path)):
            structure = StructureData(ase=structs)

            desc_calc = descriptorsTask(
                code=code,
                model=model,
                device=device,
                arch=arch,
                struct=structure,
                calc_per_element=True,
                metadata={"options": {"resources": {"num_machines": 1}}}
            )

            results[f"struct{i}"] = desc_calc.xyz_output

    return {"final_structs": results}

defining outputnode


## Task 2: Quantum ESPRESSO DFT Calculations

### Purpose
Run high-accuracy DFT calculations to obtain reference data:
- **Total energy**: Ground state energy of each structure
- **Forces**: Atomic forces for training
- **Stress tensor**: Needed for NPT dynamics and equation of state

### Implementation Details

This task processes three files (train, test, validation) and runs DFT calculations on each structure:

**Pseudopotentials**: Automatically retrieves appropriate pseudopotentials for Na and Cl from SSSP library

**Cutoff energies**: Uses recommended values from SSSP for convergence:
- `ecutwfc`: Plane-wave kinetic energy cutoff
- `ecutrho`: Charge density cutoff

**K-points**: Uses Γ-point only (1×1×1 mesh) - sufficient for the supercell size

**DFT Parameters**:
- SCF calculation (self-consistent field)
- `tprnfor=True`: Calculate forces
- `tstress=True`: Calculate stress tensor



In [8]:
from aiida.orm import (
    TrajectoryData,
    StructureData, 
    load_group, 
    KpointsData, 
    InstalledCode, 
    List, 
    Dict,
)
from pathlib import Path
import yaml
from aiida_quantumespresso.calculations.pw import PwCalculation

PwTask = task(PwCalculation)
qe_output = namespace(trajectory=TrajectoryData, parameters=Dict)

@task.graph(
    outputs=namespace(
        test_file=dynamic(qe_output),
        train_file=dynamic(qe_output),
        valid_file=dynamic(qe_output),
    )
)
def qe(
    code: InstalledCode,
    kpoints_mesh: List,
    task_metadata: Dict,
    test_file: SinglefileData,
    train_file: SinglefileData,
    valid_file: SinglefileData
    ):

    kpoints = KpointsData()
    kpoints.set_kpoints_mesh(kpoints_mesh)

    pseudo_family = load_group('SSSP/1.3/PBE/efficiency')
    
    files = {"test_file": test_file, "train_file": train_file, "valid_file": valid_file}
    results = {"test_file": {}, "train_file": {}, "valid_file": {}}
    
    for file_name, file in files.items():
        with file.as_path() as path:
            for i, structs in enumerate(iread(path, format="extxyz")):
                
                structure = StructureData(ase=structs)
                pseudos = pseudo_family.get_pseudos(structure=structure)

                ecutwfc, ecutrho = pseudo_family.get_recommended_cutoffs(
                    structure=structure,
                    unit='Ry',
                )

                pw_params = {
                    "CONTROL": {
                        "calculation": "scf",
                        'tprnfor': True,
                        'tstress': True,
                    },
                    "SYSTEM": {
                        "ecutwfc": ecutwfc,
                        "ecutrho": ecutrho,
                    },
                }
                
                qe_task = PwTask(
                    code=code,
                    parameters=pw_params,
                    kpoints=kpoints,
                    pseudos=pseudos,
                    metadata=task_metadata.value,
                    structure=structure,
                )
                
                structfile = f"{file_name}.struct{i}"

               
                results[file_name][f"struct_{i}"] = {
                    "trajectory": qe_task.output_trajectory,
                    "parameters": qe_task.output_parameters,
                }
    return results

## Task 3: Create Training Files

Extract DFT results from Quantum ESPRESSO and format them for MLIP training.

1. **Extract data** from QE output:
   - Energy from `output_parameters`
   - Forces from `output_trajectory`
   - Stress tensor from `output_trajectory`

2. **Unit conversion**:
   - Stress: Convert from Quantum ESPRESSO units to eV/Å³
   - Energy: Already in eV
   - Forces: Already in eV/Å

3. **Create ExtXYZ files**:
   - `mlip_train_file.extxyz`: Structures for training
   - `mlip_valid_file.extxyz`: Structures for validation
   - `mlip_test_file.extxyz`: Structures for testing

4. **Update config file**:
   - Adds file paths to `JanusConfigfile.yml`
   - This config file contains all training parameters


The extended XYZ format includes:
- Atomic positions
- Lattice vectors
- DFT energies in `info` dict
- Forces as atomic arrays
- Stress tensor in `info` dict
- Unit specifications

In [9]:
from aiida_mlip.data.config import JanusConfigfile
from aiida.orm import Dict
from ase.io import write
from ase import units

@task.calcfunction(outputs = ["JanusConfigfile"])
def create_train_file(**inputs):

    training_files = {}
    
    for file_name, structs in inputs.items():
        path = Path(f"mlip_{file_name}.extxyz")

        for struct_out_params in structs.values():
            
            trajectory = struct_out_params["trajectory"]

            fileStructure = trajectory.get_structure(index=0)
            fileAtoms = fileStructure.get_ase()

            stress = trajectory.arrays["stress"][0]
            converted_stress = stress * units.GPa
            fileAtoms.info["qe_stress"] = converted_stress

            fileAtoms.info["units"] = {
                "energy": "eV",
                "forces": "ev/Ang",
                "stress": "ev/Ang^3"
            }
            fileAtoms.set_array("qe_forces", trajectory.arrays["forces"][0])

            parameters = struct_out_params["parameters"]
            fileParams = parameters.get_dict()
            fileAtoms.info["qe_energy"] = fileParams["energy"]
            write(path, fileAtoms, append=True)

        training_files[file_name] = str(path.resolve())

    with open("JanusConfigfile.yml", "a") as f:
        yaml.safe_dump(training_files, f, sort_keys=False)

    return{'JanusConfigfile': JanusConfigfile(Path("JanusConfigfile.yml").resolve())}

## Task 4: Data Filtering and Splitting

Split the MD trajectory into train/validation/test sets using FPS (Farthest Point Sampling).

Note: The warnings about `k is too large` appear when you have fewer structures than the target split size - the algorithm automatically adjusts.


In [10]:
from aiida.orm import Int
from sample_split import process_and_split_data

@task.calcfunction(outputs = ["test_file", "train_file", "valid_file"])
def create_qe_files(**inputs):

    n_samples = Int(len(inputs['trajectory_data']))
    print(n_samples)

    files = process_and_split_data( 
        **inputs,
        n_samples=n_samples,
    )

    return {
        "train_file": SinglefileData(files["train_file"]),
        "test_file": SinglefileData(files["test_file"]),
        "valid_file": SinglefileData(files["valid_file"])
    }

## Configure Input Parameters


**NVT Ensemble**: 
- Maintains constant Number of particles, Volume, and Temperature
- Suitable for sampling structural variations
- Uses a thermostat to control temperature

See https://stfc.github.io/janus-core/tutorials/cli/md.html for other ensemvles.

In [11]:
from aiida.orm import Str, Float, Bool, Dict
inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cpu"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
    "ensemble": Str("NVT"),
    "struct": init_structure,
    "md_kwargs": Dict(
        {
            "steps": 10,
            "traj-every": 2
        }
    )
}
calc_inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cpu"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
}
split_task_inputs = {
    "config_types": Str(""),
    "prefix": Str(""),
    "scale": Float(1.0e5),
    "append_mode": Bool(False),
}
qe_inputs = {
    "task_metadata": Dict({
            "options": {
                "resources": {
                    "num_machines": 1,
                    "num_mpiprocs_per_machine": 32,
                },
                "max_wallclock_seconds": 3600,
                "queue_name": "scarf",
                "qos": "scarf",
                "environment_variables": {},
                "withmpi": True,
                "prepend_text": """
                    module purge
                    module use /work4/scd/scarf562/eb-common/modules/all
                    module load amd-modules
                    module load QuantumESPRESSO/7.2-foss-2023a
                """,
                "append_text": "",
            },
    }),
    "kpoints_mesh": List([1, 1, 1]),
    "code": qe_code,
}


## Build and Run the WorkGraph

### WorkGraph Structure

The workflow is defined using AiiDA WorkGraph, which:
- Automatically manages task dependencies
- Enables parallel execution where possible
- Tracks provenance of all calculations
- Handles job submission to HPC


In [12]:
mdTask = task(mdCalc)
trainTask = task(trainCalc)

@task.graph()
def md_training_workflow(
    janus_code,
    qe_code,
    model,
    md_kwargs,
    ensemble,
    init_structure,
    split_inputs,
    qe_metadata,
    qe_kpoints,
):
    md_calc = mdTask(
        code=janus_code,
        model=model,
        arch=Str(model.architecture),
        device=Str("cpu"),
        metadata={"options": {"resources": {"num_machines": 1}}},
        ensemble=ensemble,
        struct=init_structure,
        md_kwargs=md_kwargs
    )

    descriptors_calc = descriptors_task(
        code=janus_code,
        model=model,
        device=Str("cpu"),
        arch=Str(model.architecture),
        file=md_calc.traj_file
    )
   
    split_task = create_qe_files(
        **split_inputs,
        trajectory_data=descriptors_calc.final_structs,
    )
    
    qe_task = qe(
        code=qe_code,
        kpoints_mesh=qe_kpoints,
        task_metadata=qe_metadata,
        test_file=split_task.test_file,
        train_file=split_task.train_file,
        valid_file=split_task.valid_file,
    )

    training_files = create_train_file(
        test_file=qe_task.test_file,
        train_file=qe_task.train_file,
        valid_file=qe_task.valid_file,
    )

    train_task = trainTask(
        mlip_config=training_files.JanusConfigfile,
        code=janus_code,
        foundation_model=model,
        metadata={"options": {"resources": {"num_machines": 1}}},
        fine_tune=True,
    )

### Build the workgraph

In [13]:
wg = md_training_workflow.build(
    janus_code=janus_code,
    qe_code=qe_code,
    model=model,
    md_kwargs=inputs["md_kwargs"],
    ensemble=inputs["ensemble"],
    init_structure=init_structure,
    split_inputs=split_task_inputs,
    qe_metadata=qe_inputs["task_metadata"],
    qe_kpoints=qe_inputs["kpoints_mesh"],
)

### Execution

`wg.run()` submits the entire workflow and monitors its progress.

In [14]:
wg.run()

12/24/2025 10:06:57 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7535|WorkGraphEngine|continue_workgraph]: tasks ready to run: MD
12/24/2025 10:06:57 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7535|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 7538


/home/qoj42292/.aiida/scratch/presto/e6/99/6657-765e-4311-bf7a-a6f8dec4d5ae md-summary.yml


12/24/2025 10:07:13 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7535|WorkGraphEngine|update_task_state]: Task: MD, type: CALCJOB, finished.
12/24/2025 10:07:15 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7535|WorkGraphEngine|continue_workgraph]: tasks ready to run: descriptors_task
12/24/2025 10:07:17 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7535|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 7555
12/24/2025 10:07:18 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7555|WorkGraphEngine|continue_workgraph]: tasks ready to run: Descriptors,Descriptors1,Descriptors2,Descriptors3,Descriptors4,Descriptors5
12/24/2025 10:07:21 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7555|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 7560, 7565, 7570, 7575, 7580, 7585
12/24/2025 10:

uuid: da09e2c1-ac92-4b8d-8fea-a0db15725209 (unstored) value: 6
create files: train_file=PosixPath('train.xyz'), valid_file=PosixPath('valid.xyz') and test_file=PosixPath('test.xyz')
Processing: ('all', 'aiida'), 6 frames
  ('all', 'aiida'): total=6, train_target=4,                     vt_target=2


12/24/2025 10:08:05 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7535|WorkGraphEngine|update_task_state]: Task: create_qe_files, type: CALCFUNCTION, finished.
12/24/2025 10:08:05 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7535|WorkGraphEngine|continue_workgraph]: tasks ready to run: qe
12/24/2025 10:08:06 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7535|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 7633
12/24/2025 10:08:06 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7633|WorkGraphEngine|continue_workgraph]: tasks ready to run: PwCalculation,PwCalculation1,PwCalculation2,PwCalculation3,PwCalculation4,PwCalculation5
12/24/2025 10:08:08 AM <916425> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [7633|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 7635, 7637, 7639, 7641, 7643, 764

{}


### Visualize the WorkGraph

In [15]:
wg

NodeGraphWidget(settings={'minimap': True}, states={'graph_inputs': 'FINISHED', 'graph_outputs': 'FINISHED', '…

### After successful completion, you have:

1. **Fine-tuned MACE model**: Improved accuracy for NaCl systems
2. **Training datasets**: ExtXYZ files with DFT labels
3. **Complete provenance**: Full AiiDA history of all calculations