# **Free-electron bands in a periodic lattice**

Still to  TODO
- check units of energy
- reuse band visualizer - data is already returned *almost* in the required format
- add docstrings!

On the user side:
- show the BZ reusing the BZ visualizer of seekpath
- give a choice with 'standard' known lattices (BCC, FCC, tetragonal, ...) - we can pick the full list from seekpath probably


In [None]:
import numpy as np
import seekpath
import re
import matplotlib
from ase.dft.dos import linear_tetrahedron_integration as lti
from ase.dft.kpoints import monkhorst_pack
from ase.cell import Cell

In [None]:
def prettify(label):
    """
    Prettifier for matplotlib, using LaTeX syntax
    :param label: a string to prettify
    """

    label = (
        label
            .replace('GAMMA', r'$\Gamma$')
            .replace('DELTA', r'$\Delta$')
            .replace('LAMBDA', r'$\Lambda$')
            .replace('SIGMA', r'$\Sigma$')
    )
    label = re.sub(r'_(.?)', r'$_{\1}$', label)

    return label

In [None]:
def _get_band_energies(kpoints_list, b1, b2, b3, g_vectors_range):
    energy_data_curves = np.zeros(((2*g_vectors_range+1)**3, len(kpoints_list)), dtype=np.float_)

    cnt = 0
    for g_i in range(-g_vectors_range,g_vectors_range+1):
        for g_j in range(-g_vectors_range,g_vectors_range+1):
            for g_k in range(-g_vectors_range,g_vectors_range+1):
                g_vector = b1 * g_i + b2*g_j + b3 * g_k
                energy_data_curves[cnt] = np.sum(0.5*(kpoints_list + g_vector)**2, axis=1)# This is k^2 - NOTE: units to be double checked!
                cnt += 1


    # bands are ordered as follows: first band, second band, ...
    return energy_data_curves

def _compute_dos(kpts, G, ranges):
    eigs = []
    n = ranges
    
    for i in range(-n, n+1):
        for j in range(-n, n+1):
            for k in range(-n, n+1):
                g_vector = i*G[0] + j*G[1] + k*G[2]
                eigs.append(np.sum(0.5*(kpts + g_vector)**2, axis=3))

    eigs = np.moveaxis(eigs, 0, -1)
    return eigs
    

In [None]:
def get_bands(real_lattice_bohr, reference_distance = 0.025, g_vectors_range = 3):

    # Simple way to get automatically the band path:
    # I go back to real space, just put a single atom at the origin,
    # then go back with seekpath.
    # NOTE! This might not give the most general path, as e.g. there are two
    # options for cubic FCC (cF1 and cF2 in seekpath).
    # But this should be general enough for this tool.

    structure = (real_lattice_bohr, [[0., 0., 0.]], [1])
    # Use a H atom at the origin
    seekpath_path = seekpath.get_explicit_k_path(structure, reference_distance=reference_distance)
    b1, b2, b3 = np.array(seekpath_path['reciprocal_primitive_lattice'])

    all_kpoints_x = np.array(seekpath_path['explicit_kpoints_linearcoord'])
    all_kpoints_list = np.array(seekpath_path['explicit_kpoints_abs'])

    segments_data = []
    for segment_indices in seekpath_path['explicit_segments']:
        start_label = seekpath_path['explicit_kpoints_labels'][segment_indices[0]]
        end_label = seekpath_path['explicit_kpoints_labels'][segment_indices[1]-1]

        kpoints_x = all_kpoints_x[slice(*segment_indices)]
        kpoints_list = all_kpoints_list[slice(*segment_indices)]

        energy_bands = _get_band_energies(kpoints_list, b1, b2, b3, g_vectors_range)

        segments_data.append({
            'start_label': start_label,
            'end_label': end_label,
            'kpoints_list': kpoints_list,
            'kpoints_x': kpoints_x,
            'energy_bands': energy_bands,
            'b1': b1,
            'b2': b2,
            'b3': b3,
        })

    return segments_data

In [None]:
%matplotlib widget

import time
import matplotlib.pyplot as plt
from ipywidgets import Output, Button, RadioButtons, IntSlider, HBox, VBox

alat_bohr = 7.72

lattices = np.zeros((3, 3, 3));

lattices[0] = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) * alat_bohr / 2.0;
lattices[1] = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]) * alat_bohr / 2.0;
lattices[2] = np.array([[-1, 1, 1], [1, -1, 1], [1, 1, -1]]) * alat_bohr / 2.0;

real_lattice_bohr = lattices[0]

In [None]:
#G = Cell(real_lattice_bohr).reciprocal()*2*np.pi

output = Output()
btdos = Button(description="Compute DOS")
cell_type = RadioButtons(options=['Simple cubic', 'FCC', 'BCC'], value='Simple cubic', description="Cell type:")
nkpt = IntSlider(value=2, min=2, max=11, description="Number of kpoint in each dimension:")
grange = IntSlider(value=1, min=0, max=3, description="Gvector range:")

def on_celltype_changed(c):
    global real_lattice_bohr
    real_lattice_bohr = lattices[cell_type.index]
    ax[0].clear()
    plot_bandstructure('bands')
    

cell_type.observe(on_celltype_changed, names='value');

def compute_dos_lti(c):
    global dosx, dosy
    ax[1].clear()
    shape = (nkpt.value, nkpt.value, nkpt.value)
    kpts = np.dot(monkhorst_pack(shape), G).reshape(shape + (3,))

    eigs = _compute_dos(kpts, G, grange.value)

    dosx = np.linspace(0, 10, 500)
    dosy = lti(real_lattice_bohr, eigs, dosx)
    
    plot_dos(eigs)


btdos.on_click(compute_dos_lti);


def plot_bandstructure(c):
    global G
    btdos.disabled = True
    
    segments_data = get_bands(real_lattice_bohr)
    G = np.array([segments_data[0]['b1'], segments_data[0]['b2'], segments_data[0]['b3']])
    
    x_ticks = []
    x_labels = []

    for segment_data in segments_data:
        if not x_labels:
            x_labels.append(prettify(segment_data['start_label']))
            x_ticks.append(segment_data['kpoints_x'][0])
        else:
            if x_labels[-1] != prettify(segment_data['start_label']):
                x_labels[-1] += "|" + prettify(segment_data['start_label'])
        x_labels.append(prettify(segment_data['end_label']))
        x_ticks.append(segment_data['kpoints_x'][-1])

        for energy_band in segment_data['energy_bands']:
            ax[0].plot(segment_data['kpoints_x'], energy_band, 'k')

    ax[0].set_ylim([0, 5])
    ax[0].set_ylabel('Free-electron energy')
    ax[0].set_xlim([np.min(x_ticks), np.max(x_ticks)])
    ax[0].set_xticks(x_ticks)
    ax[0].set_xticklabels(x_labels)
    ax[0].grid(axis='x', color='red', linestyle='-', linewidth=0.5)
    
    ax[1].set_ylim([0, 5])
    ax[1].yaxis.tick_right()
    ax[1].yaxis.set_label_position("right")
    ax[1].set_ylabel('Density of States (eV)')
    
    btdos.disabled = False

    
def plot_dos(eigs):
    global lanaly, llti
    
    try:
        lanaly.remove()
        llti.remove()
    except:
        pass
    
    llti, = ax[1].plot(dosy, dosx, 'r-', label='LTI')

    analy_x = np.linspace(0, 5, 500);
    analy_y = 1.0/(2.0*np.pi**2)*2.0**0.5*analy_x**0.5*(alat_bohr / 2.0)**3.0*2.0**cell_type.index;

    lanaly, = ax[1].plot(analy_y, analy_x, 'b', label='Analytical solution')
    
    hy, hx = np.histogram(eigs.ravel(), bins=50, range=(0.0, 5.0))
    hy = hy/hy.max()*analy_y[-1]*0.5
    
    ax[1].barh(hx[:-1]+np.diff(hx)[0], hy, color='yellow', edgecolor='black', height=np.diff(hx), label="Histogram")
    
    ax[1].set_ylim([0, 5])
    ax[1].set_xlim([0, analy_y.max() + 0.1])
    ax[1].legend(loc=4)
    ax[1].yaxis.tick_right()
    
with output:
    global fig, ax
    fig, ax = plt.subplots(1, 2, figsize=(8,6))
    plot_bandstructure('bands')

In [None]:
display(output, HBox([cell_type, VBox([nkpt, grange])]), btdos)