# Compare Bandstructures

In [None]:
from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)
    
from aiida.orm import load_node
from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.calculation.work import WorkCalculation
from aiida.orm.calculation.job.quantumespresso.pw import PwCalculation
from aiida.orm.calculation.job.quantumespresso.pp import PpCalculation

from ase.data import covalent_radii, atomic_numbers
from ase.data.colors import cpk_colors
from ase.neighborlist import NeighborList

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.image import imread

import urlparse
import numpy as np
import ipywidgets as ipw
from base64 import b64decode
from StringIO import StringIO 
from IPython.display import clear_output

In [None]:
def get_calc_by_label(workcalc, calc_class, label):
    qb = QueryBuilder()
    qb.append(WorkCalculation, filters={'uuid':workcalc.uuid})
    qb.append(calc_class, output_of=WorkCalculation, filters={'label':label})
    assert qb.count() == 1
    calc = qb.first()[0]
    assert(calc.get_state() == 'FINISHED')
    return calc

In [None]:
def make_plots(pk_list):
    
    # display progress bar
    clear_output()
    progress = ipw.IntProgress(description='Plotting...', max=len(pk_list))
    display(progress)

    fig = plt.figure()
    band_axes = []
    struct_axes = {}
    
    for i, pk in enumerate(pk_list):
        progress.value = i

        workcalc = load_node(pk=int(pk)) 
        bands_calc = get_calc_by_label(workcalc, PwCalculation, "bands")
        structure = bands_calc.inp.structure
        
        bands = bands_calc.out.output_band.get_bands()
        if bands.ndim == 2:
            bands = bands[None,:,:]
        nspins, nkpoints, nbands = bands.shape
        
        # thumbnail
        thumbnail = workcalc.get_extra('thumbnail')
        img = imread(StringIO(b64decode(thumbnail)))
        ax1 = fig.add_subplot(3, len(pk_list), i+1)  # layout fixed later
        ax1.imshow(img)
        ax1.set_axis_off()
        formula = workcalc.get_extra('formula')
        ax1.set_title(formula)
        j = len(band_axes)
        struct_axes[(j,j+nspins)] = ax1


        # band plots
        for ispin in range(nspins):
            
            # setup bands axis
            sharey = band_axes[-1] if band_axes else None
            ax2 = fig.add_subplot(ispin+1, len(pk_list), i+1, sharey=sharey) # layout fixed later
            if band_axes:
                ax2.tick_params(axis='y', which='both',left='on',right='off', labelleft='off')
            else:
                ax2.set_ylabel('E(eV)')
            band_axes.append(ax2)
        
            # caption
            vacuum_level = workcalc.get_extra('vacuum_level')
            homo = workcalc.get_extra('homo')
            lumo = workcalc.get_extra('lumo')
            gap = workcalc.get_extra('gap')
            abs_mag = workcalc.get_extra('absolute_magnetization')
            tot_mag = workcalc.get_extra('total_magnetization')
            caption = 'Abs. magn.: {}$\mu_B$\nTot. magn.: {}$\mu_B$\nBand gap: {:.3f} eV'.format(abs_mag,tot_mag,gap)
            ax2.set_xlabel('$k\AA^{-1}$\n\n'+caption)
            ax2.axhline(y=homo, linewidth=2, color='red', ls='--')

            # plot bands
            Lx = structure.cell_lengths[0]
            x_max = np.pi / Lx
            x_data = np.linspace(0.0, x_max, nkpoints)
            y_datas = bands[ispin,:,:] - vacuum_level

            for j in range(nbands):
                ax2.plot(x_data, y_datas[:,j], color='gray')

            center = (homo + lumo)/2.0
            ax2.set_ylim([center-3.0, center+3.0])
            ax2.set_title("Spin %d"%ispin)

    # apply proper layout
    ncols = len(band_axes)
    fig.set_size_inches(2.8*ncols, 16)
    gs = GridSpec(5, ncols)
    for s, ax in struct_axes.items():
        ax.set_position(gs[0, s[0]:s[1]].get_position(fig))
    for i, ax in enumerate(band_axes):
        ax.set_position(gs[1:5, i].get_position(fig))
    
    progress.close()
    plt.show()

In [None]:
url = urlparse.urlsplit(jupyter_notebook_url)
params = urlparse.parse_qs(url.query)
if 'pk' in params.keys():
    make_plots(params['pk'])
else:
    print("Nothing to compare.")