In [None]:
from aiida import load_profile  
load_profile()

In [None]:
from aiida.orm import load_code
from aiida_mlip.data.model import ModelData

uri = "https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model"
model = ModelData.from_uri(uri, architecture="mace_mp")

janus_code = load_code("janus@localhost")
qe_code = load_code("qe@localhost")

In [None]:
from aiida.plugins import CalculationFactory
descriptorsCalc = CalculationFactory("mlip.descriptors")

In [None]:
from aiida_workgraph import WorkGraph, task
from pathlib import Path
from ase.io import read, write, iread
from ase import Atoms
import numpy as np
from aiida.orm import SinglefileData, Float, InstalledCode, List, Dict, KpointsData, StructureData, load_group, Str, Bool, Int
from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_workgraph.manager import get_current_graph
from ase import units
import tempfile
from pathlib import Path
from random import shuffle

@task.calcfunction(outputs=["scaled_file"])
def create_scales(
    min_v: Float,
    max_v:Float,
    num_structs: int,
    structures_path
): 
    tmpfile = tempfile.NamedTemporaryFile(suffix=".extxyz")

    atoms = read(structures_path.value, index=":")
    cell = atoms[0].get_cell()

    lattice_scalars = np.cbrt(np.linspace(min_v.value, max_v.value, num_structs.value))
    b = atoms.copy()
    for i, s in enumerate(lattice_scalars):
        b[i].set_cell(cell * s, scale_atoms=True)
        write(tmpfile.name,b[i],append=i>0)

    return {
        "scaled_file": SinglefileData(tmpfile.name)
    }

@task.graph(outputs = ["structures"])
def qe(
    code: InstalledCode,
    kpoints_mesh: List,
    task_metadata: Dict,
    scaled_file: SinglefileData,
    ):

    wg = get_current_graph()

    kpoints = KpointsData()
    kpoints.set_kpoints_mesh(kpoints_mesh)

    pseudo_family = load_group('SSSP/1.3/PBE/efficiency')
    
    output_structures = {}

    with scaled_file.as_path() as path:
        for i, structs in enumerate(iread(path, format="extxyz")):
            
            structure = StructureData(ase=structs)
            pseudos = pseudo_family.get_pseudos(structure=structure)

            ecutwfc, ecutrho = pseudo_family.get_recommended_cutoffs(
                structure=structure,
                unit='Ry',
            )

            pw_params = {
                "CONTROL": {
                    "calculation": "scf",
                    'tprnfor': True,
                    'tstress': True,
                },
                "SYSTEM": {
                    "ecutwfc": ecutwfc,
                    "ecutrho": ecutrho,
                },
            }
            
            qe_task = wg.add_task(
                PwCalculation,
                code=code,
                parameters=pw_params,
                kpoints=kpoints,
                pseudos=pseudos,
                metadata=task_metadata.value,
                structure=structure,
            )

            output_structures[f"struct{i}"] = {
                    "trajectory":qe_task.outputs.output_trajectory,
                    "parameters": qe_task.outputs.output_parameters
                }
        
        wg.update_ctx({
            "structures": output_structures
        })

    return {
        "structures": wg.ctx.structures,
    }

@task.calcfunction(outputs=["test_file", "train_file", "valid_file"])
def create_train_files(**structures):

    # tmpfile = tempfile.NamedTemporaryFile(suffix=".extxyz")
    
    structures_stack = list(structures.keys())
    shuffle(structures_stack)
    
    n = len(structures_stack)
    i1 = int(n*0.7)
    i2 = int(n*0.9)

    training_split = {
        "test":structures_stack[:i1],
        "train":structures_stack[i1:i2],
        "valid":structures_stack[i2:]
    }

    files = {}

    for split, split_structures in training_split.items():
        tmpfile = tempfile.NamedTemporaryFile(suffix=f"{split}.extxyz")
        for struct in split_structures:

            trajectory = structures[struct]["trajectory"]
            fileStructure = trajectory.get_structure(index=0)
            fileAtoms = fileStructure.get_ase()

            stress = trajectory.arrays["stress"][0]
            converted_stress = stress * units.GPa
            fileAtoms.info["qe_stress"] = converted_stress

            fileAtoms.info["units"] = {"energy": "eV","forces": "ev/Ang","stress": "ev/Ang^3"}
            fileAtoms.set_array("qe_forces", trajectory.arrays["forces"][0])

            parameters = structures[struct]["parameters"]
            fileParams = parameters.get_dict()
            fileAtoms.info["qe_energy"] = fileParams["energy"]
            
            write(Path(tmpfile.name), fileAtoms, append=True)

        files[f"{split}_file"] = SinglefileData(tmpfile)
    
    for filename, file in files.items():
        with file.as_path() as path:
            num_structs = len(read(path, index=":"))
        print(f"{filename} has {num_structs} structures")

    return{
        "test_file": files["test_file"],
        "train_file": files["train_file"],
        "valid_file": files["valid_file"]
    }


In [None]:
calc_inputs = {
    "code": janus_code,
    "model": model,
    "arch": Str(model.architecture),
    "device": Str("cuda"),
    "metadata": {"options": {"resources": {"num_machines": 1}}},
}

scales_inputs = {
    "min_v": 0.95,
    "max_v": 1.05,
    "num_structs": 12
}

qe_inputs = {
    "task_metadata": Dict({
            "options": {
                "resources": {
                    "tot_num_mpiprocs":1,
                    'num_mpiprocs_per_machine':1,
                    'num_cores_per_mpiproc':8,
                },
                "max_wallclock_seconds": 3600,
                "queue_name": "scarf",
                "qos": "scarf",
                "environment_variables": {},
                "withmpi": True,
                "prepend_text": """
                """,
                "append_text": "",
            },
    }),
    "kpoints_mesh": List([1, 1, 1]),
    "code": qe_code,
}

In [None]:
with WorkGraph("EOS_workflow") as wg:

    initial_structure = str(Path("../structures/NaCl-traj.xyz").resolve())

    scales_task = wg.add_task(
        create_scales,
        **scales_inputs,
        structures_path=initial_structure
    )

    qe_task = wg.add_task(
        qe,
        scaled_file=scales_task.outputs.scaled_file,
        **qe_inputs
    )
    
    train_task = wg.add_task(
        create_train_files,
        structures=qe_task.outputs.structures
    )
 

In [None]:
wg.run()

In [None]:
print(wg.tasks.create_train_files.outputs.test_file.value.get_content())