In [1]:
from aiida import load_profile
from aiida.orm import Int
from aiida_workgraph import task, WorkGraph

In [2]:
import os
import subprocess
import matplotlib.pyplot as plt
import numpy as np
from ase.build import bulk
from ase.io import write
from ase.atoms import Atoms
from adis_tools.parsers import parse_pw

In [3]:
from pickle import loads

In [4]:
def write_input(input_dict, working_directory="."):
    filename = os.path.join(working_directory, "input.pwi")
    os.makedirs(working_directory, exist_ok=True)
    write(
        filename=filename,
        images=Atoms(**input_dict["structure"]),
        Crystal=True,
        kpts=input_dict["kpts"],
        input_data={
            "calculation": input_dict["calculation"],
            "occupations": "smearing",
            "degauss": input_dict["smearing"],
        },
        pseudopotentials=input_dict["pseudopotentials"],
        tstress=True,
        tprnfor=True,
    )

In [5]:
def collect_output(working_directory="."):
    output = parse_pw(os.path.join(working_directory, "pwscf.xml"))
    return {
        "structure": output["ase_structure"].todict(),
        "energy": output["energy"],
        "volume": output["ase_structure"].get_volume(),
    }

In [6]:
@task.calcfunction(outputs=[{"name": "energy"}, {"name": "volume"}, {"name": "structure"}])
def calculate_qe(working_directory, input_dict):
    write_input(
        input_dict=input_dict,
        working_directory=working_directory,
    )
    subprocess.check_output(
        "mpirun -np 1 pw.x -in input.pwi > output.pwo",
        cwd=working_directory,
        shell=True,
    )
    return collect_output(working_directory=working_directory)

In [7]:
@task.calcfunction(outputs=[{"name": str(i)} for i in range(100)])  # maximum number of strains is currently 100 
def generate_structures(structure, strain_lst):
    structure_lst = []
    for strain in strain_lst:
        structure_strain = Atoms(**structure)
        structure_strain.set_cell(
            structure_strain.cell * strain ** (1 / 3), scale_atoms=True
        )
        structure_lst.append(structure_strain)
    return {str(i): s.todict() for i, s in enumerate(structure_lst)}

In [8]:
@task.calcfunction()
def plot_energy_volume_curve(volume_lst, energy_lst):
    plt.plot(volume_lst, energy_lst)
    plt.xlabel("Volume")
    plt.ylabel("Energy")
    plt.savefig("evcurve.png")

In [9]:
@task.calcfunction()
def get_bulk_structure(element, a, cubic):
    return bulk(
        name=element,
        a=a,
        cubic=cubic,
    ).todict()



In [10]:
pseudopotentials = {"Al": "Al.pbe-n-kjpaw_psl.1.0.0.UPF"}

In [11]:
load_profile()

Profile<uuid='7bb8761123324468bb98821cbb757251' name='presto'>

In [12]:
wg = WorkGraph("my_workflow")

In [13]:
structure = wg.add_task(get_bulk_structure, name="get_bulk_structure", element="Al", a=4.05, cubic=True)

In [14]:
calc_mini = wg.add_task(calculate_qe,
    name="calculate_qe",
    working_directory="mini",
    input_dict={
        "structure": structure.outputs.result,
        "pseudopotentials": pseudopotentials,
        "kpts": (3, 3, 3),
        "calculation": "vc-relax",
        "smearing": 0.02,
    },
)

In [15]:
number_of_strains = 5
structure_lst = wg.add_task(generate_structures,
    name="generate_structures",
    structure=calc_mini.outputs.structure,
    strain_lst=np.linspace(0.9, 1.1, number_of_strains),
)

In [16]:
job_strain_lst = []
for i in range(number_of_strains):
    calc_strain = wg.add_task(calculate_qe,
        name="calculate_qe_" + str(i),
        working_directory="strain_" + str(i),
        input_dict={
            "structure": getattr(structure_lst.outputs, str(i)),
            "pseudopotentials": pseudopotentials,
            "kpts": (3, 3, 3),
            "calculation": "scf",
            "smearing": 0.02,
        },
    )
    job_strain_lst.append(calc_strain)

In [17]:
plot = wg.add_task(plot_energy_volume_curve,
    name="plot",
    volume_lst=[job.outputs.volume for job in job_strain_lst],
    energy_lst=[job.outputs.energy for job in job_strain_lst],
)

In [18]:
wg.run()

TypeError: cannot pickle '_thread.RLock' object

In [None]:
plot.outputs.result