In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, ops
from jax import random
import numpy as np
import jax
from jax.scipy.linalg import expm
from jax.scipy import linalg


In [2]:
# tunning parameter (constant in grad)
numStates=25
numSteps=31
numBands=5
fr=25.18
kpoints=250
kvec=jnp.linspace(-1,0,kpoints)

def tridiag(a, b, c, k1=-1, k2=0, k3=1):
    return jnp.diag(a, k1) + jnp.diag(b, k2) + jnp.diag(c, k3)


C = []
for i in range(2,numStates,2):
    C = C + [0,i]
C = jnp.array(C)
D = jnp.zeros(numStates)
M1 = tridiag(C,D,-1 * C)

E = [0]
for i in range(2,numStates,2):
    E = E + [(i)**2,(i)**2]
M2 = jnp.array(-np.diag(E))


F = jnp.concatenate((jnp.array([jnp.sqrt(2)]) , jnp.ones(numStates-3)))
M3 = jnp.diag(F,-2) + jnp.diag(F,2)

# Pade

In [6]:
padeCoefA = (0,1,5/2,74/33,19/22,29/220,7/1320)
padeCoefB = (1,3,75/22,20/11,5/11,1/221,1/924)

@jax.jit
def logm(mat):
    Npq = jnp.zeros(mat.shape)
    Dpq = jnp.identity(mat.shape[0])
    
    for i in range(1,7):
        curr = jnp.linalg.matrix_power(mat - jnp.identity(mat.shape[0]),i)
        Npq += padeCoefA[i] * curr
        Dpq += padeCoefB[i] * curr
        
    return jnp.linalg.inv(Dpq) @ Npq


In [23]:
from scipy import linalg
import numpy as np

N = 100
test = 0.01 * np.random.rand(size=(N,N)) + np.identity(N)

logm(test) - linalg.logm(test)

DeviceArray([[-3.54647636e-05,  8.81552696e-05,  1.03950500e-04, ...,
              -4.49167565e-05,  6.63399696e-05,  8.94577242e-05],
             [ 7.15814531e-05, -3.08379531e-05,  2.48178840e-05, ...,
               4.83244658e-05,  1.38767064e-05,  8.08392651e-05],
             [-2.30625272e-04, -8.71913508e-05, -3.56007367e-05, ...,
               5.88322291e-05, -6.77937642e-05,  6.16796315e-05],
             ...,
             [ 7.96262175e-05,  1.00784469e-04,  2.58907676e-06, ...,
              -1.27159059e-04, -8.23885202e-05, -1.79056078e-05],
             [-1.57834031e-04,  8.24555755e-05, -4.32824250e-04, ...,
               1.02482736e-04, -7.65584409e-05,  6.71390444e-05],
             [ 1.87754631e-04,  5.66802919e-06,  1.94675289e-04, ...,
              -1.83104072e-04,  6.45220280e-05,  1.46273524e-05]],            dtype=float32)

In [25]:
logm(test)

DeviceArray([[-0.05335252,  0.03544155, -0.04392498, ...,  0.01543865,
               0.02274147,  0.00441582],
             [-0.03129308,  0.06824718,  0.08331676, ...,  0.03264381,
               0.03987109,  0.00487409],
             [ 0.0036211 ,  0.00889917, -0.03060531, ...,  0.00169048,
              -0.00846129,  0.04844356],
             ...,
             [ 0.02152383,  0.00537034,  0.06124863, ...,  0.06865254,
              -0.06855287, -0.027758  ],
             [-0.00974592, -0.06546516, -0.00678593, ...,  0.03173976,
               0.03235643,  0.02157707],
             [ 0.09501714,  0.02144121,  0.00904912, ...,  0.00275127,
               0.03614387,  0.00952093]], dtype=float32)

In [9]:
# freqs: driving frequencies
# alphas: driving strength

def computeFloquetLoss(freq,alpha,A):
    lcm = 1/freq
    Ttot =  lcm 
    ftot = 1/Ttot
    
    dT = (1/freq)/numSteps

    N = int(jnp.ceil(Ttot/dT))

    tVec = jnp.linspace(0,Ttot,N)
    dT = tVec[1] - tVec[0]
    tVec = tVec[:-1]
    tVec = tVec + dT/2
    dTau = (2 * jnp.pi * fr) * dT

    @jax.jit
    def perKstep(k):
        unitary = jnp.identity(M1.shape[0])
        for tt in range(tVec.shape[0]):
            t = tVec[tt]
            unitary = jnp.matmul(unitary,expm(-1j * dTau * createHmat(t,k)))
        return unitary

    @jax.jit
    def createHmat(t,k):
        modfunc = 1 + jnp.sum(alpha * jnp.sin(2 * jnp.pi * freq * t))
        newMat = (k**2) * jnp.identity(numStates) - 2* 1j * k * M1 -M2  - (1/4) * M3 * A * modfunc
        return newMat

    @jax.jit
    def genUni():
        kMap = vmap(perKstep)
        return kMap(kvec)

    res = genUni()

    def eigWrapper(mat):
        return jnp.linalg.eigvals(mat)

    eig = jax.jit(eigWrapper,backend='cpu')
    b =jnp.linalg.eigvals(res)
    rawEfloquet = jnp.real(1j*jnp.log(b)* (ftot/fr) / (2*np.pi))
    

    
    return jnp.std(rawEfloquet)
    
    


In [None]:
grad(computeFloquetLoss)(70.0,0.5,3.0)

NotImplementedError: Nonsymmetric eigendecomposition is only implemented on the CPU backend

In [None]:
@jax.jit
def blochStates(i):
    k = kvec[i]
    currF = vF[i,:,:] 
    H0 =  (k**2) * jnp.identity(numStates) - 2* 1j * k * M1 -M2  - (1/4) * M3 * A 
    a,vS = jnp.linalg.eigh(H0)
    vS = jnp.transpose(vS)
    Cvec = jnp.matmul(vS,jnp.conjugate(currF))
    Pvec = jnp.multiply(Cvec,jnp.conjugate(Cvec))
    inds = jnp.argmax(Pvec,axis=1)
    Efloquet = rawEfloquet[i,inds[:numBands]]
    return Efloquet
bandsF = vmap(blochStates)(jnp.arange(250))