In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from jaxQCUtil import *
from jax.config import config
from jax.scipy.linalg import expm
config.update("jax_enable_x64", True)
np.random.seed(1)




In [12]:
mapping =jnp.array([[0,0],[0,1]])
testMat = jnp.arange(16).reshape((4,4))

res = Rx(1)

for i in range(1):
    res=jnp.kron(res,Rx(1))

print_matrix(res)


[[ 0.292+0.j     0.   -0.455j  0.   -0.455j -0.708-0.j   ]
 [ 0.   -0.455j  0.292+0.j    -0.708-0.j     0.   -0.455j]
 [ 0.   -0.455j -0.708-0.j     0.292+0.j     0.   -0.455j]
 [-0.708-0.j     0.   -0.455j  0.   -0.455j  0.292+0.j   ]]


In [17]:
res = jnp.identity(2)

for i in range(1):
    res=jnp.kron(res,Rx(1))

print_matrix(res)
print_matrix(Rx(1))

[[0.54+0.j    0.  -0.841j 0.  +0.j    0.  +0.j   ]
 [0.  -0.841j 0.54+0.j    0.  +0.j    0.  +0.j   ]
 [0.  +0.j    0.  +0.j    0.54+0.j    0.  -0.841j]
 [0.  +0.j    0.  +0.j    0.  -0.841j 0.54+0.j   ]]
[[0.54+0.j    0.  -0.841j]
 [0.  -0.841j 0.54+0.j   ]]


In [7]:
initMat = jnp.zeros(2)

for i in range(4):
    if i == 2:
        res = jnp.kron(initMat,Rx(1))
    res = jnp.kron(initMat,Rx(1))

In [2]:
@jax.jit
def vec2Unitary(input):
    '''
    Converts unitary matrix to a vector of parameters (which parameterize the unitary). This vector effectively parameterizes SU(4).


    Args:
        N(int): Matrix size 
        input(float arrary): (matN**2 + matN)//2 - 1 size free parameters

    Returns:
        A complex (matN,matN) array that parameterized by the input vector
    '''

    # append 0 to the end so we can normalize it later
    newVec = jnp.append(input,0)

    temp = jnp.zeros((matN,matN))
    h_mat = temp.at[jnp.tril_indices(matN)].set(newVec)

    tr = jnp.trace(h_mat)
    
    # normalize trace to 0
    h_mat.at[-1].set(-tr)
    h_mat = (h_mat.conjugate().T + h_mat)/2
    unitary = jax.scipy.linalg.expm(1j * h_mat)

    return unitary

In [3]:
N = 5
beta = 1
J = 0.5
h = 0.3

matN = 5
perMatParaN = 2**(2*matN)-1

trivialPlacement = True
layerNum = 1

sigma_z = jnp.array([[1,0],[0,-1]])
sigma_x = jnp.array([[0,1],[1,0]])
sigma_y = jnp.array([[0,1],[1,0]])

spin = jnp.array([0,1])
spin_up = jnp.array([[1],[0]])
spin_down= jnp.array([[0],[1]])



In [4]:
staringState = jnp.zeros((2**N))
staringState.at[-1].set(1)

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 1.], dtype=float64)

In [5]:
@jax.jit
def RxLayer(thetas):
    init = jnp.identity(2)
    uni,_ = jax.lax.scan(Rx,thetas,init)
    return uni

RxLayer(jnp.arange(N))
    

TypeError: scan carry output and input must have identical types, got
DIFFERENT ShapedArray(complex128[2,10]) vs. ShapedArray(int64[5]).

In [4]:
H = jnp.zeros((2**N,2**N))

initLattice = jnp.kron(sigma_z,sigma_z)
for i in range(2,N):
    initLattice = jnp.kron(initLattice,jnp.eye(2))

H += - J * initLattice

for lattice_point in range(1,N-1):
    curr = jnp.eye(2)
    for i in range(1,lattice_point):
        curr = jnp.kron(curr,jnp.eye(2))
    curr = jnp.kron( jnp.kron(curr,sigma_z),sigma_z)
    for i in range(lattice_point+2,N):
        curr = jnp.kron(curr,jnp.eye(2))
    
    assert curr.shape[0] == H.shape[0]
    
    H += -J * curr


initLattice = sigma_x
for i in range(1,N):
    initLattice = jnp.kron(initLattice,jnp.eye(2))

H += - h * initLattice

for lattice_point in range(1,N-1):
    curr = jnp.eye(2)
    for i in range(1,lattice_point):
        curr = jnp.kron(curr,jnp.eye(2))
    curr = jnp.kron(curr,sigma_x)
    for i in range(lattice_point+1,N):
        curr = jnp.kron(curr,jnp.eye(2))
    
    assert curr.shape[0] == H.shape[0]
    
    H += -h * curr

H = jnp.array(H,dtype=jnp.complex128)

In [5]:
# prepare the ensemble basis

formator = '{0:' + '0' + str(N)  +'b}'

state_in_str = [formator.format(i) for i in range(2**N)]


def state_to_vec(s):
    # return a probability with the corresponding state
    if s[0] == '1':
        state = spin_up
        state_mat = spin_up
    else:
        state = spin_down
        state_mat = spin_down
    
    for curr in s[1:]:
        if curr == '1':
            state = jnp.kron(state,spin_up)
            state_mat = jnp.hstack((state_mat,spin_up))
        else:
            state = jnp.kron(state,spin_down)
            state_mat = jnp.hstack((state_mat,spin_down))
    
    return state,state_mat

allstate = jnp.stack([state_to_vec(s)[0] for s in state_in_str])
allstateMat = jnp.stack([state_to_vec(s)[1] for s in state_in_str])


In [6]:
# Utilities

@jax.jit
def stateMat_to_prob(state_mat,probs):
    prob_mat = jnp.multiply(probs,state_mat)
    prob_state = prob_mat[0,:] + prob_mat[1,:]
    return jnp.prod(prob_state)

@jax.jit
def build_ensemble(px):
    single = lambda p,state: p * jnp.outer(state,state)
    mats = jax.vmap(single,(0,0),0)(px,allstate)
    return jnp.sum(mats, axis=0)

@jax.jit
def weighted_expected(operator,px):
    single = lambda state: state.T @ operator @ state
    expectedVs =  jax.vmap(single)(allstate)
    res = jnp.sum(jnp.multiply(expectedVs[:,0,0],px))
    return jnp.real(res)

@jax.jit
def mat_prod(carry,x):
    return carry @ x ,0

@jax.jit
def unitary_prods(uniVec):
    uniVec = uniVec.reshape((layer_num,(2**(2*N)-1)))
    unis = jax.vmap(vec2Unitary,in_axes=(None,0))(N,uniVec)
    res,_ = jax.lax.scan(mat_prod,jnp.eye(2**N,dtype=jnp.complex128),unis)
    return res


stateMat_to_prob_map = jax.vmap(stateMat_to_prob,(0,None),0)

Free parameters initialization:

In [13]:
eps = jnp.array(0.5 * np.random.rand(N))
uniVec = jnp.array(np.random.rand(uniParas_num))
paras = jnp.hstack((eps,uniVec))

In [18]:
vec2Unitary(N,uniVec)

-16.14772350286639
-16.14772350286639


DeviceArray([[ 0.22173209+2.43656628e-01j, -0.0990587 +4.50345594e-02j,
               0.09468157+2.20339562e-01j, ...,
              -0.23921613+1.75591412e-01j,  0.02014918+1.82938343e-01j,
              -0.01045757+4.19483484e-03j],
             [-0.22316293-5.10691371e-02j,  0.33418243+7.19849344e-02j,
              -0.15078023-1.41396342e-01j, ...,
               0.02897866+3.46465599e-01j,  0.18782914+3.80740141e-01j,
               0.01164877+1.20322717e-04j],
             [-0.1055486 -8.46074549e-02j, -0.21371473-1.61121658e-01j,
               0.29067478+3.30218821e-01j, ...,
               0.22848382+2.70225861e-01j, -0.08654695+1.31256530e-02j,
              -0.01649239-1.78697578e-02j],
             ...,
             [ 0.04569813-2.36875698e-01j,  0.03544656+1.54617683e-01j,
              -0.05909613-1.09782667e-01j, ...,
               0.30988148+2.38845104e-01j, -0.19060754-1.69084725e-02j,
              -0.00270196-3.04543922e-02j],
             [-0.1274384 -2.15797947e-

In [8]:
def ising(paras):
    eps = paras[:N]
    uniVec = paras[N:]

    p_down = 1/(1+jnp.exp(-beta * eps))
    p_up = jnp.exp(-beta * eps)/(1+jnp.exp(-beta * eps))
    probs = jnp.vstack((p_down,p_up))

    px = stateMat_to_prob_map(allstateMat,probs)
    unitary = unitary_prods(uniVec)
    operator = unitary.conjugate().T @ H @ unitary

    entropy = jnp.sum(jnp.multiply(px,jnp.log(px)))
    loss = entropy + beta* weighted_expected(operator,px)
    #print(entropy)
    return loss




ising(paras),jax.grad(ising)(paras).shape
#ising(paras)



(DeviceArray(-3.41101604, dtype=float64), (4097,))

In [9]:
import scipy.optimize

def value_and_grad_numpy(f):
    def val_grad_f(*args):
        value, grad = jax.value_and_grad(f)(*args)
        return np.float64(value), np.array(grad,dtype=np.float64)
    return val_grad_f
results = scipy.optimize.minimize(value_and_grad_numpy(ising), np.array(paras,dtype=np.float64),
                                  method='L-BFGS-B', jac=True)
print("success:", results.success, "\nniterations:", results.nit, "\nfinal loss:", results.fun)

success: True 
niterations: 767 
final loss: -4.1016182712021045


In [9]:
-jnp.log(jnp.trace(expm(- beta * H)))

DeviceArray(-4.10161988-0.j, dtype=complex128)

In [10]:
def exact_ising(J,h,beta):
    inside1 = jnp.exp(beta * J) * jnp.cosh(beta*h) 
    inside2 = jnp.sqrt(jnp.exp(2*beta * J) * (jnp.sinh(beta*h)**2) + jnp.exp(-2*beta * J))

    return - (1/beta) * jnp.log(inside1+inside2)

In [11]:

p_down = 1/(1+jnp.exp(-beta * eps))
p_up = jnp.exp(-beta * eps)/(1+jnp.exp(-beta * eps))
probs = jnp.vstack((p_down,p_up))
px = stateMat_to_prob_map(allstateMat,probs)
unitary = unitary_prods(uniVec)
operator = unitary.conjugate().T @ H @ unitary

entropy = jnp.sum(jnp.multiply(px,jnp.log(px)))
#print_matrix(unitary.conjugate().T @ unitary)
#operator[2:4,2:4]
beta* weighted_expected(operator,px)
single = lambda state: state.T @ operator @ state
expectedVs =  jax.vmap(single)(allstate)
res = jnp.sum(jnp.multiply(expectedVs[:,0,0],px))
#print(expectedVs)r
res


DeviceArray(0.02614691-1.05410334e-18j, dtype=complex128)