In [54]:
from flax import nnx
import jax
import jax.numpy as jnp
import pennylane as qml

In [55]:
N_QUBITS = 10
N_LAYERS = 6
BATCH_SIZE = 16
LEARNING_RATE = 1e-4

In [56]:
def make_circuit(dev, n_qubits, n_layers):
    @qml.qnode(dev)
    def circuit(x, circuit_weights):
        # data encoding
        for i in range(n_qubits):
            qml.Hadamard(wires=i)
            qml.RY(x[i], wires=i)
            
        # trainable unitary
        for layer in range(n_layers):
            for i in range(n_qubits):
                qml.RY(circuit_weights[layer, i, 0], wires=i)  
                qml.RZ(circuit_weights[layer, i, 1], wires=i) 
                qml.RX(circuit_weights[layer, i, 2], wires=i)  

            for idx in range(n_qubits-1):
                qml.CNOT(wires=[idx, idx+1])
            qml.CNOT(wires=[n_qubits-1, 0])

        return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

    return jax.vmap(circuit, in_axes=(0, None)) # We return the vectorized circuit and indicate that the batch is in the first dimension.


In [57]:
class QuantumCircuit(nnx.Module):
  def __init__(self, num_qubits, num_layers, device, rngs = nnx.Rngs(42)):
    key = rngs.params()

    weight_shapes = (num_layers, num_qubits, 3)
    self.weights = jnp.array([
        nnx.Param(jax.random.uniform(key, shape=weight_shapes).astype(jnp.float32))
    ])
    self.circuit = make_circuit(device, num_qubits, num_layers)

  def __call__(self, x: jax.Array):
    return self.circuit(x, self.weights)

In [58]:
dev = qml.device('default.qubit', wires=N_QUBITS)
qc = QuantumCircuit(N_QUBITS, N_LAYERS, dev)

In [60]:
qc(jnp.empty((4, 10)))

[Array([[-0.03767726, -0.02900788, -0.0329062 ],
        [-0.03767726, -0.02900788, -0.0329062 ],
        [-0.03767726, -0.02900788, -0.0329062 ],
        [-0.03767726, -0.02900788, -0.0329062 ]], dtype=float32),
 Array([[ 0.02241525, -0.0077256 , -0.01964784],
        [ 0.02241525, -0.0077256 , -0.01964784],
        [ 0.02241525, -0.0077256 , -0.01964784],
        [ 0.02241525, -0.0077256 , -0.01964784]], dtype=float32),
 Array([[-0.0238865 ,  0.01320735,  0.01969719],
        [-0.0238865 ,  0.01320735,  0.01969719],
        [-0.0238865 ,  0.01320735,  0.01969719],
        [-0.0238865 ,  0.01320735,  0.01969719]], dtype=float32),
 Array([[ 0.0151712 , -0.06157526, -0.01836556],
        [ 0.0151712 , -0.06157526, -0.01836556],
        [ 0.0151712 , -0.06157526, -0.01836556],
        [ 0.0151712 , -0.06157526, -0.01836556]], dtype=float32),
 Array([[ 0.00035176,  0.00778925, -0.02163187],
        [ 0.00035176,  0.00778925, -0.02163187],
        [ 0.00035176,  0.00778925, -0.02163187],
 