adapted from: https://github.com/schance995/dqc-demo/blob/main/mnist_dqc_demo.ipynb

In [1]:
#!pip install pennylane optax flax tqdm transformers datasets

Collecting pennylane
  Using cached PennyLane-0.40.0-py3-none-any.whl.metadata (10 kB)
Collecting optax
  Using cached optax-0.2.4-py3-none-any.whl.metadata (8.3 kB)
Collecting flax
  Using cached flax-0.10.2-py3-none-any.whl.metadata (11 kB)
Collecting tqdm
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting transformers
  Using cached transformers-4.48.2-py3-none-any.whl.metadata (44 kB)
Collecting datasets
  Using cached datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting numpy<2.1 (from pennylane)
  Using cached numpy-2.0.2-cp312-cp312-macosx_14_0_arm64.whl.metadata (60 kB)
Collecting networkx (from pennylane)
  Using cached networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting rustworkx>=0.14.0 (from pennylane)
  Using cached rustworkx-0.16.0-cp39-abi3-macosx_11_0_arm64.whl.metadata (10 kB)
Collecting autograd (from pennylane)
  Using cached autograd-1.7.0-py3-none-any.whl.metadata (7.5 kB)
Collecting tomlkit (from pennylane)

In [2]:
from typing import Callable

import jax
import optax
import pennylane as qml
from datasets import load_dataset
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import numpy as jnp
from jax import random as jrand
from tqdm import tqdm
from transformers import FlaxResNetModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
N_QUBITS = 10
N_LAYERS = 6
BATCH_SIZE = 16
LEARNING_RATE = 1e-4

In [4]:
def make_circuit(dev, n_qubits, n_layers):
    @qml.qnode(dev, interface="jax-jit")
    def circuit(x, circuit_weights):
        # data encoding
        for i in range(n_qubits):
            qml.Hadamard(wires=i)
            qml.RY(x[i], wires=i)
        # trainable unitary
        qml.StronglyEntanglingLayers(circuit_weights, wires=range(n_qubits))

        return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

    return jax.jit(jax.vmap(circuit, in_axes=(0, None)))

In [5]:

class QuantumCircuit(nn.Module):
    num_qubits: int
    num_layers: int
    circuit: Callable

    @nn.compact
    def __call__(self, x):
        circuit_weights = self.param(
            'circuit_weights',
            nn.initializers.normal(),
            (self.num_layers, self.num_qubits, 3),
        )
        x = self.circuit(x, circuit_weights)
        x = jnp.array(x).T
        return x

In [6]:

class DressedQuantumClassifier(nn.Module):
    backbone: nn.Module
    circuit: Callable
    num_qubits: int
    num_layers: int
    num_labels: int

    @nn.compact
    def __call__(self, x):
        x = self.backbone(x)
        x = x.pooler_output[:, :, 0, 0]
        # reduce features to fit into quantum circuit
        x = nn.Dense(features=self.num_qubits)(x)
        # rescale inputs
        x = jnp.tanh(x) * jnp.pi / 2
        x = QuantumCircuit(
            num_qubits=self.num_qubits,
            num_layers=self.num_layers,
            circuit=self.circuit,
        )(x)
        x = nn.Dense(features=self.num_labels)(x)
        return x

In [7]:
def create_train_step(model, params, optimizer):
    state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

    @jax.jit
    def predict(params, x):
        logits = model.apply(params, x)
        return logits.argmax(axis=1)

    @jax.jit
    def loss_fn(params, x, y):
        logits = model.apply(params, x)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
        return loss

    @jax.jit
    def train_step(state, x, y):
        loss, grads = jax.value_and_grad(loss_fn)(state.params, x, y)
        state = state.apply_gradients(grads=grads)
        return state, loss

    return train_step, loss_fn, predict, state

In [8]:

dev = qml.device('lightning.qubit', wires=N_QUBITS)
circuit = make_circuit(dev, N_QUBITS, N_LAYERS)

resnet = FlaxResNetModel.from_pretrained('microsoft/resnet-50')

dqc = DressedQuantumClassifier(
    backbone=resnet.module,
    circuit=circuit,
    num_qubits=N_QUBITS,
    num_layers=N_LAYERS,
    num_labels=10,
)


I0000 00:00:1738675229.211264  373957 service.cc:145] XLA service 0x600000e5ef00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1738675229.211289  373957 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1738675229.212788  373957 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1738675229.212800  373957 mps_client.cc:384] XLA backend will use up to 28990554112 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M3 Max


In [9]:
zero_image = jnp.empty((1, 224, 224, 3))
key = jrand.PRNGKey(42)
params = dqc.init(key, zero_image)

ValueError: `EmitPythonCallback` not supported on METAL backend.

In [28]:

params['params']['backbone'] = resnet.params['params']
params['batch_stats']['backbone'] = resnet.params['batch_stats']

In [29]:

optimizer = optax.adam(LEARNING_RATE)
train_step, loss_fn, predict, state = create_train_step(dqc, params, optimizer)

In [30]:
ds = load_dataset('mnist').with_format('jax')

In [31]:
@jax.vmap
def grayscale_to_imagenet_format(x):
    x /= 255
    # resize to resnet size
    x = jax.image.resize(x, (224, 224), method='nearest')
    # copy grayscale channels to rgb
    x = jnp.dstack([x] * 3)
    # normalize to imagenet channels
    x = (x - jnp.array((0.485, 0.456, 0.406))) / jnp.array((0.229, 0.224, 0.225))
    return x

In [32]:

def train_loop(state, ds):
    total_correct = 0
    total_loss = 0
    total_seen = 0
    len_train = len(ds['train']) // BATCH_SIZE

    # close progress bar when epoch is finished
    with tqdm(ds['train'].iter(BATCH_SIZE), desc='train', leave=True, total=len_train) as pbar:
        for batch in pbar:
            x, y = batch['image'], batch['label']
            x = grayscale_to_imagenet_format(x)

            state, loss = train_step(state, x, y)
            yhat = predict(state.params, x)

            total_correct += int(sum(y == yhat))
            total_loss += float(loss)
            total_seen += len(yhat)

            pbar.set_postfix({'Mean acc': float(total_correct / total_seen),
                              'Mean loss': float(total_loss / total_seen)})
            #break

    return state # , float(total_loss / total_seen), float(total_correct / total_seen)

In [33]:
def test_loop(state, ds):
    total_correct = 0
    total_loss = 0
    total_seen = 0
    len_test = len(ds['test']) // BATCH_SIZE

    # close progress bar when epoch is finished
    with tqdm(ds['test'].iter(BATCH_SIZE), desc='test', leave=True, total=len_test) as pbar:
        for batch in pbar:
            x, y = batch['image'], batch['label']
            x = grayscale_to_imagenet_format(x)

            yhat = predict(state.params, x)
            loss = loss_fn(state.params, x, y)

            total_correct += int(sum(y == yhat))
            total_loss += float(loss)
            total_seen += len(yhat)

            pbar.set_postfix({'Mean acc': float(total_correct / total_seen),
                              'Mean loss': float(total_loss / total_seen)})
            #break

    # return float(total_loss / total_seen), float(total_correct / total_seen)

In [34]:
test_loop(state, ds)

In [None]:
state = train_loop(state, ds)

train:   0%|          | 0/3750 [00:00<?, ?it/s]

In [18]:
test_loop(state, ds)

test: 100%|██████████| 625/625 [08:21<00:00,  1.25it/s, Mean acc=0.985, Mean loss=0.0175]
