# Projected Density of States

In [None]:
from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)
    
from aiida.orm import load_node
from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.calculation.work import WorkCalculation
from aiida.orm.calculation.job.JobCalculation

import re
import urlparse
import numpy as np
from xml.etree import ElementTree
import matplotlib.pyplot as plt
import ipywidgets as ipw
from IPython.display import clear_output
import nglview

In [None]:
def get_calc_by_label(workcalc, label):
    qb = QueryBuilder()
    qb.append(WorkCalculation, filters={'uuid':workcalc.uuid})
    qb.append(JobCalculation, output_of=WorkCalculation, filters={'label':label})
    assert qb.count() == 1
    calc = qb.first()[0]
    assert(calc.get_state() == 'FINISHED')
    return calc

In [None]:
url = urlparse.urlsplit(jupyter_notebook_url)
params = urlparse.parse_qs(url.query)
pk = urlparse.parse_qs(url.query)['pk'][0]

workcalc = load_node(pk=int(pk))
vacuum_level = workcalc.get_extra('vacuum_level')
homo = workcalc.get_extra('homo')
lumo = workcalc.get_extra('lumo')

pdos_calc = get_calc_by_label(workcalc, "export_pdos")
bands_calc = get_calc_by_label(workcalc, "bands")
structure = bands_calc.inp.structure
ase_struct = structure.get_ase()
natoms = len(ase_struct)

bands = bands_calc.out.output_band.get_bands()
if bands.ndim == 2:
    bands = bands[None,:,:]

In [None]:
fn = pdos_calc.out.retrieved.get_abs_path('atomic_proj.xml')
root = ElementTree.parse(fn).getroot()
nbands = int(root.find('HEADER/NUMBER_OF_BANDS').text)
nkpoints = int(root.find('HEADER/NUMBER_OF_K-POINTS').text)
nspins = int(root.find('HEADER/NUMBER_OF_SPIN_COMPONENTS').text)
natwfcs = int(root.find('HEADER/NUMBER_OF_ATOMIC_WFC').text)

In [None]:
# kpoint weights
kpoint_weights = np.fromstring(root.find('WEIGHT_OF_K-POINTS').text, sep=' ')

In [None]:
# eigenvalues
eigvalues = np.zeros((nspins, nbands, nkpoints))
for i in range(nspins):
    for k in range(nkpoints):
        arr = np.fromstring(root.find('EIGENVALUES/K-POINT.%d/EIG.%s'%(k+1, i+1)).text, sep='\n')
        eigvalues[i, :, k] = arr * 13.60569806589 - vacuum_level # convert Ry to eV

In [None]:
# projections
projections = np.zeros((nspins, nbands, nkpoints, natwfcs))
for i in range(nspins):
    for k in range(nkpoints):
        for l in range(natwfcs):
            raw = root.find('PROJECTIONS/K-POINT.%d/SPIN.%d/ATMWFC.%d'%(k+1, i+1, l+1)).text
            arr = np.fromstring(raw.replace(",", "\n"), sep="\n")
            arr2 = arr.reshape(nbands, 2) # group real and imaginary part together
            arr3 = np.sum(np.square(arr2), axis=1) # calculate square of abs value
            projections[i, :, k, l] = arr3

In [None]:
# parse mapping atomic functions -> atoms
# example:     state #   2: atom   1 (C  ), wfc  2 (l=1 m= 1)

fn = pdos_calc.out.retrieved.get_abs_path('aiida.out')
content = open(fn).read()
m = re.findall("\n\s+state #\s*(\d+): atom\s*(\d+) ", content, re.DOTALL)
atmwfc2atom = dict([(int(i), int(j)) for i,j in m])
assert len(atmwfc2atom) == natwfcs
assert len(set(atmwfc2atom.values())) == natoms

In [None]:
def w0gauss(x,n):
    arg = np.minimum(200.0, x**2)
    w0gauss = np.exp ( - arg) / np.sqrt(np.pi)
    if n==0 :
        return w0gauss
    hd = 0.0
    hp = np.exp( - arg)
    ni = 0
    a = 1.0 / np.sqrt(np.pi)
    for i in range(1, n+1):
        hd = 2.0 * x * hp - 2.0 * ni * hd
        ni = ni + 1
        a = - a / (i * 4.0)
        hp = 2.0 * x * hd-2.0 * ni * hp
        ni = ni + 1
        w0gauss = w0gauss + a * hp
    return w0gauss

In [None]:
def calc_pdos(sigma, ngauss, Emin, Emax, atmwfcs):
    DeltaE = 0.01
    x = np.arange(Emin,Emax,DeltaE)
    
    # calculate histogram for all spins, bands, and kpoints in parallel
    xx = np.tile(x[:, None, None, None], (1, nspins, nbands, nkpoints))
    arg = (xx - eigvalues) / sigma
    delta = w0gauss(arg, n=ngauss) / sigma
    
    if atmwfcs:
        p = np.sum(projections[:,:,:,atmwfcs], axis=3) # sum over selected atmwfcs
    else:
        p = np.sum(projections, axis=3) # sum overa all atmwfcs

    c = delta * p * kpoint_weights
    y = np.sum(c, axis=(2,3)) # sum over bands and kpoints
    
    return x, y

In [None]:
def plot_pdos(ax, pdos, ispin):
    x, y = pdos
    ax.plot(y[:,ispin], x) # vertical plot
    ax.set_xlim(0, 1.1*np.amax(y))
    ax.set_xlabel('abitrary unit')

In [None]:
def plot_bands(ax, ispin):
    nspins, nkpoints, nbands = bands.shape
    
    ax.set_title("Spin %d"%ispin)
    ax.set_xlabel('$k\AA^{-1}$')
    ax.axhline(y=homo, linewidth=2, color='red', ls='--')

    # plot bands
    Lx = structure.cell_lengths[0]
    x_max = np.pi / Lx
    x_data = np.linspace(0.0, x_max, nkpoints)
    y_datas = bands[ispin,:,:] - vacuum_level

    ax.set_xlim(0, x_max)
    for j in range(nbands):
        ax.plot(x_data, y_datas[:,j], color='gray')

In [None]:
def plot_all():
    # collect all atmwfc located on selected atoms
    if selected_atoms:
        atmwfcs = [k-1 for k, v in atmwfc2atom.items() if v-1 in selected_atoms]
        print("Selected atmwfcs: "+str(atmwfcs))
    else:
        print("No atom selected - showing full PDOS")
        atmwfcs = None # all
        
    sigma = sigma_slider.value
    ngauss = ngauss_slider.value

    fig = plt.figure()
    fig.set_size_inches(12, 16)
    sharey = None
    center = (homo + lumo)/2.0
    Emin, Emax = center-3.0, center+3.0
    pdos = calc_pdos(ngauss=ngauss, sigma=sigma, Emin=Emin, Emax=Emax, atmwfcs=atmwfcs)
    
    for ispin in range(nspins):
        # band plot
        ax1 = fig.add_subplot(1, 4, 2*ispin+1, sharey=sharey)
        if not sharey:
            ax1.set_ylabel('E(eV)')
            sharey = ax1
        else:
            ax1.tick_params(axis='y', which='both',left='on',right='off', labelleft='off')
        plot_bands(ax=ax1, ispin=ispin)

        # pdos plot
        ax2 = fig.add_subplot(1, 4, 2*ispin+2, sharey=sharey)
        ax2.tick_params(axis='y', which='both',left='on',right='off', labelleft='off')
        plot_pdos(ax=ax2, pdos=pdos, ispin=ispin)
    
    sharey.set_ylim(Emin, Emax)
    plt.show()

In [None]:
def on_picked(c):
    global selected_atoms
    
    if 'atom' not in viewer.picked.keys():
        return # did not click on atom
    with plot_out:
        clear_output()
        #viewer.clear_representations()
        viewer.component_0.remove_ball_and_stick()
        viewer.component_0.remove_ball_and_stick()
        viewer.add_ball_and_stick()
        #viewer.add_unitcell()

        idx = viewer.picked['atom']['index']

        # toggle
        if idx in selected_atoms:
            selected_atoms.remove(idx)
        else:
            selected_atoms.add(idx)

        #if(selection):
        sel_str = ",".join([str(i) for i in sorted(selected_atoms)])
        viewer.add_representation('ball+stick', selection="@"+sel_str, color='red', aspectRatio=3.0)
        #else:
        #    print ("nothing selected")
        viewer.picked = {} # reset, otherwise immidiately selecting same atom again won't create change event
        
        plot_all()

In [None]:
def on_change(c):
    with plot_out:
        clear_output()
        plot_all()

style = {"description_width":"200px"}
layout = ipw.Layout(width="600px")
sigma_slider = ipw.FloatSlider(description="Broadening [eV]", min=0.01, max=0.5, value=0.1, step=0.01,
                               continuous_update=False, layout=layout, style=style)
sigma_slider.observe(on_change, names='value')
ngauss_slider = ipw.IntSlider(description="Methfessel-Paxton order", min=0, max=3, value=0,
                              continuous_update=False, layout=layout, style=style)
ngauss_slider.observe(on_change, names='value')

selected_atoms = set()    
viewer = nglview.NGLWidget()

viewer.add_component(nglview.ASEStructure(ase_struct)) # adds ball+stick
viewer.add_unitcell()
viewer.center_view()

viewer.observe(on_picked, names='picked')
plot_out = ipw.Output()

display(sigma_slider, ngauss_slider, viewer, plot_out)
on_change(None)