In [1]:
from aiida import load_profile  
load_profile()

Profile<uuid='1ad5c4ff2c1141ee9b4a511ac6859016' name='presto'>

In [2]:
from aiida.orm import load_code
from aiida_mlip.data.model import ModelData

uri = "https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model"
model = ModelData.from_uri(uri, architecture="mace_mp")

janus_code = load_code("janus@localhost")
qe_code = load_code("qe@localhost")

In [3]:
from aiida.plugins import CalculationFactory
descriptorsCalc = CalculationFactory("mlip.descriptors")

In [None]:
from aiida_workgraph import WorkGraph, task
from pathlib import Path
from ase.io import read, write, iread
from ase import Atoms
import numpy as np
from aiida.orm import SinglefileData, Float, InstalledCode, List, Dict, KpointsData, StructureData, load_group, Str, Bool, Int
from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_workgraph.manager import get_current_graph
from ase import units
import tempfile
from pathlib import Path
from random import shuffle

@task.calcfunction(outputs=["scaled_structures"])
def create_scales(
    min_v: Float,
    max_v:Float,
    num_structs: int,
    **structures
): 
    lattice_scalars = np.cbrt(np.linspace(min_v.value, max_v.value, num_structs.value))
    scaled_structures = {}

    for structure in structures.values():

        atom = structure.get_ase()
        cell = atom.get_cell()

        for i, s in enumerate(lattice_scalars):
            scaled_atom = atom.copy()
            scaled_atom.set_cell(cell * s, scale_atoms=True)
            struct_data = f"struct{i}"
            scaled_structures[struct_data] = StructureData(ase=scaled_atom)
    
    print(scaled_structures)

    return {
        "scaled_structures": scaled_structures
    }

@task.graph()
def qe(
    code: InstalledCode,
    kpoints_mesh: List,
    task_metadata: Dict,
    **scaled_structures,
    ):

    wg = get_current_graph()

    kpoints = KpointsData()
    kpoints.set_kpoints_mesh(kpoints_mesh)

    pseudo_family = load_group('SSSP/1.3/PBE/efficiency')
    
    output_structures = {}

    for i, structs in scaled_structures:
        print(i)
        print(structs)
            
    #         structure = StructureData(ase=structs)
    #         pseudos = pseudo_family.get_pseudos(structure=structure)

    #         ecutwfc, ecutrho = pseudo_family.get_recommended_cutoffs(
    #             structure=structure,
    #             unit='Ry',
    #         )

    #         pw_params = {
    #             "CONTROL": {
    #                 "calculation": "scf",
    #                 'tprnfor': True,
    #                 'tstress': True,
    #             },
    #             "SYSTEM": {
    #                 "ecutwfc": ecutwfc,
    #                 "ecutrho": ecutrho,
    #             },
    #         }
            
    #         qe_task = wg.add_task(
    #             PwCalculation,
    #             code=code,
    #             parameters=pw_params,
    #             kpoints=kpoints,
    #             pseudos=pseudos,
    #             metadata=task_metadata.value,
    #             structure=structure,
    #         )

    #         output_structures[f"struct{i}"] = {
    #                 "trajectory":qe_task.outputs.output_trajectory,
    #                 "parameters": qe_task.outputs.output_parameters
    #             }
        
    #     wg.update_ctx({
    #         "structures": output_structures
    #     })

    # return {
    #     "structures": wg.ctx.structures,
    # }

@task.calcfunction(outputs=["test_file", "train_file", "valid_file"])
def create_train_files(**structures):

    # tmpfile = tempfile.NamedTemporaryFile(suffix=".extxyz")
    
    structures_stack = list(structures.keys())
    shuffle(structures_stack)
    
    n = len(structures_stack)
    i1 = int(n*0.7)
    i2 = int(n*0.9)

    training_split = {
        "test":structures_stack[:i1],
        "train":structures_stack[i1:i2],
        "valid":structures_stack[i2:]
    }

    files = {}

    for split, split_structures in training_split.items():
        tmpfile = tempfile.NamedTemporaryFile(suffix=f"{split}.extxyz")
        for struct in split_structures:

            trajectory = structures[struct]["trajectory"]
            fileStructure = trajectory.get_structure(index=0)
            fileAtoms = fileStructure.get_ase()

            stress = trajectory.arrays["stress"][0]
            converted_stress = stress * units.GPa
            fileAtoms.info["qe_stress"] = converted_stress

            fileAtoms.info["units"] = {"energy": "eV","forces": "ev/Ang","stress": "ev/Ang^3"}
            fileAtoms.set_array("qe_forces", trajectory.arrays["forces"][0])

            parameters = structures[struct]["parameters"]
            fileParams = parameters.get_dict()
            fileAtoms.info["qe_energy"] = fileParams["energy"]
            
            write(Path(tmpfile.name), fileAtoms, append=True)

        files[f"{split}_file"] = SinglefileData(tmpfile)
    
    for filename, file in files.items():
        with file.as_path() as path:
            num_structs = len(read(path, index=":"))
        print(f"{filename} has {num_structs} structures")

    return{
        "test_file": files["test_file"],
        "train_file": files["train_file"],
        "valid_file": files["valid_file"]
    }


In [5]:
calc_inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cuda"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
}

scales_inputs = {

}

qe_inputs = {
    "task_metadata": Dict({
            "options": {
                "resources": {
                    "tot_num_mpiprocs":1,
                    'num_mpiprocs_per_machine':1,
                    'num_cores_per_mpiproc':8,
                },
                "max_wallclock_seconds": 3600,
                "queue_name": "scarf",
                "qos": "scarf",
                "environment_variables": {},
                "withmpi": True,
                "prepend_text": """
                """,
                "append_text": "",
            },
    }),
    "kpoints_mesh": List([1, 1, 1]),
    "code": qe_code,
}

In [6]:
with WorkGraph("EOS_workflow") as wg:

    initial_structure = str(Path("../structures/NaCl-traj.xyz").resolve())

    structures = {}
    for i, struct in enumerate(iread(initial_structure)):
        structures[f"structs{i}"] = StructureData(ase=struct)

    scales_task = wg.add_task(
        create_scales,
        min_v= 0.95,
        max_v= 1.05,
        num_structs= 12,
        **structures
    )

    qe_task = wg.add_task(
        qe,
        **qe_inputs,
        scaled_structures=scales_task.outputs.scaled_structures
    )
    
    # train_task = wg.add_task(
    #     create_train_files,
    #     structures=qe_task.outputs.structures
    # )
 

In [7]:
wg.run()

11/21/2025 11:53:52 AM <97822> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [16387|WorkGraphEngine|continue_workgraph]: tasks ready to run: create_scales


{'struct0': <StructureData: uuid: 8584cf18-ed5e-4340-9628-60c648d6d597 (unstored)>, 'struct1': <StructureData: uuid: a66774bf-b5da-4bf9-ab31-a4cddf4a6bca (unstored)>, 'struct2': <StructureData: uuid: 6312ee83-1f49-426d-baa4-a11df8023a55 (unstored)>, 'struct3': <StructureData: uuid: 5fc3e5bd-d14c-452c-8ccd-78c9ea315ace (unstored)>, 'struct4': <StructureData: uuid: 2e7ce5f5-8deb-4d7c-a233-99b73f3937df (unstored)>, 'struct5': <StructureData: uuid: be3c16ce-2fd1-4c12-9a69-5448b0c552b4 (unstored)>, 'struct6': <StructureData: uuid: 20fd22d9-44b3-44e7-a8af-95cc0bf4b227 (unstored)>, 'struct7': <StructureData: uuid: 26019825-d84c-46b3-9313-d42ca5af7223 (unstored)>, 'struct8': <StructureData: uuid: 9880190a-b06c-40e0-91c7-44e6b1a5fcef (unstored)>, 'struct9': <StructureData: uuid: 13d3d141-ba09-4b93-9a4c-d64dac4e5e0d (unstored)>, 'struct10': <StructureData: uuid: 24b6c6c2-0a6c-42cc-b9a0-45698797cab6 (unstored)>, 'struct11': <StructureData: uuid: 00dcb033-752b-4640-a2a8-e2a060028985 (unstored)>}


11/21/2025 11:53:52 AM <97822> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [16387|WorkGraphEngine|update_task_state]: Task: create_scales, type: CALCFUNCTION, finished.
11/21/2025 11:53:52 AM <97822> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [16387|WorkGraphEngine|continue_workgraph]: tasks ready to run: qe
11/21/2025 11:53:52 AM <97822> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [16387|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 16404


<class 'ase.atoms.Atoms'>


11/21/2025 11:53:53 AM <97822> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [16404|WorkGraphEngine|continue_workgraph]: tasks ready to run: 
11/21/2025 11:53:53 AM <97822> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [16404|WorkGraphEngine|finalize]: Finalize workgraph.
11/21/2025 11:53:53 AM <97822> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [16387|WorkGraphEngine|update_task_state]: Task: qe, type: GRAPH, finished.
11/21/2025 11:53:53 AM <97822> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [16387|WorkGraphEngine|continue_workgraph]: tasks ready to run: 
11/21/2025 11:53:53 AM <97822> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [16387|WorkGraphEngine|finalize]: Finalize workgraph.


{}

In [8]:
wg.tasks.qe.inputs.scaled_structures

TaskSocketNamespace(name='scaled_structures', sockets=[])