In [1]:
%cd ..

/Users/treycole/Codes/WanPy


In [2]:
%load_ext line_profiler
%timeit

In [3]:
from WanPy.WanPy import *
import WanPy.models as models
import WanPy.plotting as plot

from pythtb import *
import numba
import numpy as np
import sympy as sp
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import sympy as sp 
import scipy

In [4]:
delta = 1
t = -1
t2 = 0.2

model = models.Haldane(delta, t, t2).make_supercell([[2,0], [0,2]])

orbs = model.get_orb()
n_orb = model.get_num_orbitals()
n_occ = int(n_orb/2)
lat_vecs = model.get_lat() # lattice vectors

low_E_sites = np.arange(0, model.get_num_orbitals(), 2)
high_E_sites = np.arange(1, model.get_num_orbitals(), 2)

In [5]:
nkx = 50
nky = 50
Nk = nkx*nky

k_mesh = gen_k_mesh(nkx, nky, flat=False, endpoint=False)
k_mesh_flat = gen_k_mesh(nkx, nky, flat=True, endpoint=False)

u_wfs_Wan = wf_array(model, [nkx, nky])
for i in range(k_mesh.shape[0]):
    for j in range(k_mesh.shape[1]):
        u_wfs_Wan.solve_on_one_point(k_mesh[i,j], [i,j])
psi_wfs_Wan = get_bloch_wfs(orbs, u_wfs_Wan, k_mesh)

In [6]:
omit_sites = 4
tf_list = list(np.setdiff1d(low_E_sites, [omit_sites])) # delta on lower energy sites omitting the last site

W0, psi_til_wan = Wannierize(orbs, u_wfs_Wan, tf_list, ret_psi_til=True)
u_tilde_wan = get_bloch_wfs(orbs, psi_til_wan, k_mesh, inverse=True)
M = k_overlap_mat(lat_vecs, orbs, u_tilde_wan)  # [kx, ky, b, m, n]

In [7]:
spread, expc_rsq, expc_r_sq = spread_recip(lat_vecs, M, decomp=True)

print("After first projection")
print(rf"Spread = {spread[0]}")
print(rf"Omega_I = {spread[1]}")
print(rf"Omega_til = {spread[2]}")

After first projection
Spread = [0.34095246 0.34095246 0.34095246]
Omega_I = 0.7670224405244426
Omega_til = 0.2558349249342902


In [8]:
# outer window of entangled bands is full occupied manifold
outer_states = u_wfs_Wan._wfs[..., :n_occ, :]

In [9]:
W0_max_loc, Wf_cntrs = max_loc_Wan(lat_vecs, orbs, u_wfs_Wan, tf_list, outer_states, 
        iter_num_omega_i=10, iter_num_omega_til=10,
        Wan_idxs=None, verbose=True, return_uwfs=False, return_wf_centers=True, eps=1e-3
        )

0 Omega_I: 0.7603453117849567
1 Omega_I: 0.7533208472432963
2 Omega_I: 0.7453002017518733
3 Omega_I: 0.7368492908633715
4 Omega_I: 0.7285388581697609
5 Omega_I: 0.7207154115157237
6 Omega_I: 0.7134991151273813
7 Omega_I: 0.7068874786939804
8 Omega_I: 0.7008287554377128
9 Omega_I: 0.6952592999804748
0 Omega_til = 0.1498241078573886, Grad mag: 2353.355788873548
1 Omega_til = 0.1542872280639713, Grad mag: 2352.6498024917914
2 Omega_til = 0.16094568347257604, Grad mag: 2352.1006979706403
3 Omega_til = 0.16905939170496667, Grad mag: 2351.686878728778
4 Omega_til = 0.17808522113905578, Grad mag: 2351.371107937544
5 Omega_til = 0.18759975358537687, Grad mag: 2351.116322002369
6 Omega_til = 0.19729064357227555, Grad mag: 2350.9105072622906
7 Omega_til = 0.20692517147627956, Grad mag: 2350.7404886839126
8 Omega_til = 0.21633545922086908, Grad mag: 2350.598566797831
9 Omega_til = 0.2254033415838086, Grad mag: 2350.478906866375
Post processing report:
 --------------- 
Quadratic spread = [0.30688

In [10]:
%timeit max_loc_Wan(lat_vecs, orbs, u_wfs_Wan, tf_list, outer_states, iter_num_omega_i=10, iter_num_omega_til=10, verbose=True, return_uwfs=False, return_wf_centers=True, eps=1e-3)

0 Omega_I: 0.7603453117849567
1 Omega_I: 0.7533208472432963
2 Omega_I: 0.7453002017518733
3 Omega_I: 0.7368492908633715
4 Omega_I: 0.7285388581697609
5 Omega_I: 0.7207154115157237
6 Omega_I: 0.7134991151273813
7 Omega_I: 0.7068874786939804
8 Omega_I: 0.7008287554377128
9 Omega_I: 0.6952592999804748
0 Omega_til = 0.1498241078573886, Grad mag: 2353.355788873548
1 Omega_til = 0.1542872280639713, Grad mag: 2352.6498024917914
2 Omega_til = 0.16094568347257604, Grad mag: 2352.1006979706403
3 Omega_til = 0.16905939170496667, Grad mag: 2351.686878728778
4 Omega_til = 0.17808522113905578, Grad mag: 2351.371107937544
5 Omega_til = 0.18759975358537687, Grad mag: 2351.116322002369
6 Omega_til = 0.19729064357227555, Grad mag: 2350.9105072622906
7 Omega_til = 0.20692517147627956, Grad mag: 2350.7404886839126
8 Omega_til = 0.21633545922086908, Grad mag: 2350.598566797831
9 Omega_til = 0.2254033415838086, Grad mag: 2350.478906866375
Post processing report:
 --------------- 
Quadratic spread = [0.30688

In [12]:
%timeit max_loc_Wan(lat_vecs, orbs, u_wfs_Wan, tf_list, outer_states, iter_num_omega_i=10, iter_num_omega_til=10, verbose=True, return_uwfs=False, return_wf_centers=True, eps=1e-3)

0 Omega_I: 0.7603453117849567
1 Omega_I: 0.7533208472432963
2 Omega_I: 0.7453002017518733
3 Omega_I: 0.7368492908633715
4 Omega_I: 0.7285388581697609
5 Omega_I: 0.7207154115157237
6 Omega_I: 0.7134991151273813
7 Omega_I: 0.7068874786939804
8 Omega_I: 0.7008287554377128
9 Omega_I: 0.6952592999804748
0 Omega_til = 0.14819846050774663, Grad mag: 123.74818548476773
1 Omega_til = 0.14801822242475177, Grad mag: 123.03762269272599
2 Omega_til = 0.14788516635086332, Grad mag: 122.3237322245452
3 Omega_til = 0.14778073965719996, Grad mag: 121.6833475549908
4 Omega_til = 0.14769612170581264, Grad mag: 121.1080088593022
5 Omega_til = 0.1476261038019672, Grad mag: 120.59202755022105
6 Omega_til = 0.14756726714208115, Grad mag: 120.12889458420827
7 Omega_til = 0.14751725343534777, Grad mag: 119.7132960245399
8 Omega_til = 0.1474743649512787, Grad mag: 119.34024044037784
9 Omega_til = 0.1474373301026542, Grad mag: 119.00531934123241
Post processing report:
 --------------- 
Quadratic spread = [0.280

In [11]:
%timeit get_max_loc_uwfs(lat_vecs, orbs, u_tilde_wan, eps=1e-3, iter_num=100, verbose=True, tol=1e-17)

0 Omega_til = 0.2712040979133299, Grad mag: 2870.9475345915416
1 Omega_til = 0.31604076854527074, Grad mag: 2870.1686524410466
2 Omega_til = 0.3826379293570119, Grad mag: 2868.872617673531
3 Omega_til = 0.4631049082157471, Grad mag: 2867.949417964843
4 Omega_til = 0.5525793696825393, Grad mag: 2867.1540193389083
5 Omega_til = 0.6467093309345324, Grad mag: 2866.4097480517726
6 Omega_til = 0.7425287609500256, Grad mag: 2865.655702561201
7 Omega_til = 0.8377325790958348, Grad mag: 2864.912738987589
8 Omega_til = 0.93068653886654, Grad mag: 2864.157939388015
9 Omega_til = 1.020234361780789, Grad mag: 2863.401984605218
10 Omega_til = 1.1055990047872584, Grad mag: 2862.6497021012697
11 Omega_til = 1.1862932582663732, Grad mag: 2861.908150849283
12 Omega_til = 1.2620505510993714, Grad mag: 2861.183953245296
13 Omega_til = 1.3327702173673546, Grad mag: 2860.4829555816955
14 Omega_til = 1.3984744557497448, Grad mag: 2859.80997360281
15 Omega_til = 1.459274623796338, Grad mag: 2859.168709555866


KeyboardInterrupt: 

In [42]:
%timeit find_min_unitary(lat_vecs, M, iter_num=100, eps=1e-3, verbose=True, tol=1e-17)

0 Omega_til = [0.15491402], Grad mag: 4.726324345017292
1 Omega_til = [0.1548091], Grad mag: 4.700296710623302
2 Omega_til = [0.15471152], Grad mag: 4.6743730249579025
3 Omega_til = [0.15462057], Grad mag: 4.6485551397037685
4 Omega_til = [0.15453561], Grad mag: 4.622844737455427
5 Omega_til = [0.15445608], Grad mag: 4.597243344872234
6 Omega_til = [0.15438147], Grad mag: 4.571752344646321
7 Omega_til = [0.15431134], Grad mag: 4.546372986410806
8 Omega_til = [0.15424529], Grad mag: 4.5211063966986735
9 Omega_til = [0.15418297], Grad mag: 4.495953588049235
10 Omega_til = [0.15412406], Grad mag: 4.470915467347904
11 Omega_til = [0.15406828], Grad mag: 4.445992843474945
12 Omega_til = [0.15401538], Grad mag: 4.421186434330243
13 Omega_til = [0.15396512], Grad mag: 4.396496873293582
14 Omega_til = [0.15391731], Grad mag: 4.371924715173378
15 Omega_til = [0.15387176], Grad mag: 4.347470441691008
16 Omega_til = [0.1538283], Grad mag: 4.323134466542797
17 Omega_til = [0.15378678], Grad mag: 4

Timer unit: 1e-09 s

Total time: 0.758165 s
File: /var/folders/nn/m4t491h92ss8vwl56z761h6c0000gn/T/ipykernel_64561/3231049135.py
Function: find_min_unitary at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents

In [44]:
(2.84-.75)/2.84

0.7359154929577465

In [43]:
%lprun -f find_min_unitary_og find_min_unitary_og(lat_vecs, M, iter_num=100, eps=1e-3, verbose=True, tol=1e-17)

0 Omega_til = [0.15491402], Grad mag: 4.726324345017287
1 Omega_til = [0.1548091], Grad mag: 4.700296710623308
2 Omega_til = [0.15471152], Grad mag: 4.6743730249579025
3 Omega_til = [0.15462057], Grad mag: 4.648555139703763
4 Omega_til = [0.15453561], Grad mag: 4.622844737455426
5 Omega_til = [0.15445608], Grad mag: 4.597243344872239
6 Omega_til = [0.15438147], Grad mag: 4.5717523446463195
7 Omega_til = [0.15431134], Grad mag: 4.546372986410807
8 Omega_til = [0.15424529], Grad mag: 4.52110639669867
9 Omega_til = [0.15418297], Grad mag: 4.4959535880492325
10 Omega_til = [0.15412406], Grad mag: 4.470915467347907
11 Omega_til = [0.15406828], Grad mag: 4.445992843474946
12 Omega_til = [0.15401538], Grad mag: 4.421186434330246
13 Omega_til = [0.15396512], Grad mag: 4.396496873293585
14 Omega_til = [0.15391731], Grad mag: 4.37192471517338
15 Omega_til = [0.15387176], Grad mag: 4.347470441691009
16 Omega_til = [0.1538283], Grad mag: 4.323134466542796
17 Omega_til = [0.15378678], Grad mag: 4.2

Timer unit: 1e-09 s

Total time: 2.84149 s
File: /var/folders/nn/m4t491h92ss8vwl56z761h6c0000gn/T/ipykernel_64561/3147072443.py
Function: find_min_unitary_og at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents

In [20]:
def spread_recip(lat_vecs, M, decomp=False):
    """
    Args:
        M (np.array):
            overlap matrix
        decomp (bool, optional):
            Whether to compute and return decomposed spread. Defaults to False.

    Returns:
        spread | [spread, Omega_i, Omega_tilde], expc_rsq, expc_r_sq :
            quadratic spread, the expectation of the position squared,
            and the expectation of the position vector squared
    """
    shape = M.shape
    n_states = shape[3]
    nks = M.shape[:-3]
    k_axes = tuple([i for i in range(len(nks))])
    Nk = np.prod(nks)

    # Assumes only one shell for now
    k_shell, _ = get_k_shell(*nks, lat_vecs=lat_vecs, N_sh=1, tol_dp=8, report=False)
    w_b = get_weights(*nks, lat_vecs=lat_vecs, N_sh=1)[0]
    r_n = -(1 / Nk) * w_b * np.sum(
            np.log(np.diagonal(M, axis1=-1, axis2=-2)).imag, axis=k_axes).T @ k_shell[0]
    rsq_n = (1 / Nk) * w_b * np.sum(
        (1 - abs(np.diagonal(M, axis1=-1, axis2=-2)) ** 2 + np.log(np.diagonal(M, axis1=-1, axis2=-2)).imag ** 2), 
        axis=k_axes+tuple([-2]))
    expc_rsq = np.sum(rsq_n)  # <r^2>
    expc_r_sq = np.sum([np.vdot(r_n[n, :], r_n[n, :]) for n in range(r_n.shape[0])])  # <\vec{r}>^2
    spread = expc_rsq - expc_r_sq
    if decomp:
        Omega_i = w_b * n_states * k_shell[0].shape[0] - (1 / Nk) * w_b * np.sum(abs(M) **2)
        Omega_tilde = (1 / Nk) * w_b * ( 
            np.sum((-np.log(np.diagonal(M, axis1=-1, axis2=-2)).imag - k_shell[0] @ r_n.T)**2) + 
            np.sum(abs(M)**2) - np.sum( abs(np.diagonal(M, axis1=-1, axis2=-2))**2)
        )
        return [spread, Omega_i, Omega_tilde], r_n, rsq_n

    else:
        return spread, r_n, rsq_n

In [21]:
def spread_recip2(lat_vecs, M, decomp=False):
    """
    Args:
        M (np.array):
            overlap matrix
        decomp (bool, optional):
            Whether to compute and return decomposed spread. Defaults to False.

    Returns:
        spread | [spread, Omega_i, Omega_tilde], expc_rsq, expc_r_sq :
            quadratic spread, the expectation of the position squared,
            and the expectation of the position vector squared
    """
    shape = M.shape
    n_states = shape[3]
    nks = M.shape[:-3]
    k_axes = tuple([i for i in range(len(nks))])
    Nk = np.prod(nks)

    # Assumes only one shell for now
    k_shell, _ = get_k_shell(*nks, lat_vecs=lat_vecs, N_sh=1, tol_dp=8, report=False)
    w_b = get_weights(*nks, lat_vecs=lat_vecs, N_sh=1)[0]
    
    r_n = np.zeros((n_states, 2), dtype=complex)  # <\vec{r}>_n
    rsq_n = np.zeros(n_states, dtype=complex)  # <r^2>_n

    for n in range(n_states):
        for idx, b in enumerate(k_shell[0]):
            r_n[n, :] += -(1 / Nk) * w_b * b * np.sum(np.log(M[..., idx, n, n]).imag, axis=k_axes)
            rsq_n[n] += np.sum(
                    (1 / Nk) * w_b
                    * (1 - abs(M[..., idx, n, n]) ** 2
                       + np.log(M[..., idx, n, n]).imag ** 2
                       ), axis=k_axes
                )

    expc_rsq = np.sum(rsq_n)  # <r^2>
    expc_r_sq = np.sum([np.vdot(r_n[n, :], r_n[n, :]) for n in range(r_n.shape[0])])  # <\vec{r}>^2
    spread = expc_rsq - expc_r_sq
    if decomp:
        Omega_i = 0
        Omega_tilde = 0
        for idx, b in enumerate(k_shell[0]):
            Omega_i += w_b * n_states
            for n in range(n_states):
                Omega_tilde += np.sum(
                    (1 / Nk) * w_b
                    * (-np.log(M[..., idx, n, n]).imag - np.vdot(b, r_n[n])) ** 2, 
                    axis=k_axes
                )
                for m in range(n_states):
                    Omega_i -= np.sum((1 / Nk) * w_b * abs(M[..., idx, m, n]) ** 2, axis=k_axes)
                    if m != n:
                        Omega_tilde += np.sum((1 / Nk) * w_b * abs(M[..., idx, m, n]) ** 2, axis=k_axes)

        return [spread, Omega_i, Omega_tilde], r_n, rsq_n

    else:
        return spread, r_n, rsq_n

In [22]:
spread_recip(lat_vecs, M, decomp=True)

([0.6822009865245171, array([0.52717388]), array([0.1550271])],
 array([[0.5       , 0.28525905],
        [1.00295841, 1.15640858],
        [1.99704159, 1.15640858]]),
 array([0.55877306, 2.5706067 , 5.55285623]))

In [23]:
spread_recip2(lat_vecs, M, decomp=True)

([(0.6822009865245189+0j), array([0.52717388]), (0.1550271029866492+0j)],
 array([[0.5       +0.j, 0.28525905+0.j],
        [1.00295841+0.j, 1.15640858+0.j],
        [1.99704159+0.j, 1.15640858+0.j]]),
 array([0.55877306+0.j, 2.5706067 +0.j, 5.55285623+0.j]))

In [24]:
%lprun -f spread_recip spread_recip(lat_vecs, M, decomp=True)

Timer unit: 1e-09 s

Total time: 0.002208 s
File: /var/folders/nn/m4t491h92ss8vwl56z761h6c0000gn/T/ipykernel_64561/3999020790.py
Function: spread_recip at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def spread_recip(lat_vecs, M, decomp=False):
     2                                               """
     3                                               Args:
     4                                                   M (np.array):
     5                                                       overlap matrix
     6                                                   decomp (bool, optional):
     7                                                       Whether to compute and return decomposed spread. Defaults to False.
     8                                           
     9                                               Returns:
    10                                                   spread | [spread, Omega_i, Omega_tild

In [25]:
%lprun -f spread_recip2 spread_recip2(lat_vecs, M, decomp=True)

Timer unit: 1e-09 s

Total time: 0.002409 s
File: /var/folders/nn/m4t491h92ss8vwl56z761h6c0000gn/T/ipykernel_64561/3744156696.py
Function: spread_recip2 at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def spread_recip2(lat_vecs, M, decomp=False):
     2                                               """
     3                                               Args:
     4                                                   M (np.array):
     5                                                       overlap matrix
     6                                                   decomp (bool, optional):
     7                                                       Whether to compute and return decomposed spread. Defaults to False.
     8                                           
     9                                               Returns:
    10                                                   spread | [spread, Omega_i, Omega_ti

In [26]:
def spread_real(lat_vecs, orbs, w0, decomp=False):
    """
    Spread functional computed in real space with Wannier functions

    Args:
        w0 (np.array): Wannier functions
        supercell (np.array): lattice translation vectors in reduced units
        orbs (np.array): orbital vectors in reduced units
        decomp (boolean): whether to separate gauge (in)variant parts of spread

    Returns:
        Omega: the spread functional
        Omega_inv: (optional) the gauge invariant part of the spread
        Omega_tilde: (optional) the gauge dependent part of the spread
        expc_rsq: \sum_n <r^2>_{n}
        expc_r_sq: \sum_n <\vec{r}>_{n}^2
    """
    # shape = w0.shape # [*nks, idx, orb]
    # nxs = shape[:-2]
    # n_orb = shape[-1]
    # n_states = shape[-2]
    # assuming 2D for now
    nx, ny, n_wfs = w0.shape[0], w0.shape[1], w0.shape[2]
    # translation vectors in reduced units
    supercell = [
        (i, j) for i in range(-nx // 2, nx // 2) for j in range(-ny // 2, ny // 2)
    ]

    r_n = np.zeros((n_wfs, 2), dtype=complex)  # <\vec{r}>_n
    rsq_n = np.zeros(n_wfs, dtype=complex)  # <r^2>_n
    R_nm = np.zeros((2, n_wfs, n_wfs, nx * ny), dtype=complex)

    expc_rsq = 0  # <r^2>
    expc_r_sq = 0  # <\vec{r}>^2

    for n in range(n_wfs):  # "band" index
        for tx, ty in supercell:  # cells in supercell
            for i, orb in enumerate(orbs):  # values of Wannier function on lattice
                pos = (orb[0] + tx) * lat_vecs[0] + (orb[1] + ty) * lat_vecs[1]  # position
                r = np.sqrt(pos[0] ** 2 + pos[1] ** 2)

                w0n_r = w0[tx, ty, n, i]  # Wannier function

                # expectation value of position (vector)
                r_n[n, :] += abs(w0n_r) ** 2 * pos
                rsq_n[n] += r**2 * w0n_r * w0n_r.conj()

                if decomp:
                    for m in range(n_wfs):
                        for j, [dx, dy] in enumerate(supercell):
                            wRm_r = w0[
                                (tx + dx) % nx, (ty + dy) % ny, m, i
                            ]  # translated Wannier function
                            R_nm[:, n, m, j] += w0n_r * wRm_r.conj() * np.array(pos)

        expc_rsq += rsq_n[n]
        expc_r_sq += np.vdot(r_n[n, :], r_n[n, :])

    spread = expc_rsq - expc_r_sq

    if decomp:
        sigma_Rnm_sq = np.sum(np.abs(R_nm) ** 2)
        Omega_inv = expc_rsq - sigma_Rnm_sq
        Omega_tilde = sigma_Rnm_sq - np.sum(
            np.abs(
                np.diagonal(R_nm[:, :, :, supercell.index((0, 0))], axis1=1, axis2=2)
            )** 2
        )

        assert np.allclose(spread, Omega_inv + Omega_tilde)
        return [spread, Omega_inv, Omega_tilde], r_n, rsq_n

    else:
        return spread, r_n, rsq_n

In [27]:
def spread_real2(lat_vecs, orbs, w0, decomp=False):
    """
    Spread functional computed in real space with Wannier functions

    Args:
        w0 (np.array): Wannier functions
        supercell (np.array): lattice translation vectors in reduced units
        orbs (np.array): orbital vectors in reduced units
        decomp (boolean): whether to separate gauge (in)variant parts of spread

    Returns:
        Omega: the spread functional
        Omega_inv: (optional) the gauge invariant part of the spread
        Omega_tilde: (optional) the gauge dependent part of the spread
        expc_rsq: \sum_n <r^2>_{n}
        expc_r_sq: \sum_n <\vec{r}>_{n}^2
    """
    # shape = w0.shape # [*nks, idx, orb]
    # nxs = shape[:-2]
    # n_orb = shape[-1]
    # n_states = shape[-2]
    # assuming 2D for now
    nx, ny, n_wfs = w0.shape[0], w0.shape[1], w0.shape[2]
    # translation vectors in reduced units
    supercell = [
        (i, j) for i in range(-nx // 2, nx // 2) for j in range(-ny // 2, ny // 2)
    ]

    r_n = np.zeros((n_wfs, 2), dtype=complex)  # <\vec{r}>_n
    rsq_n = np.zeros(n_wfs, dtype=complex)  # <r^2>_n
    R_nm = np.zeros((2, n_wfs, n_wfs, nx * ny), dtype=complex)

    expc_rsq = 0  # <r^2>
    expc_r_sq = 0  # <\vec{r}>^2

    for n in range(n_wfs):  # "band" index
        for tx, ty in supercell:  # cells in supercell
            for i, orb in enumerate(orbs):  # values of Wannier function on lattice
                pos = (orb[0] + tx) * lat_vecs[0] + (orb[1] + ty) * lat_vecs[1]  # position
                r = np.sqrt(pos[0] ** 2 + pos[1] ** 2)

                w0n_r = w0[tx, ty, n, i]  # Wannier function

                # expectation value of position (vector)
                r_n[n, :] += abs(w0n_r) ** 2 * pos
                rsq_n[n] += r**2 * w0n_r * w0n_r.conj()

                if decomp:
                    for m in range(n_wfs):
                        for j, [dx, dy] in enumerate(supercell):
                            wRm_r = w0[
                                (tx + dx) % nx, (ty + dy) % ny, m, i
                            ]  # translated Wannier function
                            R_nm[:, n, m, j] += w0n_r * wRm_r.conj() * np.array(pos)

        expc_rsq += rsq_n[n]
        expc_r_sq += np.vdot(r_n[n, :], r_n[n, :])

    spread = expc_rsq - expc_r_sq

    if decomp:
        sigma_Rnm_sq = np.sum(np.abs(R_nm) ** 2)
        Omega_inv = expc_rsq - sigma_Rnm_sq
        Omega_tilde = sigma_Rnm_sq - np.sum(
            np.abs(
                np.diagonal(R_nm[:, :, :, supercell.index((0, 0))], axis1=1, axis2=2)
            )** 2
        )

        assert np.allclose(spread, Omega_inv + Omega_tilde)
        return [spread, Omega_inv, Omega_tilde], r_n, rsq_n

    else:
        return spread, r_n, rsq_n

In [28]:
spread_real(lat_vecs, orbs, W0, decomp=False)

((0.8518335202059397-1.7140078712518223e-18j),
 array([[0.4981783 +0.j, 0.2841202 +0.j],
        [1.00074683+0.j, 1.15459444+0.j],
        [1.99372669+0.j, 1.15459444+0.j]]),
 array([0.6085081 -6.32359057e-19j, 2.61829717-8.02907886e-20j,
        5.59655112-1.00135803e-18j]))

In [29]:
spread_real2(lat_vecs, orbs, W0, decomp=False)

((0.8518335202059397-1.7140078712518223e-18j),
 array([[0.4981783 +0.j, 0.2841202 +0.j],
        [1.00074683+0.j, 1.15459444+0.j],
        [1.99372669+0.j, 1.15459444+0.j]]),
 array([0.6085081 -6.32359057e-19j, 2.61829717-8.02907886e-20j,
        5.59655112-1.00135803e-18j]))

In [30]:
%lprun -f spread_real spread_real(lat_vecs, orbs, W0, decomp=False)

Timer unit: 1e-09 s

Total time: 0.027833 s
File: /var/folders/nn/m4t491h92ss8vwl56z761h6c0000gn/T/ipykernel_64561/4163014881.py
Function: spread_real at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def spread_real(lat_vecs, orbs, w0, decomp=False):
     2                                               """
     3                                               Spread functional computed in real space with Wannier functions
     4                                           
     5                                               Args:
     6                                                   w0 (np.array): Wannier functions
     7                                                   supercell (np.array): lattice translation vectors in reduced units
     8                                                   orbs (np.array): orbital vectors in reduced units
     9                                                   decomp (boolea

In [31]:
%lprun -f spread_real2 spread_real2(lat_vecs, orbs, W0, decomp=False)

Timer unit: 1e-09 s

Total time: 0.022621 s
File: /var/folders/nn/m4t491h92ss8vwl56z761h6c0000gn/T/ipykernel_64561/2655024618.py
Function: spread_real2 at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def spread_real2(lat_vecs, orbs, w0, decomp=False):
     2                                               """
     3                                               Spread functional computed in real space with Wannier functions
     4                                           
     5                                               Args:
     6                                                   w0 (np.array): Wannier functions
     7                                                   supercell (np.array): lattice translation vectors in reduced units
     8                                                   orbs (np.array): orbital vectors in reduced units
     9                                                   decomp (bool

In [44]:
def find_optimal_subspace(
    lat_vecs, orbs, outer_states, inner_states, iter_num=100, verbose=False, tol=1e-17, alpha=1
):
    if isinstance(inner_states, wf_array):
        shape = inner_states._wfs.shape  # [*nks, idx, orb]
        inner_states = np.array(inner_states._wfs)
    else:
        shape = inner_states.shape  # [*nks, idx, orb]

    nks = shape[:-2]
    Nk = np.prod(nks)
    n_orb = shape[-1]
    n_states = shape[-2]
    dim_subspace = n_states
    k_idx_arr = list(
        product(*[range(nk) for nk in nks])
    )  # all pairwise combinations of k_indices
   
    # Assumes only one shell for now
    w_b, k_shell, idx_shell = get_weights(*nks, lat_vecs=lat_vecs, N_sh=1)
    w_b = w_b[0]
    bc_phase = get_boundary_phase(*nks, orbs=orbs, idx_shell=idx_shell)

    P = np.einsum("...ni, ...nj->...ij", inner_states, inner_states.conj())

    # Projector on initial subspace at each k (for pbc of neighboring spaces)
    P_nbr = np.zeros((*nks, len(k_shell[0]), n_orb, n_orb), dtype=complex)
    Q_nbr = np.zeros((*nks, len(k_shell[0]), n_orb, n_orb), dtype=complex)
    T_kb = np.zeros((*nks, len(k_shell[0])), dtype=complex)

    for idx, idx_vec in enumerate(idx_shell[0]):  # nearest neighbors
        states_pbc = np.roll(inner_states, shift=tuple(-idx_vec), axis=(0,1)) * bc_phase[..., idx, np.newaxis,  :]
        P_nbr[..., idx, :, :] = np.einsum(
                "...ni, ...nj->...ij", states_pbc, states_pbc.conj()
                )
        Q_nbr[..., idx, :, :] = np.eye(n_orb) - P_nbr[..., idx, :, :]

    P_min = np.copy(P)  # start of iteration
    P_nbr_min = np.copy(P_nbr)  # start of iteration
    Q_nbr_min = np.copy(Q_nbr)  # start of iteration

    # states spanning optimal subspace minimizing gauge invariant spread
    states_min = np.zeros((*nks, dim_subspace, n_orb), dtype=complex)

    M = k_overlap_mat(lat_vecs, orbs, inner_states)
    spread, _, _ = spread_recip(lat_vecs, M, decomp=True)
    omega_I_prev = spread[1]

    for i in range(iter_num):
        P_avg = np.sum(w_b * P_nbr_min, axis=-3)
        Z = outer_states[..., :, :].conj() @ P_avg @ np.transpose(outer_states[..., : ,:], axes=(0,1,3,2))
        eigvals, eigvecs = np.linalg.eigh(Z)  # [val, idx]

        for k_idx in k_idx_arr:
            evals, evecs = eigvals[k_idx], eigvecs[k_idx]
            for idx, n in enumerate(
                np.argsort(evals.real)[-dim_subspace:]
            ):  # keep ntfs wfs with highest eigenvalue
                states_min[k_idx][idx, :] = np.sum(
                    [
                        evecs[i, n] * outer_states[k_idx][i, :]
                        for i in range(evecs.shape[0])
                    ],
                    axis=0,
                )

        P_new = np.einsum("...ni,...nj->...ij", states_min, states_min.conj())
        P_min = alpha * P_new + (1 - alpha) * P_min[k_idx] # for next iteration
        for idx, idx_vec in enumerate(idx_shell[0]):  # nearest neighbors
            states_pbc = np.roll(states_min, shift=tuple(-idx_vec), axis=(0,1)) * bc_phase[..., idx, np.newaxis,  :]
            P_nbr_min[..., idx, :, :] = np.einsum(
                    "...ni, ...nj->...ij", states_pbc, states_pbc.conj()
                    )
            Q_nbr_min[..., idx, :, :] = np.eye(n_orb) - P_nbr_min[..., idx, :, :]
            T_kb[..., idx] = np.trace(P_min[..., :, :] @ Q_nbr_min[..., idx, :, :], axis1=-1, axis2=-2)
        
        omega_I_new = (1 / Nk) * w_b * np.sum(T_kb)

        if omega_I_new > omega_I_prev:
            print("Warning: Omega_I is increasing.")

        if abs(omega_I_prev - omega_I_new) <= tol:
            print("Omega_I has converged within tolerance. Breaking loop")
            return states_min

        if verbose:
            print(f"{i} Omega_I: {omega_I_new.real}")

        omega_I_prev = omega_I_new

    return states_min

In [32]:
def find_optimal_subspace_og(
    lat_vecs, orbs, outer_states, inner_states, iter_num=100, verbose=False, tol=1e-17
):

    if isinstance(inner_states, wf_array):
        shape = inner_states._wfs.shape  # [*nks, idx, orb]
        inner_states = np.array(inner_states._wfs)
    else:
        shape = inner_states.shape  # [*nks, idx, orb]

    nks = shape[:-2]
    Nk = np.prod(nks)
    n_orb = shape[-1]
    n_states = shape[-2]
    dim_subspace = n_states
    k_idx_arr = list(
        product(*[range(nk) for nk in nks])
    )  # all pairwise combinations of k_indices

    # Assumes only one shell for now
    k_shell, idx_shell = get_k_shell(
        *nks, lat_vecs=lat_vecs, N_sh=1, tol_dp=8, report=False
    )
    w_b = get_weights(*nks, lat_vecs=lat_vecs, N_sh=1)[0]

    # Projector on initial subspace at each k
    P = np.zeros((*nks, n_orb, n_orb), dtype=complex)
    Q = np.zeros((*nks, n_orb, n_orb), dtype=complex)
    for k_idx in k_idx_arr:
        P[k_idx][:, :] = np.sum(
            [
                np.outer(inner_states[k_idx][n, :], inner_states[k_idx][n, :].conj())
                for n in range(int(n_states))
            ],
            axis=0,
        )
        Q[k_idx][:, :] = np.eye(P[k_idx].shape[0]) - P[k_idx]

    # Projector on initial subspace at each k (for pbc of neighboring spaces)
    P_nbr = np.zeros((*nks, len(k_shell[0]), n_orb, n_orb), dtype=complex)
    Q_nbr = np.zeros((*nks, len(k_shell[0]), n_orb, n_orb), dtype=complex)
    T_kb = np.zeros((*nks, len(k_shell[0])), dtype=complex)
    for k_idx in k_idx_arr:
        for idx, idx_vec in enumerate(idx_shell[0]):  # nearest neighbors
            k_nbr_idx = np.array(k_idx) + idx_vec
            # apply pbc to index
            mod_idx = np.mod(k_nbr_idx, nks)
            diff = k_nbr_idx - mod_idx
            G = np.divide(np.array(diff), np.array(nks))
            # if the translated k-index contains the -1st or last_idx+1 then we crossed the BZ boundary
            cross_bndry = True if np.any(np.in1d(k_nbr_idx, [-1, *nks])) else False
            if cross_bndry:
                bc_phase = np.array(
                    np.exp(-1j * 2 * np.pi * orbs @ G.T), dtype=complex
                ).T
            else:
                bc_phase = 1

            # apply pbc
            state_pbc = inner_states[tuple(mod_idx)] * bc_phase
            P_nbr[k_idx][idx, :, :] = np.sum(
                [
                    np.outer(state_pbc[n].T, state_pbc[n].conj())
                    for n in range(int(n_states))
                ],
                axis=0,
            )
            Q_nbr[k_idx][idx, :, :] = np.eye(n_orb) - P_nbr[k_idx][idx, :, :]

    P_min = np.copy(P)  # start of iteration
    P_nbr_min = np.copy(P_nbr)  # start of iteration
    Q_nbr_min = np.copy(Q_nbr)  # start of iteration

    # states spanning optimal subspace minimizing gauge invariant spread
    states_min = np.zeros((*nks, dim_subspace, n_orb), dtype=complex)

    M = k_overlap_mat(lat_vecs, orbs, inner_states)
    spread, _, _ = spread_recip(lat_vecs, M, decomp=True)
    omega_I_prev = spread[1]

    # diff = None
    for i in range(iter_num):
        for k_idx in k_idx_arr:
            P_avg = np.sum(w_b * P_nbr_min[k_idx], axis=0)

            # diagonalizing P_avg in outer_states basis
            N = outer_states.shape[-2]
            Z = np.zeros((N, N), dtype=complex)
            for n in range(N):
                for m in range(N):
                    Z[m, n] = outer_states[k_idx][m, :].conj() @ (
                        P_avg @ outer_states[k_idx][n, :]
                    )
            # Z = np.einsum('ni,nj->ij', outer_states[k].conj(), P_avg @ outer_states[k])

            eigvals, eigvecs = np.linalg.eigh(Z)  # [val, idx]
            for idx, n in enumerate(
                np.argsort(eigvals.real)[-dim_subspace:]
            ):  # keep ntfs wfs with highest eigenvalue
                states_min[k_idx][idx, :] = np.sum(
                    [
                        eigvecs[i, n] * outer_states[k_idx][i, :]
                        for i in range(eigvecs.shape[0])
                    ],
                    axis=0,
                )

            P_new = np.einsum("ni,nj->ij", states_min[k_idx], states_min[k_idx].conj())
            alpha = 1  # mixing with previous step to break convergence loop
            P_min[k_idx] = (
                alpha * P_new + (1 - alpha) * P_min[k_idx]
            )  # for next iteration

            for idx, idx_vec in enumerate(idx_shell[0]):  # nearest neighbors
                k_nbr_idx = np.array(k_idx) + idx_vec
                mod_idx = np.mod(k_nbr_idx, nks)
                diff = k_nbr_idx - mod_idx
                G = np.divide(np.array(diff), np.array(nks))
                # if the translated k-index contains the -1st or last_idx+1 then we crossed the BZ boundary
                cross_bndry = True if np.any(np.in1d(k_nbr_idx, [-1, *nks])) else False
                if cross_bndry:
                    bc_phase = np.array(
                        np.exp(-1j * 2 * np.pi * orbs @ G.T), dtype=complex
                    ).T
                else:
                    bc_phase = 1

                # apply pbc
                state_pbc = states_min[tuple(mod_idx)] * bc_phase
                P_nbr_min[k_idx][idx, :, :] = np.einsum(
                    "ni,nj->ij", state_pbc, state_pbc.conj()
                )
                Q_nbr_min[k_idx][idx, :, :] = (
                    np.eye(n_orb) - P_nbr_min[k_idx][idx, :, :]
                )
                T_kb[k_idx][idx] = np.trace(P_min[k_idx] @ Q_nbr_min[k_idx][idx, :, :])

        omega_I_new = (1 / Nk) * w_b * np.sum(T_kb)

        if omega_I_new > omega_I_prev:
            print("Warning: Omega_I is increasing.")

        if abs(omega_I_prev - omega_I_new) <= tol:
            print("omega_I has converged within tolerance. Breaking loop")
            return states_min

        if verbose:
            print(f"{i} Omega_I: {omega_I_new.real}")

        omega_I_prev = omega_I_new

    return states_min

In [33]:
states_min_1 = find_optimal_subspace(lat_vecs, orbs, outer_states, u_tilde_wan, iter_num=100, verbose=True)

0 Omega_I: 0.5224352455932965
1 Omega_I: 0.5185655608322894
2 Omega_I: 0.5149416822333317
3 Omega_I: 0.5114317051940603
4 Omega_I: 0.5079889673197662
5 Omega_I: 0.504598731828422
6 Omega_I: 0.5012619108799339
7 Omega_I: 0.49798763126911627
8 Omega_I: 0.4947889209823177
9 Omega_I: 0.4916798757270802
10 Omega_I: 0.48867378673602585
11 Omega_I: 0.48578198905465403
12 Omega_I: 0.48301327418404705
13 Omega_I: 0.4803737222134296
14 Omega_I: 0.47786681964443983
15 Omega_I: 0.475493746569461
16 Omega_I: 0.4732537429670067
17 Omega_I: 0.4711444912243334
18 Omega_I: 0.46916247631807456
19 Omega_I: 0.46730330365649536
20 Omega_I: 0.4655619672654045
21 Omega_I: 0.46393306868567136
22 Omega_I: 0.4624109910153313
23 Omega_I: 0.46099003422475665
24 Omega_I: 0.45966451817546417
25 Omega_I: 0.45842885934531563
26 Omega_I: 0.45727762651902565
27 Omega_I: 0.45620557987733384
28 Omega_I: 0.4552076971311633
29 Omega_I: 0.4542791896506486
30 Omega_I: 0.45341551094840576
31 Omega_I: 0.45261235938839023
32 Om

In [34]:
states_min_2 = find_optimal_subspace_og(lat_vecs, orbs, outer_states, u_tilde_wan, iter_num=100, verbose=True)

0 Omega_I: [23.05848394]
1 Omega_I: [0.80637211]
2 Omega_I: [0.51982299]
3 Omega_I: [0.52710947]
4 Omega_I: [0.51973469]
5 Omega_I: [0.5098095]
6 Omega_I: [0.50982174]
7 Omega_I: [0.50758791]
8 Omega_I: [0.50304792]
9 Omega_I: [0.50196892]
10 Omega_I: [0.49990016]
11 Omega_I: [0.49676763]
12 Omega_I: [0.49524056]
13 Omega_I: [0.49321554]
14 Omega_I: [0.49073607]
15 Omega_I: [0.48905772]
16 Omega_I: [0.48713804]
17 Omega_I: [0.48505127]
18 Omega_I: [0.48338204]
19 Omega_I: [0.48161338]
20 Omega_I: [0.47980766]
21 Omega_I: [0.47823338]
22 Omega_I: [0.47663539]
23 Omega_I: [0.47505628]
24 Omega_I: [0.47361702]
25 Omega_I: [0.47219227]
26 Omega_I: [0.47080783]
27 Omega_I: [0.46951704]
28 Omega_I: [0.46825818]
29 Omega_I: [0.46704556]
30 Omega_I: [0.46590198]
31 Omega_I: [0.46479663]
32 Omega_I: [0.46373681]
33 Omega_I: [0.46273169]
34 Omega_I: [0.46176544]
35 Omega_I: [0.46084137]
36 Omega_I: [0.45996263]
37 Omega_I: [0.45912073]
38 Omega_I: [0.45831677]
39 Omega_I: [0.45755134]
40 Omega_I

In [36]:
%lprun -f find_optimal_subspace find_optimal_subspace(lat_vecs, orbs, outer_states, u_tilde_wan, iter_num=100, verbose=True)

0 Omega_I: 0.5224352455932965
1 Omega_I: 0.5185655608322894
2 Omega_I: 0.5149416822333317
3 Omega_I: 0.5114317051940603
4 Omega_I: 0.5079889673197662
5 Omega_I: 0.504598731828422
6 Omega_I: 0.5012619108799339
7 Omega_I: 0.49798763126911627
8 Omega_I: 0.4947889209823177
9 Omega_I: 0.4916798757270802
10 Omega_I: 0.48867378673602585
11 Omega_I: 0.48578198905465403
12 Omega_I: 0.48301327418404705
13 Omega_I: 0.4803737222134296
14 Omega_I: 0.47786681964443983
15 Omega_I: 0.475493746569461
16 Omega_I: 0.4732537429670067
17 Omega_I: 0.4711444912243334
18 Omega_I: 0.46916247631807456
19 Omega_I: 0.46730330365649536
20 Omega_I: 0.4655619672654045
21 Omega_I: 0.46393306868567136
22 Omega_I: 0.4624109910153313
23 Omega_I: 0.46099003422475665
24 Omega_I: 0.45966451817546417
25 Omega_I: 0.45842885934531563
26 Omega_I: 0.45727762651902565
27 Omega_I: 0.45620557987733384
28 Omega_I: 0.4552076971311633
29 Omega_I: 0.4542791896506486
30 Omega_I: 0.45341551094840576
31 Omega_I: 0.45261235938839023
32 Om

Timer unit: 1e-09 s

Total time: 1.53279 s
File: /Users/treycole/Codes/WanPy/WanPy/WanPy.py
Function: find_optimal_subspace at line 608

Line #      Hits         Time  Per Hit   % Time  Line Contents
   608                                           def find_optimal_subspace(
   609                                               lat_vecs, orbs, outer_states, inner_states, iter_num=100, verbose=False, tol=1e-17, alpha=1
   610                                           ):
   611                                               """
   612                                               Assumes the states are defined on the same k-mesh and are of Bloch cell periodic character.
   613                                           
   614                                               Args:
   615                                                   lat_vecs: Lattice vectors
   616                                                   orbs: orbtial vectors
   617                                                

In [38]:
(4.25-1.53)/4.25

0.6399999999999999

In [37]:
%lprun -f find_optimal_subspace_og find_optimal_subspace_og(lat_vecs, orbs, outer_states, u_tilde_wan, iter_num=100, verbose=True)

0 Omega_I: [23.05848394]
1 Omega_I: [0.80637211]
2 Omega_I: [0.51982299]
3 Omega_I: [0.52710947]
4 Omega_I: [0.51973469]
5 Omega_I: [0.5098095]
6 Omega_I: [0.50982174]
7 Omega_I: [0.50758791]
8 Omega_I: [0.50304792]
9 Omega_I: [0.50196892]
10 Omega_I: [0.49990016]
11 Omega_I: [0.49676763]
12 Omega_I: [0.49524056]
13 Omega_I: [0.49321554]
14 Omega_I: [0.49073607]
15 Omega_I: [0.48905772]
16 Omega_I: [0.48713804]
17 Omega_I: [0.48505127]
18 Omega_I: [0.48338204]
19 Omega_I: [0.48161338]
20 Omega_I: [0.47980766]
21 Omega_I: [0.47823338]
22 Omega_I: [0.47663539]
23 Omega_I: [0.47505628]
24 Omega_I: [0.47361702]
25 Omega_I: [0.47219227]
26 Omega_I: [0.47080783]
27 Omega_I: [0.46951704]
28 Omega_I: [0.46825818]
29 Omega_I: [0.46704556]
30 Omega_I: [0.46590198]
31 Omega_I: [0.46479663]
32 Omega_I: [0.46373681]
33 Omega_I: [0.46273169]
34 Omega_I: [0.46176544]
35 Omega_I: [0.46084137]
36 Omega_I: [0.45996263]
37 Omega_I: [0.45912073]
38 Omega_I: [0.45831677]
39 Omega_I: [0.45755134]
40 Omega_I

Timer unit: 1e-09 s

Total time: 4.25792 s
File: /var/folders/nn/m4t491h92ss8vwl56z761h6c0000gn/T/ipykernel_64561/709642547.py
Function: find_optimal_subspace_og at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents