In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, ops
from jax import random
import numpy as np
import jax
from jax.scipy.linalg import expm
from jax.experimental import ode
import eigAD


In [2]:
# tunning parameter (constant in grad)
numStates=25
numSteps=31
numBands=5
fr=25.18
kpoints=250
kvec=jnp.linspace(-1,0,kpoints)

def tridiag(a, b, c, k1=-1, k2=0, k3=1):
    return np.diag(a, k1) + np.diag(b, k2) + np.diag(c, k3)


C = []
for i in range(2,numStates,2):
    C = C + [0,i]
C = np.array(C)
D = np.zeros(numStates)
M1 = tridiag(C,D,-1 * C)

E = [0]
for i in range(2,numStates,2):
    E = E + [(i)**2,(i)**2]
M2 = np.array(-np.diag(E))


F = np.concatenate((np.array([np.sqrt(2)]) , np.ones(numStates-3)))
M3 = np.diag(F,-2) + np.diag(F,2)

In [3]:
# freqs: driving frequencies
# alphas: driving strength

M1 = jnp.asarray(M1,dtype=jnp.complex64)
M2 = jnp.asarray(M2,dtype=jnp.complex64)
M3 = jnp.asarray(M3,dtype=jnp.complex64)
#jax.profiler.start_trace("/home/hpan/tensorboard")
def computeFloquetLoss(freq,alpha,A):
    lcm = 1/freq
    Ttot =  lcm 
    ftot = 1/Ttot
    
    dT = (1/freq)/numSteps

    N = int(jnp.ceil(Ttot/dT))


    @jax.jit
    def createHmat(t,k):
        modfunc = 1 + jnp.sum(alpha * jnp.sin(2 * jnp.pi * freq * t))
        newMat = (k**2) * jnp.identity(numStates) - 2* 1j * k * M1 -M2  - (1/4) * M3 * A * modfunc
        return newMat


    @jax.jit
    def se(unitary,t,k):
        return - 1j * createHmat(t,k) @ unitary * (2 * jnp.pi * fr) 

    @jax.jit
    def perKstep(k):
        unitaryInit = jnp.identity(M1.shape[0],dtype=jnp.complex64)
        return ode.odeint(se,unitaryInit,jnp.array([0.0,Ttot]),k)[1,:,:]
        
    def genUni():
        kMap = vmap(perKstep)
        return kMap(kvec)
    

    res = genUni()

    def eigWrapper(mat):
        return eigAD.eig(mat)

    eigWrapper= jax.jit(eigWrapper,backend='cpu')
    eigWrapper= vmap(eigWrapper)
    b,vF =eigWrapper(res)
    rawEfloquet = jnp.real(1j*jnp.log(b)* (ftot/fr) / (2*np.pi))
    
    @jax.jit
    def blochStates(i):
        k = kvec[i]
        currF = vF[i,:,:] 
        H0 =  (k**2) * jnp.identity(numStates) - 2* 1j * k * M1 -M2  - (1/4) * M3 * A 
        a,vS = jnp.linalg.eigh(H0)
        vS = jnp.transpose(vS)
        Cvec = jnp.matmul(vS,jnp.conjugate(currF))
        Pvec = jnp.multiply(Cvec,jnp.conjugate(Cvec))
        inds = jnp.argmax(Pvec,axis=1)
        Efloquet = rawEfloquet[i,inds[:numBands]]
        return Efloquet

    bandsF = vmap(blochStates)(jnp.arange(250))

    return jnp.std(bandsF)
    
    


In [4]:
grad(computeFloquetLoss)(70.0,0.5,3.5)

DeviceArray(-0.00303411, dtype=float32)