In [None]:
from wanpy.wpythtb import *
from wanpy.plot import *
from pythtb import *
from models import Haldane
import os

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

In [None]:
# tight-binding parameters
delta = 1
t = 1
t2 = -0.3

n_super_cell = 2
model = Haldane(delta, t, t2).make_supercell([[n_super_cell, 0], [0, n_super_cell]])

#############

n_orb = model.get_num_orbitals()
lat_vecs = model.get_lat()
orb_vecs = model.get_orb()
low_E_sites = np.arange(0, n_orb, 2)
high_E_sites = np.arange(1, n_orb, 2)
n_occ = int(n_orb/2)

bloch_eigstates = Bloch(model, 20, 20)
bloch_eigstates.solve_model()
chern = bloch_eigstates.chern_num().real

model_str = f'C={chern:.1f}_Delta={delta}_t={t}_t2={t2}'

print(f"Low energy sites: {low_E_sites}")
print(f"High energy sites: {high_E_sites}")
print(f"Chern # occupied: {chern: .1f}")

In [None]:
### Trial wavefunctions

# only one should be true
low_E = False

omit_sites = 6
tf_sites = list(np.setdiff1d(low_E_sites, [omit_sites])) # delta on lower energy sites omitting the last site
tf_list = [ [(orb, 1)] for orb in tf_sites]
n_tfs = len(tf_list)
Wan_frac = n_tfs/n_occ

save_sfx = model_str + f'_tfx={np.array(tf_sites, dtype=int)}'

print(f"Trial wavefunctions: {tf_list}")
print(f"# of Wannier functions: {n_tfs}")
print(f"# of occupied bands: {n_occ}")
print(f"Wannier fraction: {Wan_frac}")
print(save_sfx)

In [None]:
sv_dir = 'data'
if not os.path.exists(sv_dir):
    os.makedirs(sv_dir)
    
sv_prefix = 'WF_max_loc'
file_name = f"{sv_dir}/{sv_prefix}_{save_sfx}"

WF = np.load(f"{file_name}.npy", allow_pickle=True).item()

sv_prefix = 'WF_loc_steps'
file_name = f"{sv_dir}/{sv_prefix}_{save_sfx}"

loc_steps = np.load(f"{file_name}.npy", allow_pickle=True).item()

In [None]:
for key, val in loc_steps.items():
    print(f"{key}: Omega = {sum(val['Omega'])/3: .4f}")

In [None]:
for key, val in loc_steps.items():
    print(key, val)

# Plotting

In [None]:
plot_density(WF, 0, show_lattice=False, lat_size=2, cbar=False, return_fig=True, interpolate=False)

In [None]:
idx = 0
fig, ax = plot_decay(WF, idx, fit_rng=[5, 20], return_fig=True)
ax.legend(bbox_to_anchor=(0.6, 1.0))
inset_ax = inset_axes(ax, width="30%", height="30%", loc='upper right')  # You can adjust size and location
fig, inset_ax = plot_density(
    WF, idx, show_lattice=False, lat_size=2, cbar=False, interpolate=True, fig=fig, ax=inset_ax, return_fig=True)
inset_ax.set_xticks([])
inset_ax.set_yticks([])
inset_ax.axis('off')
plt.subplots_adjust(top=0.98, left=0.12, bottom=0.11, right=0.98)
plt.savefig(f'images/decay_and_dens_{save_sfx}.png', dpi=700)

In [None]:
title = (
  "Haldane model \n"
  fr"$C = {chern: .1f}$, $\Delta = {delta}$, $t= {t: .2f}, t_2 = {t2: .2f}$"
  )

# kwargs_centers = {'marker': '*', 'c': 'dodgerblue', 'alpha': 0.6} 
kwargs_centers = {'marker': 'o', 'c': 'dodgerblue', 'alpha': 0.6} 
# kwargs_omit = {'s': 70, 'marker': 'x', 'c': 'k', 'zorder': 3} 
kwargs_omit = {'s': 60, 'marker': 'x', 'c': 'crimson', 'zorder': 3} 
kwargs_lat_ev = {'s': 15, 'marker': 'o', 'c': 'k'} 
kwargs_lat_odd = {'s': 15, 'marker': 'o', 'facecolors':'none', 'edgecolors':'k'} 

fig, ax = plot_centers(
    WF, title=title, center_scale=50, omit_sites=[omit_sites], section_home_cell=True, 
    color_home_cell=False, translate_centers=True, kwargs_centers=kwargs_centers, 
    kwargs_lat_ev=kwargs_lat_ev, kwargs_lat_odd=kwargs_lat_odd, kwargs_omit=kwargs_omit, 
    pmx=3, pmy=3, legend=True)

ax.set_xticks([])
ax.set_yticks([])
ax.set_title('')

sv_dir = 'images'
sv_prefix = 'Wan_centers'
file_name = f"{sv_dir}/{sv_prefix}_{save_sfx}.png"

# plt.subplots_adjust(top=1, left=0, bottom=0, right=1)
plt.subplots_adjust(top=0.98, left=0.05, bottom=0.01, right=.95)

plt.savefig(file_name, dpi=700)

In [None]:
u_energy = WF.energy_eigstates.get_states()["Cell periodic"]  # energy eigenstates
P, Q = WF.energy_eigstates.get_projector(return_Q=True)  # full band projector

u_occ = u_energy[..., :n_occ, :]  # occupied energy eigenstates
P_occ = np.einsum("...ni, ...nj -> ...ij", u_occ, u_occ.conj())  # occupied band projector
Q_occ = np.eye(P_occ.shape[-1]) - P_occ[..., :, :]  # occ complement

u_tilde = WF.tilde_states.get_states()["Cell periodic"]  # reduced tilde states
P_triv, Q_triv = WF.tilde_states.get_projector(return_Q=True)  # tilde space projectors

## Projectors on full mesh
P_top = P_occ - P_triv  # complementary subspace
Q_top =  np.eye(P_top.shape[-1]) - P_top[..., :, :] 

eigvals, eigvecs = np.linalg.eigh(P_top)  # states spanning complement
u_top = eigvecs[..., :, -1]  # take state with non-zero eigval
u_top = u_top[..., np.newaxis, :]  # single state

In [None]:
k_path = [[0, 0], [2/3, 1/3], [.5, .5], [1/3, 2/3], [0, 0], [.5, .5]]
k_label = (r'$\Gamma $',r'$K$', r'$M$', r'$K^\prime$', r'$\Gamma $', r'$M$')
(k_vec, k_dist, k_node) = model.k_path(k_path, 501, report=False)

# Actual eigenstates and eigenenergies
evals, evecs = model.solve_ham(k_vec, return_eigvecs=True)
n_eigs = evecs.shape[-2]

# Reduced Wannier interpolated energies and states
interp_energies_triv, interp_unk_triv = WF.interp_energies(k_vec, ret_eigvecs=True)

# Complementary subspace interpolated energies and states
interp_energies_top, interp_unk_top = WF.interp_energies(k_vec, u_tilde=u_top, ret_eigvecs=True) 

In [None]:
fig, axs = plt.subplots(3, 1, sharex=True, constrained_layout=True)

############# subplot a ####################

# Actual bands
wt = abs(evecs)**2
col = np.sum([ wt[..., i] for i in high_E_sites], axis=0)
for n in range(evals.shape[-1]):
    axs[0].plot(k_dist, evals[:, n], c='k', lw=2, zorder=0)

    scat = axs[0].scatter(
        k_dist, evals[:, n], c=col[:, n], 
        cmap='plasma', marker='o', s=2, vmin=0, vmax=1, zorder=2)
    
cbar = fig.colorbar(scat, ticks=[1,0], pad=0.01)
cbar.ax.set_yticklabels([r'$\psi_B$', r'$\psi_A$'], size=12)
# cbar.ax.set_yticklabels([])
cbar.ax.tick_params(size=0) 

axs[0].set_xlim(0, k_node[-1])
axs[0].set_xticks([ ])
for n in range(len(k_node)):
    axs[0].axvline(x=k_node[n], linewidth=0.5, color='k', zorder=1)

axs[0].set_ylabel(r"Energy $E(\mathbf{{k}})$", size=12)
axs[0].yaxis.labelpad = 10
axs[0].set_ylim(-3.5, -0.18)

axs[0].text(-.1, -0.6, '(a)', size=12)

############# subplot b ####################

# Actual bands
for n in range(evals.shape[-1]):
    axs[1].plot(k_dist, evals[:, n], c='k', lw=2, zorder=0, alpha=0.25)

# Reduced bands
wt = abs(interp_unk_triv)**2
col = np.sum([ wt[..., i] for i in high_E_sites], axis=0)
for n in range(interp_energies_triv.shape[-1]):
    axs[1].plot(k_dist, interp_energies_triv[:, n], c='k', lw=2, zorder=0)

    scat = axs[1].scatter(
            k_dist, interp_energies_triv[:, n], c=col[:, n], 
            cmap='plasma', marker='o', s=2, vmin=0, vmax=1, zorder=2
            )

cbar = fig.colorbar(scat, ticks=[1,0], pad=0.01)
cbar.ax.set_yticklabels([r'$\psi_B$', r'$\psi_A$'], size=12)
cbar.ax.tick_params(size=0) 

axs[1].set_xlim(0, k_node[-1])
axs[1].set_xticks([ ])
for n in range(len(k_node)):
    axs[1].axvline(x=k_node[n], linewidth=0.5, color='k', zorder=1)

axs[1].set_ylabel(r"Energy $E(\mathbf{{k}})$", size=12)
axs[1].yaxis.labelpad = 10
axs[1].set_ylim(-3.5, -0.18)

axs[1].text(-.1, -0.6, '(b)', size=12)

############# subplot c ####################

# Actual bands
for n in range(evals.shape[-1]):
    axs[2].plot(k_dist, evals[:, n], c='k', lw=2, zorder=0, alpha=0.25)

# topological bands
wt = abs(interp_unk_top)**2
col = np.sum([ wt[..., i] for i in high_E_sites], axis=0)
for n in range(interp_energies_top.shape[-1]):
    axs[2].plot(k_dist, interp_energies_top[:, n], c='k', lw=2, zorder=0)

    scat = axs[2].scatter(
            k_dist, interp_energies_top[:, n], c=col[:, n], 
            cmap='plasma', marker='o', s=2, vmin=0, vmax=1, zorder=2)

cbar = fig.colorbar(scat, ticks=[1,0], pad=0.01)
cbar.ax.set_yticklabels([r'$\psi_B$', r'$\psi_A$'], size=12)
cbar.ax.tick_params(size=0) 

axs[2].set_xlim(0, k_node[-1])
axs[2].set_xticks(k_node)
for n in range(len(k_node)):
    axs[2].axvline(x=k_node[n], linewidth=0.5, color='k', zorder=1)
if k_label is not None:
    axs[2].set_xticklabels(k_label, size=12)

axs[2].set_ylabel(r"Energy $E(\mathbf{{k}})$", size=12)
axs[2].yaxis.labelpad = 10
axs[2].set_ylim(-3.5, -0.18)
axs[2].text(-.1, -0.6, '(c)', size=12)

#######################

# plt.subplots_adjust(top=.97, left=0.13, bottom=0.07, right=1.06, hspace=0.07)
# fig.set_size_inches()
# plt.savefig(f"images/interp_{band_type}_bands.png", dpi=700)
plt.savefig(f"images/interp_all_bands_{save_sfx}.png", dpi=700)

In [None]:
P_triv.shape

In [None]:
prod = np.einsum("...ij, ...jk -> ...ik", P_triv, P_top)

prod[0,0].round(3)

In [None]:
pos = {'xs': [], 'ys': []}
for i, orb in enumerate(orb_vecs):

    # Extract relevant parameters
    r = orb[0] * lat_vecs[0] + orb[1] * lat_vecs[1] 
    x, y = r[0], r[1]

    # Store values in 'all'
    pos['xs'].append(x)
    pos['ys'].append(y)

In [None]:
x = np.array(pos['xs'])
y = np.array(pos['ys'])

In [None]:
plt.scatter(x, y, c=abs(u_top[13, 6])**2)
plt.colorbar()

In [None]:
plt.scatter(x, y, c=abs(u_top[14, 7])**2)
plt.colorbar()

In [None]:
u_top.shape

In [None]:
u_top[13, 7]

In [None]:
u_omit_site = (abs(u_top)**2).round(2)[..., 0, -2]
idx = np.where(u_omit_site == np.amin(u_omit_site))

In [None]:
idx

In [None]:
u_omit_site[idx]