In [6]:
import jax
import jax.numpy as jnp
import numpy as np
from vec2Unitary import vec2Unitary
from jax.config import config
from jax.scipy.linalg import expm
config.update("jax_enable_x64", True)

np.random.seed(1)

def print_matrix(mat):
    with np.printoptions(precision=3, suppress=True):
        print(mat)

def assert_equal(mat1,mat2,eps = 0.0000001):
    assert jnp.sum((mat1-mat2)**2) < eps

In [7]:
N = 4
beta = 1
J = 0.5
h = 0.3
layer_num = 4
uniParas_num = layer_num * (2**(2*N)-1)

sigma_z = jnp.array([[1,0],[0,-1]])
sigma_x = jnp.array([[0,1],[1,0]])
sigma_y = jnp.array([[0,1],[1,0]])

spin = jnp.array([0,1])
spin_up = jnp.array([[1],[0]])
spin_down= jnp.array([[0],[1]])


In [8]:
#def createH(J,h)

H = jnp.zeros((2**N,2**N))

initLattice = jnp.kron(sigma_z,sigma_z)
for i in range(2,N):
    initLattice = jnp.kron(initLattice,jnp.eye(2))

H += - J * initLattice

for lattice_point in range(1,N-1):
    curr = jnp.eye(2)
    for i in range(1,lattice_point):
        curr = jnp.kron(curr,jnp.eye(2))
    curr = jnp.kron( jnp.kron(curr,sigma_z),sigma_z)
    for i in range(lattice_point+2,N):
        curr = jnp.kron(curr,jnp.eye(2))
    
    assert curr.shape[0] == H.shape[0]
    
    H += -J * curr


initLattice = sigma_x
for i in range(1,N):
    initLattice = jnp.kron(initLattice,jnp.eye(2))

H += - h * initLattice

for lattice_point in range(1,N-1):
    curr = jnp.eye(2)
    for i in range(1,lattice_point):
        curr = jnp.kron(curr,jnp.eye(2))
    curr = jnp.kron(curr,sigma_x)
    for i in range(lattice_point+1,N):
        curr = jnp.kron(curr,jnp.eye(2))
    
    assert curr.shape[0] == H.shape[0]
    
    H += -h * curr

H = jnp.array(H,dtype=jnp.complex128)

In [9]:
# prepare the ensemble basis

formator = '{0:' + '0' + str(N)  +'b}'

state_in_str = [formator.format(i) for i in range(2**N)]


def state_to_vec(s):
    # return a probability with the corresponding state
    if s[0] == '1':
        state = spin_up
        state_mat = spin_up
    else:
        state = spin_down
        state_mat = spin_down
    
    for curr in s[1:]:
        if curr == '1':
            state = jnp.kron(state,spin_up)
            state_mat = jnp.hstack((state_mat,spin_up))
        else:
            state = jnp.kron(state,spin_down)
            state_mat = jnp.hstack((state_mat,spin_down))
    
    return state,state_mat

allstate = jnp.stack([state_to_vec(s)[0] for s in state_in_str])
allstateMat = jnp.stack([state_to_vec(s)[1] for s in state_in_str])


In [17]:
# Utilities

@jax.jit
def stateMat_to_prob(state_mat,probs):
    prob_mat = jnp.multiply(probs,state_mat)
    prob_state = prob_mat[0,:] + prob_mat[1,:]
    return jnp.prod(prob_state)

@jax.jit
def build_ensemble(px,uni):
    single = lambda p,state: p * uni @ jnp.outer(state,state) @  uni.conjugate().T
    mats = jax.vmap(single,(0,0),0)(px,allstate)
    return jnp.sum(mats, axis=0)

@jax.jit
def weighted_expected(operator,px):
    single = lambda state: state.T @ operator @ state
    expectedVs =  jax.vmap(single)(allstate)
    res = jnp.sum(jnp.multiply(expectedVs[:,0,0],px))
    return jnp.real(res)

@jax.jit
def mat_prod(carry,x):
    return carry @ x ,0

@jax.jit
def unitary_prods(uniVec):
    uniVec = uniVec.reshape((layer_num,(2**(2*N)-1)))
    unis = jax.vmap(vec2Unitary,in_axes=(None,0))(N,uniVec)

    res,_ = jax.lax.scan(mat_prod,jnp.eye(2**N,dtype=jnp.complex128),unis)
    return res


stateMat_to_prob_map = jax.vmap(stateMat_to_prob,(0,None),0)

Free parameters initialization:

In [11]:
eps = jnp.array(0.5 * np.random.rand(N))
uniVec = jnp.array(np.random.rand(uniParas_num))
paras = jnp.hstack((eps,uniVec))

In [12]:
def ising(paras):
    eps = paras[:N]
    uniVec = paras[N:]

    p_down = 1/(1+jnp.exp(-beta * eps))
    p_up = jnp.exp(-beta * eps)/(1+jnp.exp(-beta * eps))
    probs = jnp.vstack((p_down,p_up))

    px = stateMat_to_prob_map(allstateMat,probs)
    unitary = unitary_prods(uniVec)
    operator = unitary.conjugate().T @ H @ unitary

    entropy = jnp.sum(jnp.multiply(px,jnp.log(px)))
    loss = entropy + beta* weighted_expected(operator,px)
    #print(entropy)
    return loss




ising(paras),jax.grad(ising)(paras).shape
#ising(paras)



(DeviceArray(-2.71075407, dtype=float64), (1024,))

In [13]:
import scipy.optimize

def value_and_grad_numpy(f):
    def val_grad_f(*args):
        value, grad = jax.value_and_grad(f)(*args)
        return np.float64(value), np.array(grad,dtype=np.float64)
    return val_grad_f
results = scipy.optimize.minimize(value_and_grad_numpy(ising), np.array(paras,dtype=np.float64),
                                  method='L-BFGS-B', jac=True)
print("success:", results.success, "\nniterations:", results.nit, "\nfinal loss:", results.fun)

success: True 
niterations: 147 
final loss: -3.2436713770210712


In [50]:
-jnp.log(jnp.trace(expm(- beta * H)))

DeviceArray(-3.25023374-0.j, dtype=complex128)

In [15]:
# get all inner stuff
paras = jnp.array(results.x)

eps = paras[:N]
uniVec = paras[N:]

p_down = 1/(1+jnp.exp(-beta * eps))
p_up = jnp.exp(-beta * eps)/(1+jnp.exp(-beta * eps))
probs = jnp.vstack((p_down,p_up))

px = stateMat_to_prob_map(allstateMat,probs)
unitary = unitary_prods(uniVec)
operator = unitary.conjugate().T @ H @ unitary

entropy = jnp.sum(jnp.multiply(px,jnp.log(px)))
loss = entropy + beta* weighted_expected(operator,px)

In [60]:
ensemble = build_ensemble(px,unitary)

@jax.jit
def absorbedState(allstate,unitary):
    transUni = unitary.T.conjugate()
    leftFun = lambda state: unitary @ state
    rightFun = lambda state:  state.T @ transUni
    res = jax.vmap(leftFun)(allstate)
    return res

mStates = absorbedState(allstate,unitary)

In [48]:
# trivial calculation

energies = jnp.linalg.eig(H)[0]
states = jnp.linalg.eig(H)[1]

allSx = sigma_x.copy()
for i in range(N-1):
    allSx = jnp.kron(allSx,sigma_x)

omega = 0.451
sf = 0
eps = 10e-3
ops = allSx

for m in range(2**N):
    for m_p in range(2**N):

        if (energies[m_p] - energies[m] - omega) > eps:
            continue

        elem = states[:,m_p].T.conjugate() @ ops @ states[:,m]

        sf += jnp.exp(-energies[m_p] * beta) * elem * elem.conjugate()

sf/jnp.trace(expm(- beta * H))

DeviceArray(1.01378431+0.j, dtype=complex128)

In [64]:

# sum of the energies is 0
jnp.sum(jnp.hstack([(allstate[i].T @ H @ allstate[i])[0] for i in range(2**N)]))

es = []
for m in range(2**N):
    for m_p in range(2**N):

        es += [energies[m_p] - energies[m]]
uniqueEDeltas = jnp.unique(jnp.abs(jnp.array(es).round(decimals=4)))
uniqueEDeltas

DeviceArray([0.    , 0.3199, 0.3952, 0.3953, 0.451 , 0.7152, 0.7709,
             1.0351, 1.1662, 1.4861, 1.8814, 1.9371, 2.2013, 2.6523,
             3.3675], dtype=float64)

In [66]:
# estimation

energies = [state.T.conjugate() @ H @ state for state in mStates]
states = mStates

sf = 0
for m in range(2**N):
    for m_p in range(2**N):

        if (energies[m_p] - energies[m] - omega) > eps:
            continue

        elem = states[:,m_p].T.conjugate() @ ops @ states[:,m]

        sf += jnp.exp(-energies[m_p] * beta) * elem * elem.conjugate() 

sf/jnp.exp(-loss)

DeviceArray([[0.81563282-1.93746254e-17j]], dtype=complex128)

### Put it all together

In [97]:
def strFactor(energies,states,omega,Z,ops):
    eps = 10e-4
    sf = 0
    for m in range(2**N):
        for m_p in range(2**N):
            #print(omega,energies[m_p] - energies[m]- omega)
            if (energies[m_p] - energies[m] - omega) > eps:
                #print('jump')
                continue

            elem = states[:,m_p].T.conjugate() @ ops @ states[:,m]

            sf += jnp.exp(-energies[m_p] * beta) * elem * elem.conjugate()

    return np.real(sf/Z)

In [98]:
exactEnergies = jnp.linalg.eig(H)[0]
exactStates = jnp.linalg.eig(H)[1]

estEnergies = [state.T.conjugate() @ H @ state for state in mStates]
estStates = mStates

exactZ =jnp.trace(expm(- beta * H))
estZ = jnp.exp(-loss)

allSx = sigma_x.copy()
for i in range(N-1):
    allSx = jnp.kron(allSx,sigma_x)

exactSF = [strFactor(exactEnergies,exactStates,omega,exactZ,allSx) for omega in uniqueEDeltas]


In [99]:
estSF = [strFactor(estEnergies,estStates,omega,estZ,allSx) for omega in uniqueEDeltas]

In [103]:
#strFactor(exactEnergies,exactStates,omega,exactZ,allSx)
np.array(exactSF),np.array(estSF).flatten()

(array([1.01378431, 1.01378431, 1.01378431, 1.01378431, 1.01378431,
        1.01378431, 1.01378431, 1.01378431, 1.01378431, 1.01378431,
        1.01378431, 1.01378431, 1.01378431, 1.01378431, 1.01378431]),
 array([0.77684825, 0.80731117, 0.80958715, 0.80958715, 0.81563282,
        0.8949174 , 0.91720312, 0.91832393, 0.97559712, 0.98328268,
        0.98436404, 0.98468355, 0.99607126, 1.00609578, 1.00658384]))

In [30]:
allstate[2].T @ allSx @ allstate[1]

DeviceArray([[0]], dtype=int64)

In [40]:
jnp.where(allSx != 0 )[0]

DeviceArray([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,
              12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,
              24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,
              36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,
              48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
              60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
              72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
              84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,
              96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107,
             108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
             120, 121, 122, 123, 124, 125, 126, 127], dtype=int64)

DeviceArray([[0, 0, 0, ..., 0, 0, 1],
             [0, 0, 0, ..., 0, 1, 0],
             [0, 0, 0, ..., 1, 0, 0],
             ...,
             [0, 0, 1, ..., 0, 0, 0],
             [0, 1, 0, ..., 0, 0, 0],
             [1, 0, 0, ..., 0, 0, 0]], dtype=int64)