# QAOA implementation with qutip and JAX

In [15]:
!pip install qutip



In [16]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('..')
import numpy as np

import jax
import jax.numpy as jnp

from jax import grad, value_and_grad
from jax.experimental import optimizers
from jax.scipy.linalg import expm

from qutip.operators import sigmax, qeye
from qutip.tensor import tensor

# from util import load_problem_instance

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
def plus_state(q):       
    return 1 / jnp.sqrt(2**q) * jnp.ones((2**q, 1))

In [18]:
def cost_func(x, cost, p, q, s):
    gamma_list, beta_list = jnp.split(x, 2)
    s = variational_state(cost, p, q, s, gamma_list, beta_list)
    return jnp.real(jnp.vdot(jnp.transpose(s), jnp.multiply(cost, s)))

In [19]:
def loss_fn(x, cost, p, q, s):
    f = cost_func(x, cost, p, q, s)
    return ((-4) - f)**2

In [20]:
def variational_state(cost, p, q, s, gamma_list, beta_list):
    for gamma, beta in zip(gamma_list, beta_list):
        s = jnp.multiply(U_C(gamma, cost), s)
        s = jnp.matmul(U_B(beta, q), s)
    return s 

In [21]:
def U_C(gamma, cost):
    return jnp.exp(-1j * gamma * cost)

In [22]:
def U_B(beta, q):
    X = 0
    for i in range(q):
        local_pauli_x_operation = [sigmax() if i==j else qeye(sigmax().dims[0][0]) for j in range(q)]
        X += tensor(local_pauli_x_operation)
    X = jnp.array(X.full())
    return expm(- 1j * beta * X)

In [23]:
def step(step, opt_state):
  value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state), cost, p, q, s)
  opt_state = opt_update(step, grads, opt_state)
  return value, opt_state

## Gradient descent based optimization 

In [24]:
cost = jnp.array([-0, -3, -2, -3, -3, -4, -3, -2, -2, -3, -4, -3, -3, -2, -3, -0]).reshape(-1,1)

In [25]:
p = 1
q = int(np.log2(cost.shape[0]))
s = plus_state(q)
x = jnp.array(np.random.rand(2*p))

## Time testing 
### 11 qubits
#### unoptimized U_C full matrix 
9.09 s ± 506 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
#### optimized U_C vector 
3.27 s ± 249 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [26]:
# %timeit cost_func(x, cost, p, q, s)

In [27]:
learning_rate = 1e-1
params = jnp.array(np.random.rand(2*p)) 
num_steps = 100
epochs = 1

opt_init, opt_update, get_params = optimizers.adam(learning_rate)
opt_state = opt_init(params)

for epoch in range(epochs):
    for i in range(num_steps):
        value, opt_state = step(i, opt_state)

Traced<ConcreteArray(1.0)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(1., dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(1.0)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(1., dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(0.99999994)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(0.99999994, dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(0.9999999)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(0.9999999, dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(1.0)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(1., dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(0.9999998)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(0.9999998, dtype=float32)
       tangent 

In [28]:
f = cost_func(jnp.array(opt_state.packed_state[0][0]), cost, p, q, s)
print(f)

1.0000001
-3.237106
