# QML on MNIST Classification

## Overview

The aim of this tutorial is not about the machine learning perspective on better design of QML method for MNIST classification. Instead, we use a simple parameterized circuit and demonstrate the QML-related technical ingredients of ``tensorcircuit``. Nevertheless, this note is by no means a good practice on QML.

[WIP note]

## Setup

In [1]:
from functools import partial
import numpy as np
import tensorflow as tf
import jax
from jax.config import config

config.update("jax_enable_x64", True)
from jax import numpy as jnp
import optax
import tensorcircuit as tc

In [2]:
tc.set_backend("tensorflow")
tc.set_dtype("complex128")

('complex128', 'float64')

## Data Processing

We utilize MNIST data and resize them to 3*3 to fit into a 9-qubit circuit.
The testbed we use is a binary classification task, digit 1 vs. 5.
And since this tutorial is not about good practice on QML, we leave the validation set away.
And we only collect 100 data points for a small demo.

In [3]:
# numpy data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train[..., np.newaxis] / 255.0


def filter_pair(x, y, a, b):
    keep = (y == a) | (y == b)
    x, y = x[keep], y[keep]
    y = y == a
    return x, y


x_train, y_train = filter_pair(x_train, y_train, 1, 5)
x_train_small = tf.image.resize(x_train, (3, 3)).numpy()
x_train_bin = np.array(x_train_small > 0.5, dtype=np.float32)
x_train_bin = np.squeeze(x_train_bin)[:100]

In [4]:
# tensorflow data

x_train_tf = tf.reshape(tf.constant(x_train_bin, dtype=tf.float64), [-1, 9])
y_train_tf = tf.constant(y_train[:100], dtype=tf.float64)

# jax data

x_train_jax = jnp.array(x_train_bin, dtype=np.float64).reshape([100, -1])
y_train_jax = jnp.array(y_train[:100], dtype=np.float64).reshape([100])



## Using ``vectorized_value_and_grad`` API

In [5]:
nlayers = 3


def qml_loss(x, y, weights, nlayers):
    n = 9
    weights = tc.backend.cast(weights, "complex128")
    x = tc.backend.cast(x, "complex128")
    c = tc.Circuit(n)
    for i in range(n):
        c.rx(i, theta=x[i])
    for j in range(nlayers):
        for i in range(n - 1):
            c.cnot(i, i + 1)
        for i in range(n):
            c.rx(i, theta=weights[2 * j, i])
            c.ry(i, theta=weights[2 * j + 1, i])
    ypred = c.expectation([tc.gates.z(), (4,)])
    ypred = tc.backend.real(ypred)
    ypred = (tc.backend.real(ypred) + 1) / 2.0
    return -y * tc.backend.log(ypred) - (1 - y) * tc.backend.log(1 - ypred), ypred

In [6]:
def get_qml_vvag():
    qml_vvag = tc.backend.vectorized_value_and_grad(
        qml_loss, argnums=(2,), vectorized_argnums=(0, 1), has_aux=True
    )
    qml_vvag = tc.backend.jit(qml_vvag, static_argnums=(3,))
    return qml_vvag


qml_vvag = get_qml_vvag()
qml_vvag(x_train_tf, y_train_tf, tf.ones([nlayers * 2, 9], dtype=tf.float64), nlayers)

((<tf.Tensor: shape=(100,), dtype=float64, numpy=
  array([0.8433698 , 0.56257199, 0.54653163, 0.56257199, 0.82036163,
         0.56257199, 0.56257199, 0.58030506, 0.82036163, 0.56257199,
         0.82036163, 0.56257199, 0.82036163, 0.56257199, 0.54653163,
         0.54653163, 0.56257199, 0.56257199, 0.58030506, 0.82036163,
         0.54653163, 0.56257199, 0.56257199, 0.56257199, 0.56257199,
         0.56257199, 0.56257199, 0.85182866, 0.56257199, 0.82036163,
         0.82036163, 0.56257199, 0.8433698 , 0.56257199, 0.8433698 ,
         0.56257199, 0.85182866, 0.56257199, 0.82036163, 0.54653163,
         0.56257199, 0.56257199, 0.56257199, 0.56257199, 0.8433698 ,
         0.58030506, 0.56257199, 0.82036163, 0.8433698 , 0.8433698 ,
         0.54653163, 0.56257199, 0.82036163, 0.86501404, 0.56257199,
         0.56257199, 0.8433698 , 0.56257199, 0.85182866, 0.82036163,
         0.82036163, 0.56257199, 0.82036163, 0.56257199, 0.56257199,
         0.56257199, 0.82036163, 0.8433698 , 0.843369

In [7]:
# %timeit qml_vvag(x_train_tf, y_train_tf, tf.ones([nlayers*2, 9], dtype=tf.float64), nlayers)

### Jax Backend Compatibility

In [8]:
tc.set_backend("jax")

<tensorcircuit.backends.jax_backend.JaxBackend at 0x7ffb04a71820>

In [9]:
qml_vvag = get_qml_vvag()
qml_vvag(
    x_train_jax, y_train_jax, jnp.ones([nlayers * 2, 9], dtype=np.float64), nlayers
)

((DeviceArray([0.8433698 , 0.56257199, 0.54653163, 0.56257199, 0.82036163,
               0.56257199, 0.56257199, 0.58030506, 0.82036163, 0.56257199,
               0.82036163, 0.56257199, 0.82036163, 0.56257199, 0.54653163,
               0.54653163, 0.56257199, 0.56257199, 0.58030506, 0.82036163,
               0.54653163, 0.56257199, 0.56257199, 0.56257199, 0.56257199,
               0.56257199, 0.56257199, 0.85182866, 0.56257199, 0.82036163,
               0.82036163, 0.56257199, 0.8433698 , 0.56257199, 0.8433698 ,
               0.56257199, 0.85182866, 0.56257199, 0.82036163, 0.54653163,
               0.56257199, 0.56257199, 0.56257199, 0.56257199, 0.8433698 ,
               0.58030506, 0.56257199, 0.82036163, 0.8433698 , 0.8433698 ,
               0.54653163, 0.56257199, 0.82036163, 0.86501404, 0.56257199,
               0.56257199, 0.8433698 , 0.56257199, 0.85182866, 0.82036163,
               0.82036163, 0.56257199, 0.82036163, 0.56257199, 0.56257199,
               0.56257199

In [10]:
# %timeit qml_vvag(x_train_jax, y_train_jax, jnp.ones([nlayers * 2, 9], dtype=np.float64), nlayers)

### Training Using ``tf.data``

In [11]:
# switch back to tensorflow
tc.set_backend("tensorflow")
qml_vvag = get_qml_vvag()
qml_vvag = tc.backend.jit(qml_vvag, static_argnums=(3,))

In [12]:
mnist_data = (
    tf.data.Dataset.from_tensor_slices((x_train_tf, y_train_tf))
    .repeat(200)
    .shuffle(100)
    .batch(32)
)

In [13]:
opt = tf.keras.optimizers.Adam(1e-2)
w = tf.Variable(
    initial_value=tf.random.normal(shape=(2 * nlayers, 9), stddev=0.5, dtype=tf.float64)
)
for i, (xs, ys) in zip(range(2000), mnist_data):
    (losses, ypreds), grad = qml_vvag(xs, ys, w, nlayers)
    if i % 20 == 0:
        print(tf.reduce_mean(losses))
        opt.apply_gradients([(grad[0], w)])

tf.Tensor(0.689301607482696, shape=(), dtype=float64)
tf.Tensor(0.6825438352666904, shape=(), dtype=float64)
tf.Tensor(0.6815497367036047, shape=(), dtype=float64)
tf.Tensor(0.6632433448327015, shape=(), dtype=float64)
tf.Tensor(0.6641348270253142, shape=(), dtype=float64)
tf.Tensor(0.6779914200102861, shape=(), dtype=float64)
tf.Tensor(0.6550256969249619, shape=(), dtype=float64)
tf.Tensor(0.6801325087248677, shape=(), dtype=float64)
tf.Tensor(0.6190616725052769, shape=(), dtype=float64)
tf.Tensor(0.6711760566099414, shape=(), dtype=float64)
tf.Tensor(0.6965496746836946, shape=(), dtype=float64)
tf.Tensor(0.6443036572691725, shape=(), dtype=float64)
tf.Tensor(0.6060956714527996, shape=(), dtype=float64)
tf.Tensor(0.6728839286340991, shape=(), dtype=float64)
tf.Tensor(0.6584085272471567, shape=(), dtype=float64)
tf.Tensor(0.6600981577311038, shape=(), dtype=float64)
tf.Tensor(0.6581071758186605, shape=(), dtype=float64)
tf.Tensor(0.6609348320181809, shape=(), dtype=float64)
tf.Tensor(0

## Using ``tf.keras`` API

In [14]:
from tensorcircuit import keras


def qml_y(x, weights, nlayers):
    n = 9
    weights = tc.backend.cast(weights, "complex128")
    x = tc.backend.cast(x, "complex128")
    c = tc.Circuit(n)
    for i in range(n):
        c.rx(i, theta=x[i])
    for j in range(nlayers):
        for i in range(n - 1):
            c.cnot(i, i + 1)
        for i in range(n):
            c.rx(i, theta=weights[2 * j, i])
            c.ry(i, theta=weights[2 * j + 1, i])
    ypred = c.expectation([tc.gates.z(), (4,)])
    ypred = tc.backend.real(ypred)
    ypred = (tc.backend.real(ypred) + 1) / 2.0
    return ypred


ql = keras.QuantumLayer(partial(qml_y, nlayers=nlayers), [(2 * nlayers, 9)])

In [15]:
# keras interface with value and grad paradigm


@tf.function
def my_vvag(xs, ys):
    with tf.GradientTape() as tape:
        ypred = ql(xs)
        loss = tf.keras.losses.BinaryCrossentropy()(ys, ypred)
    return loss, tape.gradient(loss, ql.variables)


my_vvag(x_train_tf, y_train_tf)

(<tf.Tensor: shape=(), dtype=float64, numpy=0.7179324626922607>,
 [<tf.Tensor: shape=(6, 9), dtype=float64, numpy=
  array([[-1.97741333e-02, -3.24903196e-03, -1.19449484e-02,
           1.34411790e-02, -2.29378194e-03,  9.24968875e-04,
           3.41827505e-04,  1.38777878e-17, -6.93889390e-18],
         [-1.85390086e-02,  3.81940052e-03, -3.05341288e-02,
          -1.79981829e-03, -5.77913396e-02, -3.71762005e-03,
          -5.10097165e-03, -1.71303943e-17, -1.73472348e-18],
         [ 5.04193508e-03, -1.77846516e-02,  2.26429668e-02,
          -1.41076421e-02, -3.13874407e-02,  1.37515418e-03,
           2.08166817e-17,  2.42861287e-17, -1.73472348e-18],
         [ 2.67860892e-02,  1.92311176e-02, -2.44580361e-02,
          -5.08346256e-02, -1.15289797e-02, -8.99461139e-03,
           3.46944695e-18, -5.20417043e-18, -6.93889390e-18],
         [ 9.54097912e-18, -3.46944695e-18,  1.04083409e-17,
          -1.73472348e-18, -2.53960212e-03,  1.31188463e-17,
           5.20417043e-18, 

In [16]:
# %timeit my_vvag(x_train_tf, y_train_tf)

In [17]:
# keras interface with keras training paradigm

model = tf.keras.Sequential([ql])

model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(0.01),
    metrics=[tf.keras.metrics.BinaryAccuracy()],
)

model.fit(x_train_tf, y_train_tf, batch_size=32, epochs=100)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100


<keras.callbacks.History at 0x7ffaf92858b0>

### Quantum-Classical Hybrid Model in Keras

In [18]:
def qml_ys(x, weights, nlayers):
    n = 9
    weights = tc.backend.cast(weights, "complex128")
    x = tc.backend.cast(x, "complex128")
    c = tc.Circuit(n)
    for i in range(n):
        c.rx(i, theta=x[i])
    for j in range(nlayers):
        for i in range(n - 1):
            c.cnot(i, i + 1)
        for i in range(n):
            c.rx(i, theta=weights[2 * j, i])
            c.ry(i, theta=weights[2 * j + 1, i])
    ypreds = []
    for i in range(n):
        ypred = c.expectation([tc.gates.z(), (i,)])
        ypred = tc.backend.real(ypred)
        ypred = (tc.backend.real(ypred) + 1) / 2.0
        ypreds.append(ypred)
    return tc.backend.stack(ypreds)

In [19]:
ql = tc.keras.QuantumLayer(partial(qml_ys, nlayers=nlayers), [(2 * nlayers, 9)])
model = tf.keras.Sequential([ql, tf.keras.layers.Dense(1, activation="sigmoid")])

In [20]:
model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(0.01),
    metrics=[tf.keras.metrics.BinaryAccuracy()],
)

model.fit(x_train_tf, y_train_tf, batch_size=32, epochs=100)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<keras.callbacks.History at 0x7ffac3142e20>

### Hybrid Model in Jax

In [21]:
tc.set_backend("jax")

<tensorcircuit.backends.jax_backend.JaxBackend at 0x7ffb04a71820>

In [22]:
key = jax.random.PRNGKey(42)
key, *subkeys = jax.random.split(key, num=4)
params = {
    "qweights": jax.random.normal(subkeys[0], shape=[nlayers * 2, 9]),
    "cweights:w": jax.random.normal(subkeys[1], shape=[9]),
    "cweights:b": jax.random.normal(subkeys[2], shape=[1]),
}

In [23]:
def qml_hybrid_loss(x, y, params, nlayers):
    weights = params["qweights"]
    w = params["cweights:w"]
    b = params["cweights:b"]
    ypred = qml_ys(x, weights, nlayers)
    ypred = tc.backend.reshape(ypred, [-1, 1])
    ypred = w @ ypred + b
    ypred = jax.nn.sigmoid(ypred)
    ypred = ypred[0]
    loss = -y * tc.backend.log(ypred) - (1 - y) * tc.backend.log(1 - ypred)
    return loss

In [24]:
qml_hybrid_loss_vag = tc.backend.jit(
    tc.backend.vvag(qml_hybrid_loss, vectorized_argnums=(0, 1), argnums=2),
    static_argnums=3,
)

In [25]:
qml_hybrid_loss_vag(x_train_jax, y_train_jax, params, nlayers)

(DeviceArray([3.73282398, 0.02421603, 0.02899787, 0.02421603, 4.08996787,
              0.03069481, 0.02421603, 0.01688146, 4.08996787, 0.03069481,
              4.08996787, 0.02421603, 4.08996787, 0.02421603, 0.02899787,
              0.03354042, 0.02421603, 0.02421603, 0.01688146, 4.08996787,
              0.03354042, 0.02421603, 0.02421603, 0.03069481, 0.02421603,
              0.02421603, 0.03069481, 3.73798651, 0.02421603, 3.68810189,
              4.08996787, 0.03069481, 3.73282398, 0.03069481, 3.73282398,
              0.02421603, 3.49674264, 0.02421603, 4.08996787, 0.02899787,
              0.02421603, 0.02421603, 0.03069481, 0.03069481, 3.73282398,
              0.02533775, 0.03069481, 3.68810189, 3.73282398, 3.49896983,
              0.02899787, 0.03069481, 4.08996787, 3.41172721, 0.02421603,
              0.02421603, 3.73282398, 0.02421603, 3.73798651, 3.68810189,
              4.08996787, 0.03069481, 4.08996787, 0.02421603, 0.03069481,
              0.02421603, 3.68810189, 

In [26]:
optimizer = optax.adam(5e-3)
opt_state = optimizer.init(params)
for i, (xs, ys) in zip(range(2000), mnist_data):  # using tf data loader here
    xs = xs.numpy()
    ys = ys.numpy()
    v, grads = qml_hybrid_loss_vag(xs, ys, params, nlayers)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 30 == 0:
        print(jnp.mean(v))

1.2979572281332594
0.8331012068009501
0.6805939758448183
0.5897353928152392
0.6460840124038746
0.6093143713632384
0.6671721223530598
0.5863347320393952
0.5465362554431986
0.5594138744621404
0.5493311423294576
0.5228166702417829
0.6176455570797168
0.5256494465741394
0.5359881696740493
0.5787532611935906
0.49082340457493323
0.4062487079116086
0.5802733401377229
0.4762524476616207
0.5404245247888219
