The first block of code just inserts libraries and constants to be used later.

In [None]:
import jax.numpy as jnp

from jax import jit

from jax import lax
import jax
from jax import make_jaxpr
from jax import random
from jax import device_put

import numpy as np
from numpy import linalg as LA
import matplotlib.pyplot as plt

d = 2
chi = 4
h_x, h_y, h_z = np.sqrt(3)/2 , np.log(2), 0.4 # Field in the different directions
J_xx, J_yy, J_zz = .5, .5, .5 # Strength of spin interactions
Model = [["xx"],["z","y"]]   
lam_0,Gamma_0,lam_1,Gamma_1 = np.random.rand(chi), np.random.rand(2*chi**2).reshape(chi,2,chi), np.random.rand(chi), np.random.rand(2*chi**2).reshape(chi,2,chi)
U = np.random.rand(16).reshape(2,2,2,2)

@jit
def division(x):  
    return jnp.where(x == 0, 0., 1 / x)

def Model_coefficients(J_xx,J_yy,J_zz,h_x,h_y,h_z):
  Interactions = {
                  "xx":J_xx,
                  "yy":J_yy,
                  "zz":J_zz}
  Fields = {
          "x":h_x,  #Fields
          "y":h_y,
          "z":h_z}
  return Interactions,Fields
  
Interactions,Fields = Model_coefficients(J_xx,J_yy,J_zz,h_x,h_y,h_z)
data_type = complex

In [None]:
@jit  #Checked
def Entropy(v):
  return -jnp.sum(jnp.where(v == 0, 0., jnp.log(v))*v)

def Entropy_non_jit(v):
  return -np.sum(np.where(v == 0, 0. , np.log(v)*v))

# Comparison
rand_vec = np.random.rand(chi)
print(LA.norm(Entropy(rand_vec) - Entropy_non_jit(rand_vec)))

0.0


In [None]:
@jit #checked
def Pauli_jaxed():
  S_z = jnp.array([[1.,0.],[0.,-1.]],dtype = complex)
  S_x = jnp.array([[0.,1.],[1.,0.]] ,dtype = complex)
  S_y = jnp.array([[0.,complex(0,-1.)],[complex(0,1.),0.]],dtype = complex) #Pauli Matrices
  S1 = {
     "x" : S_x,
     "xx": S_x, 
     "y" : S_y,
     "yy": S_y, 
     "z" : S_z,
     "zz": S_z }
  return S_z, S_x, S_y, S1


def Pauli():
  S_z = np.diag([1.,-1.]).astype(np.complex128)
  S_x = np.array([[0.,1.],[1.,0.]],dtype = np.complex128)
  S_y = np.array([[0.,complex(0,-1.)],[complex(0,1.),0.]],dtype = np.complex128) #Pauli Matrices
  S1 = {
     "x" : S_x,
     "xx": S_x, 
     "y" : S_y,
     "yy": S_y, 
     "z" : S_z,
     "zz": S_z }
  return S_z, S_x, S_y, S1
S_z, S_x, S_y, S1 = Pauli()

# Comparison
print(LA.norm(Pauli()[2]-Pauli_jaxed()[2]))

0.0


In [None]:
def Create_HamiltonianTEBD_non_jit(Model, S1, Interactions, Fields,d):
    H = np.zeros((d,d,d,d)).astype(np.complex128)

    for j in range(len(Model[0])):
      H += np.einsum("ij,kl->ikjl",2*Interactions[Model[0][j]]*S1[Model[0][j]],S1[Model[0][j]]) #two-site

    for k in range(len(Model[1])):
      H += np.einsum("ij,kl->ikjl",Fields[Model[1][k]]*S1[Model[1][k]],np.diag(np.ones(2)))    #on-site
      H += np.einsum("ij,kl->ikjl",np.diag(np.ones(2)),Fields[Model[1][k]]*S1[Model[1][k]])
    return H

def Create_HamiltonianTEBD_jit(Model, S1, Interactions, Fields,d):
    H = jnp.zeros((d,d,d,d)).astype(complex)

    for j in range(len(Model[0])):
      H += jnp.einsum("ij,kl->ikjl",2*Interactions[Model[0][j]]*S1[Model[0][j]],S1[Model[0][j]]) #two-site

    for k in range(len(Model[1])):
      H += jnp.einsum("ij,kl->ikjl",Fields[Model[1][k]]*S1[Model[1][k]],jnp.diag(np.ones(2)))    #on-site
      H += jnp.einsum("ij,kl->ikjl",jnp.diag(np.ones(2)),Fields[Model[1][k]]*S1[Model[1][k]])
    return H

H_jit = Create_HamiltonianTEBD_jit(Model, S1, Interactions, Fields,d)
H_non_jit = Create_HamiltonianTEBD_non_jit(Model, S1, Interactions, Fields,d)
# Comparison
print(LA.norm(H_jit-H_non_jit))

0.0


In [None]:

def Compute_Theta_optimized(lam_0,Gamma_0,lam_1,Gamma_1):
    # This approach we can't jit because of einsum 
    Theta1 = jnp.einsum("a,aib->aib",lam_0,Gamma_0)
    Theta1 = jnp.einsum("aib,b->aib",Theta1,lam_1)
    Theta1 = jnp.einsum("aib,bjc->aijc",Theta1,Gamma_1)
    Theta1 = jnp.einsum("aijc,c->aijc",Theta1,lam_0)
    
    # This one we can
    Theta=jnp.tensordot(jnp.diag(lam_0),Gamma_0,axes = 1)
    Theta=jnp.tensordot(Theta,jnp.diag(lam_1),axes = 1)
    Theta=jnp.tensordot(Theta,Gamma_1,axes = 1)
    Theta=jnp.tensordot(Theta,jnp.diag(lam_0),axes = 1)

    return Theta, Theta1

def Compute_Theta_non_jit(lam_0,Gamma_0,lam_1,Gamma_1):
    Theta = np.einsum("a,aib->aib",lam_0,Gamma_0)
    Theta = np.einsum("aib,b->aib",Theta,lam_1)
    Theta = np.einsum("aib,bjc->aijc",Theta,Gamma_1)
    Theta = np.einsum("aijc,c->aijc",Theta,lam_0)

    return Theta
# Comparison
print(type(lam_0[0]))
print(np.finfo(np.float64).eps,"\n")

print(LA.norm(Compute_Theta_optimized(lam_0,Gamma_0,lam_1,Gamma_1)[0]-Compute_Theta_non_jit(lam_0,Gamma_0,lam_1,Gamma_1)),"Compariosn using jnp.tensordot, where we can jit the function")

print(LA.norm(Compute_Theta_optimized(lam_0,Gamma_0,lam_1,Gamma_1)[1]-Compute_Theta_non_jit(lam_0,Gamma_0,lam_1,Gamma_1)),"Compariosn using jnp.einsum")

<class 'numpy.float64'>
2.220446049250313e-16 

2.420375e-08 Compariosn using jnp.tensordot, where we can jit the function
2.420375e-08 Compariosn using jnp.einsum


In [None]:
@jit
def i_trunc_opt(Fs,d=2):
  chi = 4
  S,V,D = jnp.linalg.svd(Fs)
  VI = jax.lax.dynamic_slice(V, (0,), (chi,))  
  lam = VI/jnp.linalg.norm(VI)
  
  A = jax.lax.dynamic_slice(S, (0,0), (chi*d,chi))  
  A = A.reshape(chi,d,chi)
  
  B = jax.lax.dynamic_slice(D, (0,0), (chi,chi*d))  
  B = B.reshape(chi,d,chi)
  
  return A,lam,B


def i_trunc_non_jit(chi,Fs,d=2):
  S,V,D = LA.svd(Fs,full_matrices=False)
    
  lam = V[:chi]/np.linalg.norm(V[:chi])

  A=S[:,:chi].reshape(chi,d,chi)
  B=D[:chi].reshape(chi,d,chi)
  
  return A,lam,B
# Comparison
Fs = np.random.rand(chi*chi*4).reshape(chi*2,chi*2)
print(type(Fs[0][0]))
print(np.finfo(np.float64).eps,"\n")

print(LA.norm(i_trunc_non_jit(chi,Fs)[0] - i_trunc_opt(Fs)[0]))

print(LA.norm(i_trunc_non_jit(chi,Fs)[1] - i_trunc_opt(Fs)[1]))

print(LA.norm(i_trunc_non_jit(chi,Fs)[2] - i_trunc_opt(Fs)[2]))

print(i_trunc_non_jit(chi,Fs)[0])

<class 'numpy.float64'>
2.220446049250313e-16 

1.1081409e-06
9.1856926e-08
1.1395796e-06
[[[-0.28905441  0.30343869 -0.73163494  0.12319217]
  [-0.36987202  0.45380751  0.31528641  0.16009687]]

 [[-0.31532122 -0.23461519  0.05902784 -0.69862885]
  [-0.46666767 -0.25807165 -0.1087248   0.19073188]]

 [[-0.38916162  0.39354705  0.08975242 -0.47423204]
  [-0.25338519  0.20297756 -0.08078108  0.29496071]]

 [[-0.35615538 -0.60187734 -0.20918018  0.07857846]
  [-0.34632149 -0.14833389  0.5400574   0.34155432]]]


In [None]:
print(i_trunc_opt(Fs)[0])

[[[-0.27017975 -0.26676416  0.30447516  0.2888867 ]
  [-0.40424645 -0.5102322  -0.27968207 -0.18371873]]

 [[-0.33057955 -0.06783468 -0.6092997   0.5161189 ]
  [-0.45700914  0.3698402   0.06807419  0.17523882]]

 [[-0.42375225 -0.24999335  0.07672869 -0.6301403 ]
  [-0.30860907 -0.12641668  0.61084574  0.357605  ]]

 [[-0.290714    0.47325853  0.15649678 -0.21518742]
  [-0.29347518  0.4739888  -0.22289251 -0.12016868]]]


In [None]:
@jit
def Exp_value_two_site_optimizirano(lam_0,Gamma_0,lam_1,Gamma_1,S_1,S_2):
   Theta,Theta1 = Compute_Theta_optimized(lam_0,Gamma_0,lam_1,Gamma_1)
   
   Mz = jnp.tensordot(S_1 ,jnp.eye(2),axes = 0)

   Mz = jnp.einsum("aijb,ikjm->akmb",Theta,Mz)
   Mz1 = jnp.tensordot(Mz,jnp.conj(Theta),axes=([0,1,2,3],[0,1,2,3]))

   Mzz = jnp.tensordot(jnp.eye(2) ,S_2,axes = 0)
   Mzz = jnp.einsum("aijb,ikjm->akmb",Theta,Mzz)
   Mz2 = jnp.tensordot(Mzz,jnp.conj(Theta),axes=([0,1,2,3],[0,1,2,3]))
   return (Mz1 + Mz2)/2

def Exp_value_two_site_non_jit(lam_0,Gamma_0,lam_1,Gamma_1,S_1,S_2):
   Theta = Compute_Theta_non_jit(lam_0,Gamma_0,lam_1,Gamma_1)
   
   Mz = np.einsum("ik,jm->ikjm",S_1 ,np.eye(2))
   Mz = np.einsum("aijb,ikjm->akmb",Theta,Mz)
   Mz1 = np.einsum("akmb,akmb->",Mz,np.conj(Theta))

   Mzz = np.einsum("ik,jm->ikjm",np.eye(2) ,S_2)
   Mzz = np.einsum("aijb,ikjm->akmb",Theta,Mzz)
   Mz2 = np.einsum("akmb,akmb->",Mzz,np.conj(Theta))
   return (Mz1+Mz2)/2
# Comparison
print(type(Gamma_0[0][0][0]))
print(np.finfo(np.float64).eps,"\n")

print(LA.norm(Exp_value_two_site_optimizirano(lam_0,Gamma_0,lam_1,Gamma_1,S_z,S_z) - Exp_value_two_site_non_jit(lam_0,Gamma_0,lam_1,Gamma_1,S_z,S_z)))

<class 'numpy.float64'>
2.220446049250313e-16 

7.6293945e-06


In [None]:
def apply_two_site_unitary_non_jit(lam_0,Gamma_0,lam_1,Gamma_1,U,chi,d):

    Theta = Compute_Theta_non_jit(lam_0,Gamma_0,lam_1,Gamma_1)
    Theta = np.einsum("ijkl,aijb->aklb",U,Theta)
    Theta = Theta.reshape(chi*d,chi*d)         

    Sigma,lam_1,Ve = i_trunc_non_jit(chi,Theta)

    vi = np.divide(1.0, lam_0, out=np.zeros_like(lam_0), where=np.abs(lam_0)>=1E-14)
    
    Gamma_1 = np.einsum("aib,b->aib",Ve,vi)                            
    Gamma_0 = np.einsum("a,aib->aib",vi,Sigma) 
    
    return Gamma_0,Gamma_1,lam_1

def apply_two_site_unitary_opt(lam_0,Gamma_0,lam_1,Gamma_1,U):
    chi = 20
    d = 2 
    Theta,Theta1 = Compute_Theta_optimized(lam_0,Gamma_0,lam_1,Gamma_1)
    
    Theta = jnp.einsum("ijkl,aijb->aklb",U,Theta)
    Theta = Theta.reshape(chi*d,chi*d)         

    Sigma,lam_1,Ve = i_trunc_opt(Theta)

    vi = division(lam_0)
    Gamma_1 = jnp.tensordot(Ve,jnp.diag(vi),axes = 1)
    
    Gamma_0 = jnp.tensordot(jnp.diag(vi),Sigma,axes = 1)
    
    return Gamma_0,Gamma_1,lam_1

# Comparison
print(LA.norm(apply_two_site_unitary_non_jit(lam_0,Gamma_0,lam_1,Gamma_1,U,20,2)[2] - apply_two_site_unitary_opt(lam_0,Gamma_0,lam_1,Gamma_1,U)[2]))

1.0571295e-07
