# Prepare Molecule for SPM Simulation with AiiDA

In [None]:
import os
import os.path
import re
import time
import threading
import subprocess

import ase
import ase.io
import nglview
import numpy as np

import ipympl
import matplotlib.pyplot as plt

import ipywidgets as widgets
from IPython.display import display, clear_output
from fileupload import FileUploadWidget
from tempfile import NamedTemporaryFile

from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)
from aiida.orm import DataFactory

In [None]:
%%html
<!-- hide matplotlib figure title -->
<style> .ui-dialog-titlebar { display: none; } </style>

## Step 1: Upload a .xyz file or select example molecule

In [None]:
def on_file_upload(c):
    global mol
    tmp = NamedTemporaryFile(suffix=file_upload.filename)
    f = open(tmp.name, "w")
    f.write(file_upload.data)
    f.close()
    mol = ase.io.read(tmp.name)
    tmp.close()
    setup_new_mol()

def on_click_example(b):
    global mol
    if b == btn_eg1:
        mol = ase.io.read("mol_start.xyz")
    if b == btn_eg2:
        mol = ase.io.read("mol_start2.xyz")
    setup_new_mol()
    
    
#TODO: FileUploadWidget doesn't fire event when same file is uploaded twice
file_upload = FileUploadWidget("Upload Molecule")
file_upload.observe(on_file_upload, names='data')

btn_eg1 = widgets.Button(description='Example 1')
btn_eg2 = widgets.Button(description='Example 2')
btn_eg1.on_click(on_click_example)
btn_eg2.on_click(on_click_example)
display(widgets.HBox([file_upload, btn_eg1, btn_eg2]))

In [None]:
def setup_new_mol():
    global viewer_struct
    
    clear_output()
    print("Found %i atoms."%len(mol))
    prepare_mol_on_slab()
    
    # remove old components
    for i in viewer._ngl_component_ids:
        viewer.remove_component(i)
    
    viewer_struct = ASEStructure2(mol_on_au)
    viewer.add_trajectory(viewer_struct)
    viewer.add_ball_and_stick()
    viewer.add_unitcell()
    viewer.center_view()
    #TODO https://github.com/arose/ngl/blob/master/src/controls/viewer-controls.js
    #w.orientation = [[-1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,0.1]] 

In [None]:
def prepare_mol_on_slab():
    global mol_on_au, mol, au_slab
    
    # determine cell size
    aux = 2.973735971
    auy = 5.150661791
    auz_top = 18.49
    dminau = 2.3

    cx = np.amax(mol.positions[:,0]) - np.amin(mol.positions[:,0])
    cellmax = cx
    cx = cx + 10
    cy = np.amax(mol.positions[:,1]) - np.amin(mol.positions[:,1])
    cy = cy + 10
    cz = 40
    nx = int(round(cx/aux))
    ny = int(round(cy/auy))
    cx = nx * aux
    cy = ny * auy
    #print "ABC", cx, cy, cz
    mol.cell = (cx,cy,cz)
    mol.pbc = (True,True,True)
    
    # position molecule a bit above gold slab
    mol.center()
    minz = np.amin(mol.positions[:,2])
    dz = (-minz + auz_top + dminau)
    mol.positions[:,2] += dz
    
    # template for gold slab
    au_templ = [    ('H'  ,   1.502273797989   ,  3.091176986694   , 10.434000015259),
                    ('H'  ,   0.017487669364   ,  0.521309018135   , 10.439200401306),
                    ('Au' ,   0.011242070235   ,  3.954530715942   , 11.276800155640),
                    ('Au' ,   1.498110055923   ,  1.378893256187   , 11.276800155640),
                    ('Au' ,   0.011242070235   ,  2.237301826477   , 13.704799652100),
                    ('Au' ,   1.498110055923   ,  4.812939167023   , 13.704799652100),
                    ('Au' ,   1.498110055923   ,  3.095984896024   , 16.073392868042),
                    ('Au' ,   0.011242070235   ,  0.520347436269   , 16.073581695557),
                    ('Au' ,   0.011242070235   ,  3.954530715942   , 18.493516921997),
                    ('Au' ,   1.498110055923   ,  1.378893256187   , 18.495817184448), ]
    
    # generate gold slab
    #print "molecule from 1 to ",atomsmol
    #print "substrate from ",atomsmol+1," to ", atomsmol+nx*ny*2*5
    #print "fix positions from ",atomsmol+1," to ", atomsmol+nx*ny*2*5 ,"or in DFT to", atomsmol+nx*ny*2*3
    au_slab = ase.Atoms()
    for sym, tx, ty, tz in au_templ:
        for i in range(nx):
            for j in range(ny):
                pos = (tx + i*aux, ty + j*auy, tz)
                a = ase.Atom(sym, pos)
                au_slab.append(a)

    #print "total number of atoms ",len(au_atoms)
    mol_on_au = mol + au_slab

## Step2: Position molecule on surface

In [None]:
# TODO merge this upstream
from nglview.base_adaptor import Structure, Trajectory
from nglview.utils.py_utils import tempfolder

class ASEStructure2(Trajectory, Structure):
    def __init__(self, ase_atoms, ext='pdb', params={}):
        super(ASEStructure2, self).__init__()
        self.path = ''
        self.ext = ext
        self.params = params
        self._ase_atoms = ase_atoms

    def get_structure_string(self):
        with tempfolder():
            self._ase_atoms.write('tmp.pdb')
            return open('tmp.pdb').read()
        
    def get_coordinates(self, index):
        return self._ase_atoms.positions

    @property
    def n_frames(self):
        return 1

In [None]:
viewer = nglview.NGLWidget()
display(viewer)

In [None]:
#<i class="fa fa-repeat" aria-hidden="true"></i>
btn_rotleft = widgets.Button(description='Rotate left')
btn_rotright = widgets.Button(description='Rotate right')
btn_up = widgets.Button(description='Move up')
btn_down = widgets.Button(description='Move down')


def on_click_dir(b):
    global mol_on_au
    if b == btn_rotleft:
        mol.rotate(a=+0.1, v="z", center="COM")
    if b == btn_rotright:
        mol.rotate(a=-0.1, v="z", center="COM")
    if b == btn_up:
        mol.positions[:,2] += 0.1
    if b == btn_down:
        mol.positions[:,2] -= 0.1

    mol_on_au = mol + au_slab
    update_ngview()
    
btn_rotleft.on_click(on_click_dir)
btn_rotright.on_click(on_click_dir)
btn_up.on_click(on_click_dir)
btn_down.on_click(on_click_dir)
display(widgets.HBox([btn_up, btn_down, btn_rotleft, btn_rotright]))

In [None]:
def update_ngview():
    viewer_struct._ase_atoms = mol_on_au
    # add new component before remoing old one, preserves orientation
    #viewer.add_structure(nglview.ASEStructure(mol_on_au))
    #viewer.component_1.add_ball_and_stick()
    #viewer.component_1.add_unitcell()
    #viewer.remove_component(w._ngl_component_ids[0])
    viewer.on_frame_changed(None)

In [None]:
def read_traj(fn):
    global mol, au_slab, mol_on_au
    traj = open(fn).read()
    lines = traj.strip().split("\n")
    #natoms = len()
    #lines[-]
    #for line in reversed(traj.split("\n")):
    #    print line
    #    if " E = " in line:
    #        break
    natoms = int(lines[0])
    assert(len(lines)%float(natoms+2)==0.0)
    #energy = float(lines[-natoms-1].split()[5])
    positions = []
    for i, line in enumerate(lines[-natoms:-natoms+len(mol)]):
        s, x, y, z = line.split()
        mol[i].position = [float(x), float(y), float(z)]
    mol_on_au = mol + au_slab
    #TODO read eneries

In [None]:
def update_plot(ax):
    #energy_curve[0].set_data(xdata, energies)
    #fig.canvas.draw()
    ax.clear()
    ax.plot(energies, marker="o")
    ax.set_xlabel("Optimization Step")
    ax.set_ylabel("Energy [a.u.]")
    ax.figure.canvas.draw()

## Step 3: Run geometry optimization (DFTB)

In [None]:
# run CP2K
output_area = widgets.HTML()
fig, ax = plt.subplots(figsize=(7,4))
ax.set_xlabel("Optimization Step")
ax.set_ylabel("Energy [a.u.]")
#energy_curve = ax.plot([])

def update_energies_positions(ax):
    # update plot
    if not os.path.exists("PROJECT-pos-1.xyz"):
        return
    traj = open("PROJECT-pos-1.xyz").read()
    energies = [float(x) for x in re.findall(" E = (.*)", traj)]
    #xdata = np.arange(len(energies)) 
    ax.clear()
    ax.plot(energies, marker="o")
    ax.set_xlabel("Optimization Step")
    ax.set_ylabel("Energy [a.u.]")
    ax.figure.canvas.draw()
    
    read_traj("PROJECT-pos-1.xyz")
    update_ngview()
    

def output_worker():
    pre_tag = '<pre style="width:600px; max-height:250px; overflow-x:auto; line-height:1em; font-size:0.8em;">'   
    dots = 0
    latest = ""
    while(cp2k_process.poll() == None):
        time.sleep(1)
        # update output window
        if os.path.exists("mol_on_au.out"):
            full = open("mol_on_au.out").read() # TODO seek forward
            output = "\n".join(full.split("\n")[-20:]) #last 100 lines
        else:
            output = ""
        dots += 1
        if(latest!=output): # new output
            dots = 0 # rest dot counter
            update_energies_positions(ax)
        latest = output
        output += "\n" + ("."*dots) + "\n"
        output_area.value = pre_tag + output + '</pre>'
    
    # read one last time entirely
    update_energies_positions(ax)
    output = open("mol_on_au.out").read()
    output += "\n\nCP2K finished, exit code: %s"%cp2k_process.returncode
    output_area.value = pre_tag + output + '</pre>'
    max_force.disabled = False
    btn_startstop.description="Start CP2K"
    

def start_cp2k():
    global cp2k_process
    max_force.disabled = True
    btn_startstop.description="Stop CP2K"
    
    # construct CP2K input 
    mol.write("mol.xyz")
    mol_on_au = mol + au_slab
    mol_on_au.write("mol_on_au.xyz")
    cp2k_inp = open("mol_on_au_cp2k.tmpl").read()
    cp2k_inp = cp2k_inp.replace("<first_mol_atom>", "1")
    cp2k_inp = cp2k_inp.replace("<last_mol_atom>", "%d"%len(mol))
    cp2k_inp = cp2k_inp.replace("<first_au_atom>", "%d"%(len(mol)+1))
    cp2k_inp = cp2k_inp.replace("<last_au_atom>", "%d"%len(mol_on_au))
    cp2k_inp = cp2k_inp.replace("<cell_abc>", "%f %f %f"%tuple(mol.cell.diagonal()))
    cp2k_inp = cp2k_inp.replace("<max_force>", "%f"%max_force.value)
    open("mol_on_au.inp", "w").write(cp2k_inp)

    # start cp2k
    if os.path.exists("mol_on_au.out"):
        os.remove("mol_on_au.out")
    if os.path.exists("PROJECT-pos-1.xyz"):
        os.remove("PROJECT-pos-1.xyz")
    cmd = "cp2k.psmp -i mol_on_au.inp -o mol_on_au.out"
    cp2k_process = subprocess.Popen(cmd.split())
    
    # start output thread
    threading.Thread(target=output_worker).start()


def on_click_startstop(e):
    if(btn_startstop.description.startswith("Start")):
        start_cp2k()
    else:
        cp2k_process.kill()
       

btn_startstop = widgets.Button(description="Start CP2K")
btn_startstop.on_click(on_click_startstop)

max_force = widgets.FloatSlider(description='MAX_FORCE:', value=1e-3, min=1e-4, max=1e-2, step=1e-5, 
                                 readout_format='.1e', layout=widgets.Layout(width="600px"),)


display(widgets.VBox([max_force, btn_startstop, output_area, fig.canvas.manager.canvas]))

## Step 4: Store structure in the AiiDA database

In [None]:
def on_click_store(b):
    global mol_on_au
    
    StructureData = DataFactory('structure')
    s = StructureData(ase=mol_on_au)
    s.description = inp_descr.value
    s.store()
    print("Stored in AiiDA: "+repr(s))

inp_descr = widgets.Text(placeholder="Description (optional)")   
btn_store = widgets.Button(description='Store in AiiDA')
btn_store.on_click(on_click_store)
display(widgets.HBox([btn_store, inp_descr]))