In [2]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import random

key = random.PRNGKey(80)
random.uniform(key)

def print_matrix(mat):
    with np.printoptions(precision=3, suppress=True):
        print(mat)

def assert_equal(mat1,mat2,eps = 0.0000001):
    assert jnp.sum((mat1-mat2)**2) < eps

In [3]:
np.random.seed(1)
N = 4
beta = 1

sigma_z = jnp.array([[1,0],[0,-1]])
sigma_x = jnp.array([[0,1],[1,0]])
sigma_y = jnp.array([[0,1],[1,0]])

# Generate x 

spin_choices = jnp.hstack(([[0],[1]],)*N)
spin = jnp.array([0,1])
spin_up = jnp.array([[1],[0]])
spin_down= jnp.array([[0],[1]])

epsilons = np.random.rand(2,N)
eps = jnp.array(0.5 * np.random.rand(N))
p_down = 1/(1+jnp.exp(-beta * eps))
p_up = jnp.exp(-beta * eps)/(1+jnp.exp(-beta * eps))
probs = jnp.vstack((p_down,p_up))

"""
def choose_spin(prob):
    key = random.split()
    choice = random.choice(key,spin,p=prob)
    up_func = lambda x: spin_up
    down_func = lambda x: spin_down
    if choice == 0:
        return spin_up
    else:
        return spin_down
    return jax.lax.cond(choice,up_func,down_func,0)
    rand_spin =  jax.vmap(choose_spin)(probs.T)
    rand_spin = rand_spin.T
"""
p_down.shape,p_up.shape


((4,), (4,))

In [4]:
# create H

J = 0.1

H = jnp.zeros((2**N,2**N))

initLattice = jnp.kron(sigma_z,sigma_z)
for i in range(2,N):
    initLattice = jnp.kron(initLattice,jnp.eye(2))

H += - J * initLattice

for lattice_point in range(1,N-1):
    curr = jnp.eye(2)
    for i in range(1,lattice_point):
        curr = jnp.kron(curr,jnp.eye(2))
    curr = jnp.kron( jnp.kron(curr,sigma_z),sigma_z)
    for i in range(lattice_point+2,N):
        curr = jnp.kron(curr,jnp.eye(2))
    
    assert curr.shape[0] == H.shape[0]
    
    H += -J * curr



In [5]:

x0 = jnp.zeros((2**N,1))
x0  = x0.at[0].set(1)
x0.T @ H @ x0

DeviceArray([[-0.3]], dtype=float32)

In [6]:
vec = random.uniform(key,(3,1))
vec

DeviceArray([[0.737398  ],
             [0.5460452 ],
             [0.05000734]], dtype=float32)

It holds that $\det(e^{A}) = e^{Tr(A)}$. The trace of the parametric hermitian H should be vanishing to ensure the unitary has unit det.

In [7]:
vec = random.uniform(key,(3,1))
jnp.diag(vec[0:1,0],k=-1) + jnp.diag(vec[1:,0])


vec = random.uniform(key,(6,1))
jnp.diag(vec[0:1,0],k=-2) + jnp.diag(vec[1:3,0],k=-1) + jnp.diag(vec[3:,0])

DeviceArray([[0.13638556, 0.        , 0.        ],
             [0.07172787, 0.08502686, 0.        ],
             [0.5093068 , 0.09822512, 0.8820245 ]], dtype=float32)

In [8]:
#Converts unitary matrix to a vector of parameters (which parameterize the 
#unitary). This vector effectively parameterizes SU(4).

#Unitary is V = exp(1i*B); B is Hermitian


mat_size = int(2**N)

input = random.uniform(key,(1,mat_size**2-1))

input = input[0,:]

vec = jnp.append(input,-jnp.sum(input[-mat_size+1:]))

h_mat = jnp.diag(vec[-mat_size:])

assert jnp.trace(h_mat) < 0.00001

count = 0
#real part
for i in range(1,mat_size):
    h_mat += jnp.diag(vec[count:count+i],k=-mat_size+i)
    count += i

#img part
for i in range(1,mat_size):
    h_mat += jnp.diag(1j * vec[count:count+i],k=-mat_size+i)
    count += i

h_mat = (h_mat.conjugate().T + h_mat)/2

assert jnp.trace(h_mat) < 0.000001
assert_equal(h_mat,h_mat.conjugate().T)

unitary = jax.scipy.linalg.expm(1j * h_mat)


assert_equal(unitary.conjugate().T @ unitary,jnp.eye(mat_size))


In [9]:
# prepare the ensemble

formator = '{0:' + '0' + str(N)  +'b}'

state_in_str = [formator.format(i) for i in range(2**N)]


def state_to_vec(s):
    # return a probability with the corresponding state
    if s[0] == '1':
        state = spin_up
        state_mat = spin_up
    else:
        state = spin_down
        state_mat = spin_down
    
    for curr in s[1:]:
        if curr == '1':
            state = jnp.kron(state,spin_up)
            state_mat = jnp.hstack((state_mat,spin_up))
        else:
            state = jnp.kron(state,spin_down)
            state_mat = jnp.hstack((state_mat,spin_down))
    
    return state,state_mat

allstate = [state_to_vec(s)[0] for s in state_in_str]
allstateMat = [state_to_vec(s)[1] for s in state_in_str]
#jnp.hstack(tuple(allstate)),state_in_str


def stateMat_to_prob(state_mat):
    prob_mat = jnp.multiply(probs,state_mat)
    prob = prob_mat[0,:] + prob_mat[1,:]
    return jnp.prod(prob)

allprobs = jnp.hstack((stateMat_to_prob(mat) for mat in allstateMat))

jnp.sum(allprobs)




DeviceArray(1., dtype=float32)

DeviceArray([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [9]:
vec = np.random.rand(2,N)
spins = vec/ np.vstack((np.linalg.norm(vec,axis=0),)*2)
porb_up = spins[0,:] 
dist = np.prod()

TypeError: _prod_dispatcher() missing 1 required positional argument: 'a'

In [25]:
def bloch_to_rho(vec):
    return (jnp.identity(2) + vec[0] * sigma_x + vec[1]  * sigma_y)/2

def bloch_to_rho_prod(blochVecs):

    rho_prod = bloch_to_rho(blochVecs[0,:])
    for i in range(1,blochVecs.shape[-1]):
        rho_prod = jnp.kron(rho_prod,bloch_to_rho(blochVecs[:,i]))
    return rho_prod

bloch_to_rho_prod(np.random.rand(2,10)).shape

(1024, 1024)

In [27]:

Pr = (jnp.identity(2) - sigma_z)/2


rho_prod = (1/(1 + jnp.exp(- beta*eps[0]))) * jnp.expm1(- beta* eps[0] * Pr)

def rho_prod_sacn(carry,eps_r):
    y = (1/(1 + jnp.exp(- beta*eps_r))) * jnp.expm1(- beta* eps_r * Pr)
    return jnp.kron(carry,y),0



for i in range(1,N):
    eps_r = eps[i]
    y = (1/(1 + jnp.exp(- beta*eps_r))) * jnp.expm1(- beta* eps_r * Pr)
    rho_prod = jnp.kron(rho_prod,y)

#jax.lax.scan(rho_prod_sacn,init,eps[1:])
y



DeviceArray([[-0.        , -0.        ],
             [-0.        , -0.00684679]], dtype=float32)

In [26]:
jnp.sum(jnp.abs(rho_prod)>0)

DeviceArray(1, dtype=int32)

In [33]:
np.outer(x,x.T) @ x

array([0.30151134, 0.30151134, 0.        , 0.        , 0.30151134,
       0.30151134, 0.30151134, 0.30151134, 0.30151134, 0.        ,
       0.        , 0.30151134, 0.        , 0.30151134, 0.30151134,
       0.        , 0.        , 0.30151134, 0.        , 0.        ])

In [None]:
np.random( )