In [1]:
%cd ..

/Users/treycole/Codes/WanPy


In [2]:
from WanPy.pythtb_Wannier import *
import WanPy.models as models
import WanPy.plotting as plot

from pythtb import *
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import sympy as sp 
import scipy

In [3]:
delta = 1
t0 = 0.4
tprime = 0.5

model = models.chessboard(t0, tprime, delta).make_supercell([[2,0], [0,2]])

orbs = model.get_orb()
n_orb = model.get_num_orbitals()
n_occ = int(n_orb/2)
lat_vecs = model.get_lat() # lattice vectors

low_E_sites = np.arange(0, model.get_num_orbitals(), 2)
high_E_sites = np.arange(1, model.get_num_orbitals(), 2)

In [4]:
# get Bloch eigenstates on 2D k-mesh for Wannierization (exclude endpoints)
nkx = 16
nky = 16
Nk = nkx*nky
k_mesh = gen_k_mesh(nkx, nky, flat=False, endpoint=False)
u_wfs = wf_array(model, [nkx, nky])
for i in range(k_mesh.shape[0]):
    for j in range(k_mesh.shape[1]):
        u_wfs.solve_on_one_point(k_mesh[i,j], [i,j])

In [5]:
# Wannierization via single-shot projection
low_E_sites = np.arange(0, model.get_num_orbitals(), 2)
high_E_sites = np.arange(1, model.get_num_orbitals(), 2)
omit_sites = 4
tf_list = list(np.setdiff1d(low_E_sites, [omit_sites])) # delta on lower energy sites omitting the last site
psi_wfs = get_bloch_wfs(model, u_wfs, k_mesh, inverse=False)
psi_tilde = get_psi_tilde(psi_wfs, tf_list, state_idx=None)
u_tilde = get_bloch_wfs(model, psi_tilde, k_mesh, inverse=True)

In [6]:
M = k_overlap_mat(u_tilde, orbs=orbs) # [kx, ky, b, m, n]
spread, expc_rsq, expc_r_sq = spread_recip(model, M, decomp=True)

print(rf"Spread from M_kb of \tilde{{u_nk}} = {spread[0]}")
print(rf"Omega_I from M_kb of \tilde{{u_nk}} = {spread[1]}")
print(rf"Omega_til from M_kb of \tilde{{u_nk}} = {spread[2]}")

Spread from M_kb of \tilde{u_nk} = (0.8461633960951493+0j)
Omega_I from M_kb of \tilde{u_nk} = 0.7542011847776007
Omega_til from M_kb of \tilde{u_nk} = (0.09196221131753476+0j)


In [7]:
import time

start = time.time()

outer_states = u_wfs._wfs[..., :n_occ, :]
util_min_Wan = find_optimal_subspace(
    model, outer_states, u_tilde, iter_num=100, print_=True)

fin = time.time()
print()
print(f"time = {fin-start} s")

0 Omega_I: 78.19051933852954
1 Omega_I: 1.429297801543418
2 Omega_I: 0.7565582177634868
3 Omega_I: 0.7906845350320626
4 Omega_I: 0.7576156495331096
5 Omega_I: 0.7497332112352671
6 Omega_I: 0.7456588442469121
7 Omega_I: 0.7486079714089131
8 Omega_I: 0.7417211377551347
9 Omega_I: 0.7408304345835235
10 Omega_I: 0.7417020174486776
11 Omega_I: 0.7379441790069434
12 Omega_I: 0.7368364037165922
13 Omega_I: 0.7371938070946873
14 Omega_I: 0.7345374432578546
15 Omega_I: 0.7334702047081478
16 Omega_I: 0.7335027955863718
17 Omega_I: 0.7314445120150175
18 Omega_I: 0.7304267970356165
19 Omega_I: 0.7302841398962187
20 Omega_I: 0.7285845015665421
21 Omega_I: 0.727627061185721
22 Omega_I: 0.7273828645983782
23 Omega_I: 0.7259246579971904
24 Omega_I: 0.725028121539037
25 Omega_I: 0.7247248590821832
26 Omega_I: 0.7234425894936869
27 Omega_I: 0.7226044031397256
28 Omega_I: 0.7222678587168325
29 Omega_I: 0.7211215967226594
30 Omega_I: 0.7203380873948082
31 Omega_I: 0.719984679673525
32 Omega_I: 0.718948185

In [8]:
psi_til_min = get_bloch_wfs(model, util_min_Wan, k_mesh)
state_idx = list(range(psi_til_min.shape[2]))
psi_til_til_min = get_psi_tilde(psi_til_min, tf_list, state_idx=state_idx)
u_til_til_min = get_bloch_wfs(model, psi_til_til_min, k_mesh, inverse=True)

start = time.time()

U, _ = find_min_unitary(model, M, iter_num=100, eps=1e-3, print_=True)

fin = time.time()
print()
print(f"time = {fin-start} s")

0 Omega_til = (0.09195327881141376+0j), Grad mag:  1.79919
1 Omega_til = (0.09194482837507115+0j), Grad mag:  1.79636
2 Omega_til = (0.0919368252874401+0j), Grad mag:  1.79353
3 Omega_til = (0.09192923795461073+0j), Grad mag:  1.79070
4 Omega_til = (0.0919220375701006+0j), Grad mag:  1.78786
5 Omega_til = (0.09191519781864292+0j), Grad mag:  1.78502
6 Omega_til = (0.09190869461706523+0j), Grad mag:  1.78218
7 Omega_til = (0.09190250588685778+0j), Grad mag:  1.77934
8 Omega_til = (0.09189661135397702+0j), Grad mag:  1.77650
9 Omega_til = (0.0918909923720953+0j), Grad mag:  1.77366
10 Omega_til = (0.09188563176615594+0j), Grad mag:  1.77081
11 Omega_til = (0.0918805136935179+0j), Grad mag:  1.76797
12 Omega_til = (0.09187562352043335+0j), Grad mag:  1.76512
13 Omega_til = (0.09187094771191026+0j), Grad mag:  1.76227
14 Omega_til = (0.09186647373327084+0j), Grad mag:  1.75942
15 Omega_til = (0.09186218996199973+0j), Grad mag:  1.75657
16 Omega_til = (0.09185808560865133+0j), Grad mag:  1.

In [8]:
psi_til_min = get_bloch_wfs(model, util_min_Wan, k_mesh)
state_idx = list(range(psi_til_min.shape[2]))
psi_til_til_min = get_psi_tilde(psi_til_min, tf_list, state_idx=state_idx)
u_til_til_min = get_bloch_wfs(model, psi_til_til_min, k_mesh, inverse=True)

start = time.time()

U, _ = find_min_unitary(model, M, iter_num=100, eps=1e-3, print_=True)

fin = time.time()
print()
print(f"time = {fin-start} s")

0 Omega_til = (0.09195327881141376+0j), Grad mag:  1.79919
1 Omega_til = (0.09194482837507115+0j), Grad mag:  1.79636
2 Omega_til = (0.0919368252874401+0j), Grad mag:  1.79353
3 Omega_til = (0.09192923795461073+0j), Grad mag:  1.79070
4 Omega_til = (0.0919220375701006+0j), Grad mag:  1.78786
5 Omega_til = (0.09191519781864292+0j), Grad mag:  1.78502
6 Omega_til = (0.09190869461706523+0j), Grad mag:  1.78218
7 Omega_til = (0.09190250588685778+0j), Grad mag:  1.77934
8 Omega_til = (0.09189661135397702+0j), Grad mag:  1.77650
9 Omega_til = (0.0918909923720953+0j), Grad mag:  1.77366
10 Omega_til = (0.09188563176615594+0j), Grad mag:  1.77081
11 Omega_til = (0.0918805136935179+0j), Grad mag:  1.76797
12 Omega_til = (0.09187562352043335+0j), Grad mag:  1.76512
13 Omega_til = (0.09187094771191026+0j), Grad mag:  1.76227
14 Omega_til = (0.09186647373327084+0j), Grad mag:  1.75942
15 Omega_til = (0.09186218996199973+0j), Grad mag:  1.75657
16 Omega_til = (0.09185808560865133+0j), Grad mag:  1.

In [None]:
# outer window of entangled bands is full occupied manifold
outer_states = u_wfs_Wan._wfs[..., :n_occ, :]
w0_max_loc = max_loc_Wan(model, u_wfs_Wan, tf_list, outer_states, 
        iter_num_omega_i=300, iter_num_omega_til=500,
        state_idx=None, print_=True, return_uwfs=False, eps=2e-3
        )

Wan_idx = 0
plot.plot_Wan(w0_max_loc, Wan_idx, orbs, lat_vecs, plot_decay=True, show=True)