In [99]:
from functools import partial
import jax
import jax.numpy as jnp

def local_element_indices_2d(num_body, pauli_array, loc_array):
    if pauli_array.shape[-1] != num_body:
        raise ValueError(f"Array has incorrect body of interactions {pauli_array.shape[-1]}. Expected body of interactions is {num_body}.")

    count_3s = jnp.sum(pauli_array == 3, axis = 1)
    count_1s = jnp.sum(pauli_array == 1, axis = 1)

    pauli_array_xz = {}
    xloc_arrays = {}
    zloc_arrays = {}
    yloc_arrays = {}
    xy_loc_arrays = {}

    for i in range(num_body+1):    #z_number
        for j in range (num_body+1-i):  #x_number
            mask = ((count_3s == i)&(count_1s == j))
            pauli_array_xz[i, j] = pauli_array[mask]

            mask_x = (pauli_array_xz[i, j] == 1)
            mask_y = (pauli_array_xz[i, j] == 2)
            mask_z = (pauli_array_xz[i, j] == 3)

            if mask_x.sum() != 0:
                xloc_arrays[i, j] = loc_array[mask][mask_x].reshape(-1, j, 2)
            elif mask_y.sum() != 0 or mask_z.sum() !=0:
                if mask_y.sum()!=0:
                    xloc_arrays[i, j] = jnp.array([[]]).reshape(pauli_array_xz[i ,j].shape[0], 0, 2).astype(int)
                else:
                    xloc_arrays[i, j] = jnp.array([[]]).reshape(pauli_array_xz[i, j].shape[0], 0, 2).astype(int)

            if mask_y.sum() !=0:
                yloc_arrays[i, j] = loc_array[mask][mask_y].reshape(-1, num_body-i-j, 2).astype(int)
            elif mask_x.sum() != 0 or mask_z.sum() !=0:
                if mask_x.sum()!=0:
                    yloc_arrays[i, j] = jnp.array([[]]).reshape(pauli_array_xz[i, j].shape[0], 0, 2).astype(int)
                else:
                    yloc_arrays[i, j] = jnp.array([[]]).reshape(pauli_array_xz[i, j].shape[0], 0, 2).astype(int)

            if mask_z.sum()!=0:
                zloc_arrays[i, j] = loc_array[mask][mask_z].reshape(-1, i, 2).astype(int)
            elif mask_x.sum() != 0 or mask_y.sum() !=0:
                if mask_y.sum()!=0:
                    zloc_arrays[i, j] = jnp.array([[]]).reshape(pauli_array_xz[i, j].shape[0], 0, 2).astype(int)
                else:
                    zloc_arrays[i, j] = jnp.array([[]]).reshape(pauli_array_xz[i, j].shape[0], 0, 2).astype(int)
    print("xlco_arrays:", xloc_arrays)
    print("yloc_arrays:", yloc_arrays)
    for ind in (xloc_arrays):
        xy_loc_arrays[ind] = jnp.concatenate((xloc_arrays[ind], yloc_arrays[ind]), axis=1).astype(int)
    return  xy_loc_arrays, yloc_arrays, zloc_arrays
@jax.jit
def total_samples_2d(samples, xyloc):
    def scan_array_element(sample_element, xyloc_arrays_element):
        scan_samples = sample_element.at[xyloc_arrays_element[:,0], xyloc_arrays_element[:,1]].set((sample_element[xyloc_arrays_element[:,0], xyloc_arrays_element[:,1]]+1)%2)
        return sample_element, scan_samples
    sample_tmp = samples
    for xyloc_ind in xyloc:
        if xyloc[xyloc_ind].size != 0:
            sample_tmp = jnp.append(sample_tmp, jax.lax.scan(scan_array_element, samples, xyloc[xyloc_ind])[1])
    return sample_tmp.reshape(-1, samples.shape[0], samples.shape[1])

@jax.jit
def new_coe_2d(samples, coe_array, yloc, zloc):
    def ycoe(sample_element, yloc_arrays_element):
        scan_coe_tmp_y = ((-1)**sample_element[yloc_arrays_element[:,0],yloc_arrays_element[:,1]]*1j).prod()
        return sample_element, scan_coe_tmp_y
    def zcoe(sample_element, zloc_arrays_element):
        scan_coe_tmp_z = ((-1)**sample_element[zloc_arrays_element[:,0], zloc_arrays_element[:,1]]).prod()
        return sample_element, scan_coe_tmp_z
    coe_tmp_y = jnp.array([1])
    coe_tmp_z = jnp.array([1])
    for yloc_ind in yloc:
        if yloc[yloc_ind].shape[0] != 0:
            coe_tmp_y = jnp.append(coe_tmp_y, jax.lax.scan(ycoe, samples, yloc[yloc_ind])[1])
        else:
            coe_tmp_y = jnp.append(coe_tmp_y, 1)
    for zloc_ind in zloc:
        if zloc_ind[1] == 0:
            coe_tmp_z = coe_tmp_z.at([0]).set(jax.lax.scan(zcoe, samples, zloc[zloc_ind])[1].sum())
        elif zloc[zloc_ind].shape[0] != 0:
            coe_tmp_z = jnp.append(coe_tmp_z, jax.lax.scan(zcoe, samples, zloc[zloc_ind])[1])
        else:
            coe_tmp_z = jnp.append(coe_tmp_z, 1)
    return coe_tmp_y*coe_tmp_z*jnp.concatenate((jnp.array([1]), coe_array), axis=0)

In [118]:
from jax import random
from jax.random import PRNGKey, randint
import numpy as np
pauli_array = jnp.repeat(jnp.array([1,3,3,3,3])[None], 16, axis = 0).astype(int)
print("pauli_array:", pauli_array)
loc_array = jnp.array([[]])
for i in range (1, 5):
    for j in range (1, 5):
        loc_array = jnp.append(loc_array, jnp.array([[i,j],[i,j+1],[i+1,j],[i-1,j], [i, j-1]]))
loc_array = loc_array.reshape(16, 5, 2)
print("loc_array:", loc_array)
xy_loc, yloc, zloc = local_element_indices_2d(5, pauli_array, loc_array)
print("zloc:", zloc)

pauli_array: [[1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]
 [1 3 3 3 3]]
loc_array: [[[1. 1.]
  [1. 2.]
  [2. 1.]
  [0. 1.]
  [1. 0.]]

 [[1. 2.]
  [1. 3.]
  [2. 2.]
  [0. 2.]
  [1. 1.]]

 [[1. 3.]
  [1. 4.]
  [2. 3.]
  [0. 3.]
  [1. 2.]]

 [[1. 4.]
  [1. 5.]
  [2. 4.]
  [0. 4.]
  [1. 3.]]

 [[2. 1.]
  [2. 2.]
  [3. 1.]
  [1. 1.]
  [2. 0.]]

 [[2. 2.]
  [2. 3.]
  [3. 2.]
  [1. 2.]
  [2. 1.]]

 [[2. 3.]
  [2. 4.]
  [3. 3.]
  [1. 3.]
  [2. 2.]]

 [[2. 4.]
  [2. 5.]
  [3. 4.]
  [1. 4.]
  [2. 3.]]

 [[3. 1.]
  [3. 2.]
  [4. 1.]
  [2. 1.]
  [3. 0.]]

 [[3. 2.]
  [3. 3.]
  [4. 2.]
  [2. 2.]
  [3. 1.]]

 [[3. 3.]
  [3. 4.]
  [4. 3.]
  [2. 3.]
  [3. 2.]]

 [[3. 4.]
  [3. 5.]
  [4. 4.]
  [2. 4.]
  [3. 3.]]

 [[4. 1.]
  [4. 2.]
  [5. 1.]
  [3. 1.]
  [4. 0.]]

 [[4. 2.]
  [4. 3.]
  [5. 2.]
  [3. 2.]
  [4. 1.]]

 [[4. 3.]
  [4. 4.]
  [5. 3.]
  [3. 3.]

In [119]:
from jax import vmap
samples = jnp.array([[1,1,1,0,0,1],[0,1,1,1,0,1],[1,0,1,1,1,0],[1,1,0,1,1,1],[0,1,1,1,1,1],[0,1,0,1,1,1]])
batch_totoal_samples_2d = vmap(total_samples_2d, (0, None))
total_samples_2d(samples, xy_loc)

Array([[[1, 1, 1, 0, 0, 1],
        [0, 1, 1, 1, 0, 1],
        [1, 0, 1, 1, 1, 0],
        [1, 1, 0, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [0, 1, 0, 1, 1, 1]],

       [[1, 1, 1, 0, 0, 1],
        [0, 0, 1, 1, 0, 1],
        [1, 0, 1, 1, 1, 0],
        [1, 1, 0, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [0, 1, 0, 1, 1, 1]],

       [[1, 1, 1, 0, 0, 1],
        [0, 1, 0, 1, 0, 1],
        [1, 0, 1, 1, 1, 0],
        [1, 1, 0, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [0, 1, 0, 1, 1, 1]],

       [[1, 1, 1, 0, 0, 1],
        [0, 1, 1, 0, 0, 1],
        [1, 0, 1, 1, 1, 0],
        [1, 1, 0, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [0, 1, 0, 1, 1, 1]],

       [[1, 1, 1, 0, 0, 1],
        [0, 1, 1, 1, 1, 1],
        [1, 0, 1, 1, 1, 0],
        [1, 1, 0, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [0, 1, 0, 1, 1, 1]],

       [[1, 1, 1, 0, 0, 1],
        [0, 1, 1, 1, 0, 1],
        [1, 1, 1, 1, 1, 0],
        [1, 1, 0, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [0

In [120]:
print(yloc[4,1].shape)
coe_array = jnp.ones(16)
new_coe_2d(samples, coe_array, yloc, zloc, 5)

(16, 0, 2)
{(4, 1): Traced<ShapedArray(int32[16,4,2])>with<DynamicJaxprTrace(level=1/0)>}
coe_tmp_y: Traced<ShapedArray(complex64[17])>with<DynamicJaxprTrace(level=1/0)>
coe_tmp_z: Traced<ShapedArray(int32[17])>with<DynamicJaxprTrace(level=1/0)>


Array([ 1.+0.j,  1.+0.j,  1.+0.j,  1.+0.j, -1.+0.j,  1.+0.j,  1.+0.j,
        1.+0.j,  1.+0.j,  1.+0.j,  1.+0.j, -1.+0.j,  1.+0.j, -1.+0.j,
        1.+0.j,  1.+0.j,  1.+0.j], dtype=complex64)

In [115]:
jnp.repeat(jnp.array([1,3,3,3,3]), jnp.array([1,2,3,4,5]), axis=0)
jnp.repeat(jnp.log(sample_amp), jnp.ones(numsamples)*Nx*Ny+1, axis=0)

Array([1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], dtype=int32)

In [88]:
import jax.numpy as jnp

Ny, Nx = 4, 4  # Define Ny and Nx with assumed values for demonstration

# Pre-allocate bulk locations with array operations
I, J = jnp.meshgrid(jnp.arange(1, Ny-1), jnp.arange(1, Nx-1), indexing='ij')
bulk_coordinates = jnp.stack([I, J, I, J+1, I+1, J, I-1, J, I, J-1], axis=-1)
loc_array_bulk = bulk_coordinates.reshape(-1, 5, 2)

# Create edges excluding corners
edge_top = jnp.stack([jnp.ones(Nx-2), jnp.arange(1, Nx-1)], axis=-1)
edge_bottom = jnp.stack([jnp.full((Nx-2,), Ny-1), jnp.arange(1, Nx-1)], axis=-1)
edge_left = jnp.stack([jnp.arange(1, Ny-1), jnp.zeros(Ny-2)], axis=-1)
edge_right = jnp.stack([jnp.arange(1, Ny-1), jnp.full((Ny-2,), Nx-1)], axis=-1)
edges = jnp.concatenate([edge_top, edge_bottom, edge_left, edge_right])
loc_array_edge = jnp.repeat(edges[:, None, :], 4, axis=1) # Repeat the edge pattern

# Fix the edge coordinates with appropriate values
loc_array_edge = loc_array_edge.at[:, 1:, 0].add(1)  # Increment i or j based on direction
loc_array_edge = loc_array_edge.at[:, 2, 1].set(0)    # Set j=0 for left edge increments
loc_array_edge = loc_array_edge.at[:, 3, 0].set(Ny-1) # Set i=Ny-1 for bottom edge increments
loc_array_edge = loc_array_edge.reshape(-1, 4, 2)

# Pre-defined corner locations
loc_array_corner = jnp.array([[[0,0],[0,1],[1,0]],
                              [[0, Nx-1],[0, Nx-2], [1, Nx-1]],
                              [[Ny-1, 0],[Ny-1, 1],[Ny-2,0]],
                              [[Ny-1, Nx-1],[Ny-1, Nx-2],[Ny-2, Nx-1]]])

# Pauli arrays can be created directly since their pattern doesn't depend on Ny or Nx
pauli_array_bulk = jnp.tile(jnp.array([[1,3,3,3,3]]), ((Ny-2)*(Nx-2), 1))
pauli_array_edge = jnp.tile(jnp.array([[1,3,3,3]]), ((Ny+Nx-4)*2, 1))
pauli_array_corner = jnp.tile(jnp.array([[1,3,3]]), (4, 1))

# It's assumed local_element_indices_2d function exists and operates on the inputs provided
# xy_loc_bulk, yloc_bulk, zloc_bulk = local_element_indices_2d(5, pauli_array_bulk, loc_array_bulk)
# xy_loc_edge, yloc_edge, zloc_edge = local_element_indices_2d(4, pauli_array_edge, loc_array_edge)
# xy_loc_corner, yloc_corner, zloc_corner = local_element_indices_2d(3, pauli_array_corner, loc_array_corner)


In [94]:
jnp.meshgrid(jnp.arange(1, Ny-1), jnp.arange(1, Nx-1), indexing='ij')

[Array([[1, 1],
        [2, 2]], dtype=int32),
 Array([[1, 2],
        [1, 2]], dtype=int32)]

In [97]:
# Add edge coordinates excluding corners for the left and right sides of the grid
edge_coordinates = []
for i in range(1, Ny - 1):
    edge_coordinates.extend([[i, 0], [i, 1], [i + 1, 0], [i - 1, 0]])
    edge_coordinates.extend([[i, Nx - 1], [i, Nx - 2], [i + 1, Nx - 1], [i - 1, Nx - 1]])

# Add edge coordinates excluding corners for the top and bottom of the grid
for j in range(1, Nx - 1):
    edge_coordinates.extend([[0, j], [1, j], [0, j - 1], [0, j + 1]])
    edge_coordinates.extend([[Ny - 1, j], [Ny - 2, j], [Ny - 1, j - 1], [Ny - 1, j + 1]])

# Convert list to a JAX array
loc_array_edge = jnp.array(edge_coordinates).reshape(-1, 4, 2)

In [98]:
loc_array_edge

Array([[[1, 0],
        [1, 1],
        [2, 0],
        [0, 0]],

       [[1, 3],
        [1, 2],
        [2, 3],
        [0, 3]],

       [[2, 0],
        [2, 1],
        [3, 0],
        [1, 0]],

       [[2, 3],
        [2, 2],
        [3, 3],
        [1, 3]],

       [[0, 1],
        [1, 1],
        [0, 0],
        [0, 2]],

       [[3, 1],
        [2, 1],
        [3, 0],
        [3, 2]],

       [[0, 2],
        [1, 2],
        [0, 1],
        [0, 3]],

       [[3, 2],
        [2, 2],
        [3, 1],
        [3, 3]]], dtype=int32)