In [2]:
import jax.numpy as jnp

N = 10  # Assuming some value for N
# Create a range of values from 0 to N-3
i_values = jnp.arange(N-3)[:, None]  # Reshape to use broadcasting
# Create the loc_array_bulk using broadcasting
loc_array_bulk = i_values + jnp.array([0, 1, 2])


In [3]:
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

# Define a pytree structure
params = [
        jnp.array([1.0, 2.0, 3.0]), 
        jnp.array([4.0, 5.0]),
        2.5,
        jnp.array([0.5, 1.5])
    ]


# Flatten the pytree into a single 1D array
flat_params, unflatten = ravel_pytree(params)

# Now flat_params is a single 1D array containing all the parameters
print("Flattened parameters:", flat_params)

# Do something with flat_params (e.g., pass through an optimizer)

# Then unflatten the optimized parameters back to the original structure
new_params = unflatten(flat_params)

# new_params has the same structure as params
print("Unflattened parameters:", new_params)


Flattened parameters: [1.  2.  3.  4.  5.  2.5 0.5 1.5]
Unflattened parameters: [Array([1., 2., 3.], dtype=float32), Array([4., 5.], dtype=float32), Array(2.5, dtype=float32), Array([0.5, 1.5], dtype=float32)]


In [78]:
import numpy as np
t1 = np.random.rand(1000000)
t2 = np.random.rand(1000000)*(1-t1)
t3 = np.random.rand(1000000)*(1-t1-t2)
t4 = 1-t1-t2-t3
p1 = np.random.rand(1000000)*2*np.pi
p2 = np.random.rand(1000000)*2*np.pi
p3 = np.random.rand(1000000)*2*np.pi
p4 = np.random.rand(1000000)*2*np.pi
x1 = np.sqrt(t1)*np.exp(1j*p1)
x2 = np.sqrt(t2)*np.exp(1j*p2)
x3 = np.sqrt(t3)*np.exp(1j*p3)
x4 = np.sqrt(t4)*np.exp(1j*p4)

In [79]:
np.sum(np.abs(x1)**2+np.abs(x2)**2+np.abs(x3)**2+np.abs(x4)**2)

1000000.0

In [80]:
px1 = np.abs(x1)**2
px2 = np.abs(x2)**2
px3 = np.abs(x3)**2
px4 = np.abs(x4)**2
px1x4 = px1+px4
px2x3 = px2+px3
C = 2*x1*x4-2*x2*x3
cp1 = (1+np.sqrt(1-np.abs(C)**2))/2
cp2 = (1-np.sqrt(1-np.abs(C)**2))/2
a1 = -px1*np.log(px1)-px2*np.log(px2)-px3*np.log(px3)-px4*np.log(px4)+px1x4*np.log(px1x4)+px2x3*np.log(px2x3)
b1 = -cp1*np.log(cp1)-cp2*np.log(cp2)

In [81]:
print(x1[636373], x2[636373], x3[636373], x4[636373])

(-0.6955850017902521+0.19609206259357648j) (0.05270686824463031-0.3134852381631305j) (-0.10491433363726148+0.5133506686988841j) (0.2637844027233604-0.18038919643270798j)


In [82]:
print(np.sum(a1>b1))

1000000


In [20]:
import numpy as np
a0 = np.random.rand(10000)
a1 = np.random.rand(10000)
a2 = np.random.rand(10000)
a3 = np.random.rand(10000)
x0 = a0**2+a1**2+a2**2+a3**2
x1 = np.sqrt((a0**2+a1**2-a2**2-a3**2)**2+4*(a0*a2+a1*a3)**2)
Sa = -a0*np.log(a0)-a1*np.log(a1)-a2*np.log(a2)-a3*np.log(a3)
Sx = -x0*np.log(x0)-x1*np.log(x1)

In [17]:
print(np.sum(a0+a1+a2+a3<=1))
np.sum((Sa>Sx)&(a0+a1+a2+a3<=1))

431


417

In [1]:
import jax
import jax.numpy as jnp
from jax import random
def unitary(key, n):
    a, b = random.normal(key, (2, n, n))
    z = a + b * 1j
    q, r = jnp.linalg.qr(z)
    d = jnp.diag(r)
    return q * d / abs(d)

In [10]:
from jax import vmap
a = vmap(unitary, (0, None), 0)(random.split(random.PRNGKey(0), 9), 2).reshape(3,3,2,2)

In [11]:
for i in a:
    for j in i: 
        print(j@j.conj().T)

[[1.0000000e+00+0.0000000e+00j 5.9604645e-08-2.9802322e-08j]
 [5.9604645e-08+2.9802322e-08j 1.0000000e+00+0.0000000e+00j]]
[[ 1.0000000e+00+0.0000000e+00j -2.9802322e-08+2.9802322e-08j]
 [-2.9802322e-08-2.9802322e-08j  1.0000000e+00+0.0000000e+00j]]
[[1.0000002e+00+0.0000000e+00j 2.2351742e-08+2.0489097e-08j]
 [2.2351742e-08-2.0489097e-08j 1.0000001e+00+0.0000000e+00j]]
[[1.0000001e+00+0.j 8.9406967e-08+0.j]
 [8.9406967e-08+0.j 1.0000000e+00+0.j]]
[[9.9999964e-01+0.j 2.9802322e-08+0.j]
 [2.9802322e-08+0.j 1.0000000e+00+0.j]]
[[1.0000000e+00+0.0000000e+00j 2.9802322e-08-1.1129305e-07j]
 [2.9802322e-08+1.1129305e-07j 9.9999988e-01+0.0000000e+00j]]
[[ 9.9999988e-01+0.0000000e+00j -4.0978193e-08+1.4901161e-08j]
 [-4.0978193e-08-1.4901161e-08j  9.9999994e-01+0.0000000e+00j]]
[[ 9.99999821e-01+0.0000000e+00j -1.07102096e-07-1.5832484e-08j]
 [-1.07102096e-07+1.5832484e-08j  1.00000000e+00+0.0000000e+00j]]
[[1.0000002e+00+0.0000000e+00j 5.9604645e-08-2.9802322e-08j]
 [5.9604645e-08+2.9802322e-