# Compare Bandstructures

In [None]:
from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)
    
from aiida.orm import load_node
from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.calculation.work import WorkCalculation
from aiida.orm.calculation.job.quantumespresso.pw import PwCalculation
from aiida.orm.calculation.job.quantumespresso.pp import PpCalculation

from ase.data import covalent_radii, atomic_numbers
from ase.data.colors import cpk_colors
from ase.neighborlist import NeighborList

from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt
import ipywidgets as ipw
import numpy as np
import urlparse

In [None]:
def get_calc_by_label(workcalc, calc_class, label):
    qb = QueryBuilder()
    qb.append(WorkCalculation, filters={'uuid':workcalc.uuid})
    qb.append(calc_class, output_of=WorkCalculation, filters={'label':label})
    assert qb.count() == 1
    calc = qb.first()[0]
    assert(calc.get_state() == 'FINISHED')
    return calc

In [None]:
def make_plots(pk_list):
    # display progress bar
    progress = ipw.IntProgress(description='Plotting...', max=len(pk_list))
    display(progress)

    args_erang = [-3,3]

    fig = plt.figure()
    fig.set_size_inches(2.8*len(pk_list), 16)
    gs = GridSpec(5, len(pk_list))
    fig.text(0.05, 0.5, 'E(eV)', va='center', rotation='vertical')

    for i, pk in enumerate(pk_list):
        progress.value = i

        if i==0:
            ax1 = fig.add_subplot(gs[1:5,i])
        else:
            ax1 = fig.add_subplot(gs[1:5,i], sharey=ax_old)
            ax1.tick_params(axis='y', which='both',left='on',right='off', labelleft='off')
        ax_old = ax1 # needed for next iteration
    
        workcalc = load_node(pk=int(pk))
        
        vacuum_level = workcalc.get_extra('vacuum_level')
        homo = workcalc.get_extra('homo')
        lumo = workcalc.get_extra('lumo')
        gap = workcalc.get_extra('gap')
        abs_mag = workcalc.get_extra('absolute_magnetization')
        tot_mag = workcalc.get_extra('total_magnetization')
        
        caption = 'Abs. magn.: {}$\mu_B$\nTot. magn.: {}$\mu_B$\nBand gap: {:.3f} eV'.format(abs_mag,tot_mag,gap)
        ax1.set_xlabel(caption)
        ax1.axhline(y=homo, linewidth=2, color='red', ls='--')

        bands_calc = get_calc_by_label(workcalc, PwCalculation, "bands")
        bands = bands_calc.out.output_band.get_bands()
        structure = bands_calc.inp.structure
        nspins, nkpoints, nbands = bands.shape
        
        ispin = 0 # spin hard-coded
        
        Lx = structure.cell_lengths[0]
        x_max = np.pi / Lx
        x_data = np.linspace(0.0, x_max, nkpoints)
        y_datas = bands[ispin,:,:] - vacuum_level

        for j in range(nbands):
            ax1.plot(x_data, y_datas[:,j], color='gray')
        
        center = (homo + lumo)/2.0
        ax1.set_ylim([center-3.0, center+3.0])

        ax2 = fig.add_subplot(gs[0:1,i])
        plot_thumbnail(ax2, structure)

    progress.close()
    plt.show()

In [None]:
def plot_thumbnail(ax, structure):
    ase_struct = structure.get_ase()
    s = ase_struct.repeat((2,1,1))
    cov_radii = [covalent_radii[a.number] for a in s]
    nl = NeighborList(cov_radii, bothways = True, self_interaction = False)
    nl.update(s)
    
    ax.set_aspect(1)
    ax.axes.set_xlim([0,s.cell[0][0]])
    ax.axes.set_ylim([5,s.cell[1][1]-5])
    ax.set_axis_bgcolor((0.423,0.690,0.933))
    ax.axes.get_yaxis().set_visible(False)

    name = ase_struct.get_chemical_formula() # get name before repeat
    ax.set_xlabel(name, fontsize=12)
    ax.tick_params(axis='x', which='both', bottom='off', top='off',labelbottom='off')
    
    for at in s:
        #circles
        x,y,z = at.position
        n = atomic_numbers[at.symbol]
        ax.add_artist(plt.Circle((x,y), covalent_radii[n]*0.5, color=cpk_colors[n], fill=True, clip_on=True))
        #bonds
        nlist = nl.get_neighbors(at.index)[0]
        for theneig in nlist:
            x,y,z = (s[theneig].position +  at.position)/2
            x0,y0,z0 = at.position
            if (x-x0)**2 + (y-y0)**2 < 2 :
                ax.plot([x0,x],[y0,y],color=cpk_colors[n],linewidth=2,linestyle='-')

In [None]:
url = urlparse.urlsplit(jupyter_notebook_url)
params = urlparse.parse_qs(url.query)
if 'pk' in params.keys():
    make_plots(params['pk'])
else:
    print("Nothing to compare.")