## Integrating EPWpy with MLIP workflow

EPW workflows written in `EPWpy` can seamlessly integrate with machine learning interatomic potential (MLIP) workflows, which are usually written in Python.

In this notebook, we calculate phonon dispersion using `PWscf` and `PHonon` executables in [Quantum Espresso](https://www.quantum-espresso.org/) accessed via `EPWpy`, and compare with the corresponding output generated with a universal interatomic potential (UIP) named [M3GNet](https://www.nature.com/articles/s43588-022-00349-3).

### Getting started

In [None]:
from typing import Dict
import os
import sys
sys.path.insert(0, os.path.abspath(f'{os.getcwd()}/../EPWpy'))
import json

import numpy as np

from EPWpy import EPWpy

In [None]:
diamond = EPWpy({'prefix':'diamond',
               'restart_mode':'\'from_scratch\'',
               'structure':'diamond.POSCAR',               
               'mass':[12.011],
               'pseudo': 'automatic',
               'ecutwfc': '90',
               'verbosity': 'high',
               'pseudo_dir': 'automatic'                  
                },
               code='',
               env='mpirun',
               system='C')

### EPWpy workflow

In [None]:
# Calculate phonon frequencies using pw.x and ph.x

# TODO: Write code here


### MLIP workflow

In [None]:
# Install and import packages required for this section
! pip install torch pymatgen ase phonopy matgl pydantic > /dev/null
import numpy as np
import torch
from pymatgen.core import Structure
from phonopy import Phonopy
from phonopy.structure.atoms import PhonopyAtoms
from phonopy.phonon.band_structure import (
    get_band_qpoints_and_path_connections,
)
from ase import Atoms
from ase.calculators.calculator import Calculator

In [None]:
struct = Structure(
    lattice=diamond.default_pw_cell_parameters['vec'],
    species=diamond.default_pw_atomic_positions['atoms'],
    coords=diamond.default_pw_atomic_positions['atomic_pos']
)
struct

In [None]:
# Relax the structure according to MLIP
calc = 'm3gnet'

torch.set_default_device('cpu')  # need this for now, maybe matgl folks would fix this later
struct = struct.relax(calculator=calc, steps=1000, fmax=1e-5, verbose=True)
calc: Calculator = struct.calc

In [None]:
def phonopyatoms2atoms(pnpatoms: PhonopyAtoms) -> Atoms:
    """
    phonopy.PhonopyAtoms -> ase.Atoms converter
    """
    return Atoms(
        symbols=pnpatoms.symbols,
        positions=pnpatoms.positions,
        cell=pnpatoms.cell,
    )

supercell_dim = 5
unitcell = PhonopyAtoms(symbols=[s.symbol for s in struct.species],
                        cell=struct.lattice.matrix,
                        scaled_positions=struct.frac_coords,
                        )
factors = [supercell_dim] * 3

dx = 0.04  # in Angstrom
phonon = Phonopy(unitcell, supercell_matrix=np.diag(factors))
phonon.generate_displacements(distance=dx)
sets_of_forces = np.array([
    calc.get_forces(phonopyatoms2atoms(scell))
    for scell in phonon.supercells_with_displacements
])
phonon.forces = sets_of_forces
phonon.produce_force_constants()
phonon.symmetrize_force_constants(level=1)

# Get phonon band structure data
band_paths = [
    # Path 1
    [
        [0.0, 0.0, 0.0],  # Gamma
        [1/2, 0.0, 1/2],  # X
        [1/2, 1/4, 3/4],  # W
        [1/2, 1/2, 1/2],  # L
        [0.0, 0.0, 0.0],  # Gamma
    ],
]
band_path_labels = ['$\\Gamma$', '$X$', '$W$', '$L$', '$\\Gamma$']
qpoints, _ = get_band_qpoints_and_path_connections(band_paths)
phonon.run_band_structure(qpoints)
data = phonon.get_band_structure_dict()

In [None]:
qpoints = [data['qpoints'][0][:1, :]]
frequencies = [data['frequencies'][0][:1, :]]
for qpts, freqs in zip(data['qpoints'], data['frequencies']):
    qpoints.append(qpts[1:, :])
    frequencies.append(freqs[1:, :])
qpoints, frequencies = np.vstack(qpoints), np.vstack(frequencies)
qdist = np.linalg.norm(qpoints[1:,:] - qpoints[:-1,:],
                        axis=1, keepdims=True)
qdist = np.vstack((np.zeros((1, 1)), qdist))
qdist = np.cumsum(qdist, axis=0).flatten()

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize=(8., 6.), facecolor='w')
plt.plot(qdist, frequencies, 'r-')
plt.xlim(qdist.min(), qdist.max())
plt.xticks(ticks=qdist[::50], labels=band_path_labels)
plt.grid()
plt.ylabel('$\omega$ (THz)')
plt.show()