In [1]:
import pennylane as qml
from jax import numpy as jnp
# import numpy as np
import optax
import catalyst

n_wires = 4
# data = np.sin(np.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev = qml.device("default.qubit", wires=n_wires)

@qml.qnode(dev)
def circuit(data, weights):
    """Quantum circuit ansatz"""

    @qml.for_loop(0, n_wires, 1)
    def data_embedding(i):
        qml.RY(data[i], wires=i)

    data_embedding()

    @qml.for_loop(0, n_wires, 1)
    def ansatz(i):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    ansatz()

    # we use a sum of local Z's as an observable since a
    # local Z would only be affected by params on that qubit.
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

: 

In [1]:
import jax.numpy as jnp

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]
