In [1]:
import numpy as np
import jax 
from jax import jit
from jax import numpy as jnp
import pennylane as qml
import time

jax.config.update("jax_enable_x64", True)

n_qubits = 16
dev = qml.device("default.qubit", wires=n_qubits) # For some reason, lightning.qubit makes JIT much worse but the NO-JIT version faster.

# Just-in-Time Compilation

In [2]:
@qml.qnode(dev, interface="jax")
def circuit(param):
    for idx in range(n_qubits):
        qml.RX(param, wires=idx)

    for idx in range(n_qubits-1):
        qml.CNOT(wires=[idx, idx+1])

    qml.CNOT(wires=[n_qubits-1, 0])

    return qml.expval(qml.PauliZ(0))

jit_circuit = jax.jit(circuit)

In [3]:
start = time.time()
for _ in range(100):
    circuit(0.123)
no_jit_time = time.time() - start

In [4]:
start = time.time()
for _ in range(100):
    jit_circuit(0.123)
first_time = time.time() - start

In [5]:

print("Time to do 100 execugtions")
print(f"\tNo JIT: {no_jit_time:0.8f} seconds")
print(f"\t   JIT: {first_time:0.8f} seconds")

Time to do 100 execugtions
	No JIT: 3.44696403 seconds
	   JIT: 0.32424426 seconds


# Automatic Vectorization

# Automatic Differentiation