# on improving lowdin with kpm

## imports here

In [None]:
import kwant
import numpy as np
import scipy.linalg as la
import scipy.sparse as sla

import itertools

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import sympy

## prepare random Hamiltonian

In [None]:
def H0_random(nA=4, nB=100, gap=1, epsilonA=0.2, epsilonB=10):
    """Generate random Hamiltonian with quasi-degenerate states."""
    energiesA = epsilonA * np.random.random(nA) - epsilonA / 2
    
    energiesB = epsilonB * np.random.random(nB) - epsilonB / 2
    energiesB = energiesB[np.abs(energiesB) > gap/2]

    energies = np.append(energiesA, energiesB)
    U = kwant.rmt.circular(len(energies))
    
    return U.transpose().conjugate() @ np.diag(energies) @ U


def H1_random(n, v=1):
    return kwant.rmt.gaussian(n, v=v)

In [None]:
alphas = np.linspace(0, .1, 100)
np.random.seed(0)

H0 = H0_random()        # This is H_0
H1 = H1_random(len(H0)) # This is perturbation (H')

H1 = {sympy.sympify('1'): H1}

In [None]:
%%time
ev, evec = la.eigh(H0)

window = (-.25, +.25)
indices = [i for (i, e) in enumerate(ev) if window[0] < e < window[1]]

# 2nd order explicit

In [None]:
from codes.lowdin import second_order_explicit

In [None]:
import sympy

In [None]:
%%time
M2 = second_order_explicit(H1, indices, ev, evec)

In [None]:
M2[1, 1]

# 2nd order KPM

In [None]:
from codes.lowdin import second_order_kpm

In [None]:
%%time
M2_kpm = second_order_kpm(H0, H1, ev[indices], evec[:, indices], 
                          num_moments=1500)

In [None]:
M2_kpm[1, 1]

# Convergence

In [None]:
def difference(kpm):
    assert set(M2) == set(kpm)
    output = 0
    for key, val in kpm.items():
        output += la.norm(val - M2[key])
    return output

In [None]:
moments = range(100, 5000, 500)


ds = []
for num_moment in moments:
    print(num_moment)
    kpm = second_order_kpm(
        H0, H1, ev[indices], evec[:, indices], num_moments=num_moment
    )
    ds.append(difference(kpm))
    

In [None]:
plt.plot(moments, ds, 'o-')
plt.ylabel('|explicit - kpm|')
plt.xlabel('# moments')
plt.yscale('log')
plt.grid()