In [1]:
import numpy as np
import jax 
from jax import jit
from jax import numpy as jnp
import pennylane as qml
import time

jax.config.update("jax_enable_x64", True)
key = jax.random.PRNGKey(42)

In [2]:
n_qubits = 16
dev = qml.device("default.qubit", wires=n_qubits) # For some reason, lightning.qubit makes JIT much worse but the NO-JIT version faster.

@qml.qnode(dev) # If jax version is > 0.4.28, VMAP won't work if the interface is not specified as "jax-jit" due using `jax.core.ConcreteArray` which is deprecated.
def circuit(param):
    for idx in range(n_qubits):
        qml.RX(param, wires=idx)

    for idx in range(n_qubits-1):
        qml.CNOT(wires=[idx, idx+1])

    qml.CNOT(wires=[n_qubits-1, 0])

    return qml.expval(qml.PauliZ(0))

# Just-in-Time Compilation

In [3]:
jit_circuit = jax.jit(circuit)
batch_params = jax.random.uniform(key, shape=(100,))

In [4]:
start = time.time()
for param in batch_params:
    circuit(param)
no_jit_time = time.time() - start

In [5]:
start = time.time()
for param in batch_params:
    jit_circuit(param)
first_time = time.time() - start

In [6]:
print("Time to do 100 executions:")
print(f"\tNo JIT: {no_jit_time:0.8f} seconds")
print(f"\t   JIT: {first_time:0.8f} seconds")

Time to do 100 executions:
	No JIT: 4.01821589 seconds
	   JIT: 0.34297895 seconds


# Automatic Vectorization

In [7]:
vcircuit = jax.vmap(circuit)
vcircuit_jit = jax.vmap(jit_circuit)

In [8]:
batch_params = jax.random.uniform(key, shape=(1000,))

In [9]:
start = time.time()
for param in batch_params:
    circuit(param)
no_vjit_time = time.time() - start

In [10]:
start = time.time()
batched_results = vcircuit(batch_params)
vmap_time = time.time() - start

In [11]:
start = time.time()
batched_jit_results = vcircuit_jit(batch_params)
vjit_time = time.time() - start

In [12]:
print("Time to do 1000 executions:")
print(f"\tNo JIT, No VMAP: {no_vjit_time:0.8f} seconds")
print(f"\t           VMAP: {vmap_time:0.8f} seconds")
print(f"\t   VMAP and JIT: {vjit_time:0.8f} seconds")

Time to do 1000 executions:
	No JIT, No VMAP: 17.40339684 seconds
	           VMAP: 5.89912891 seconds
	   VMAP and JIT: 0.16277909 seconds
