In [1]:
import numpy as np
import qiskit.quantum_info as qi

from qiskit_dynamics import dispatch
from qiskit_dynamics.dispatch import Array

### Array class

The `Array` class can wrap different ndarray backends. For now just the numpy ndarray (`backend = 'numpy'`) or a Jax DeviceArray (`backend = 'jax'`). This functionality should eventually be moved into the `qiskit.quantum_info` module.

In [2]:
# Initialize an Array

a = Array(np.arange(10))
a

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], backend='numpy')

In [3]:
# See inner array

a.data

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [4]:
# Work with numpy ufuncs

np.sin(a) + 3 * np.cos(a)

Array([ 3.        ,  2.4623779 , -0.33914308, -2.82885748, -2.71773336,
       -0.10793772,  2.60109536,  2.91869336,  0.55285815, -2.3212723 ],
      backend='numpy')

In [5]:
# Work with numpy array functions

a.real

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], backend='numpy')

In [6]:
# Call attribute of wrapped array

a.reshape((2, 5))

Array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]], backend='numpy')

### Using Array backends

In [7]:
# See available Array backends

dispatch.available_backends()

('numpy', 'jax')

In [8]:
# Enable double precision jax
import jax

jax.config.update("jax_enable_x64", True)

b = Array(np.arange(10), backend='jax')
b



Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], backend='jax')

In [9]:
# See inner array

b.data

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64)

In [10]:
# Work with numpy ufuncs

np.sin(b) + 3 * np.cos(b)

Array([ 3.        ,  2.4623779 , -0.33914308, -2.82885748, -2.71773336,
       -0.10793772,  2.60109536,  2.91869336,  0.55285815, -2.3212723 ],
      backend='jax')

In [11]:
# Work with numpy array functions

np.dot(b, b)

Array(285, backend='jax')

### Using Array in other classes

In [12]:
# Set Jax as default backend

dispatch.set_default_backend('jax')

In [13]:
def obj(theta):
    """Toy objective function"""
    # Need to wrap Operators in Array until qinfo integration is done
    I = Array(qi.Operator.from_label('I'))  
    Y = Array(qi.Operator.from_label('Y'))
    
    # Need to wrap parameters in array to handle Jax jit/grad dispatch
    cos = np.cos(Array(theta) / 2)
    sin = np.sin(Array(theta) / 2)

    op = cos * I + sin * Y
    val = np.abs(np.trace(np.dot(op, Y)))
    return val

In [14]:
# Test objective

obj(0.1)

Array(0.09995834)

### Wrapping 3rd-party library functions to work with arrays

In [15]:
# Wrap jax functions

jit = dispatch.wrap(jax.jit, decorator=True)
grad = dispatch.wrap(jax.grad, decorator=True)
value_and_grad = dispatch.wrap(jax.value_and_grad, decorator=True)

f = jit(obj)
g = grad(obj)
h = value_and_grad(obj)

f(0.1), g(0.1), h(0.1)

(Array(0.09995834), Array(0.99875026), (Array(0.09995834), Array(0.99875026)))

### Jax OdeInt

In [16]:
from jax.experimental.ode import odeint as jax_odeint

# Wrap jax odeint function
odeint = dispatch.wrap(jax_odeint)

In [17]:
def sample_rhs(y, t):
    return y

y0 = Array([0., 1., 2.], dtype=float)
t = Array([0., 1., 2.], dtype=float)

In [18]:
results = odeint(sample_rhs, y0, t, atol=1e-10, rtol=1e-10)
results

Array([[ 0.        ,  1.        ,  2.        ],
       [ 0.        ,  2.71828183,  5.43656366],
       [ 0.        ,  7.3890561 , 14.7781122 ]])