<a href="https://colab.research.google.com/github/schance995/dqc-demo/blob/main/mnist_dqc_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hybrid Quantum Transfer Learning with Pennylane, Jax, and Flax

This demo shows how to train a dressed quantum circuit to classify MNIST with Jax, Flax, and HuggingFace. This demo shares some code with the multilabel chest X-ray classifier developed in https://arxiv.org/abs/2405.00156v2.


## Imports

In [1]:
!pip install pennylane "jax[cuda12]" optax flax tqdm transformers datasets



In [2]:
from typing import Callable

import jax
import optax
import pennylane as qml
from datasets import load_dataset
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import numpy as jnp
from jax import random as jrand
from tqdm import tqdm
from transformers import FlaxResNetModel

## Hyperparameters

In [3]:
N_QUBITS = 10
N_LAYERS = 6
BATCH_SIZE = 16
LEARNING_RATE = 1e-4

## Model implementation

### Quantum circuit

`jax.vmap` is used to enable batching

In [4]:
def make_circuit(dev, n_qubits, n_layers):
    @qml.qnode(dev)
    def circuit(x, circuit_weights):
        # data encoding
        for i in range(n_qubits):
            qml.Hadamard(wires=i)
            qml.RY(x[i], wires=i)
        # trainable unitary
        for layer in range(n_layers):
            for i in range(n_qubits):
                qml.RY(circuit_weights[layer, i], wires=i)
            qml.broadcast(qml.CNOT, wires=range(n_qubits), pattern='ring')

        return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

    return jax.vmap(circuit, in_axes=(0, None))

### Flax Linen wrapper for circuit

In [5]:
class QuantumCircuit(nn.Module):
    num_qubits: int
    num_layers: int
    circuit: Callable

    @nn.compact
    def __call__(self, x):
        circuit_weights = self.param(
            'circuit_weights',
            nn.initializers.normal(),
            (self.num_layers, self.num_qubits),
        )
        x = self.circuit(x, circuit_weights)
        x = jnp.array(x).T
        return x

### Dressed quantum classifier

Wraps a quantum circuit module with a transfer learning backbone and classical pre/post-processing layers. This implementation uses [Microsoft's ResNet-50](https://huggingface.co/microsoft/resnet-50) backbone.

In [6]:
class DressedQuantumClassifier(nn.Module):
    backbone: nn.Module
    circuit: Callable
    num_qubits: int
    num_layers: int
    num_labels: int

    @nn.compact
    def __call__(self, x):
        x = self.backbone(x)
        x = x.pooler_output[:, :, 0, 0]
        # reduce features to fit into quantum circuit
        x = nn.Dense(features=self.num_qubits)(x)
        # rescale inputs
        x = jnp.tanh(x) * jnp.pi / 2
        x = QuantumCircuit(
            num_qubits=self.num_qubits,
            num_layers=self.num_layers,
            circuit=self.circuit,
        )(x)
        x = nn.Dense(features=self.num_labels)(x)
        return x

### Training, loss, and inference functions

In [7]:
def create_train_step(model, params, optimizer):
    state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

    @jax.jit
    def predict(params, x):
        logits = model.apply(params, x)
        return logits.argmax(axis=1)

    @jax.jit
    def loss_fn(params, x, y):
        logits = model.apply(params, x)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
        return loss

    @jax.jit
    def train_step(state, x, y):
        loss, grads = jax.value_and_grad(loss_fn)(state.params, x, y)
        state = state.apply_gradients(grads=grads)
        return state, loss

    return train_step, loss_fn, predict, state

## Create the dressed quantum circuit classifier


In [8]:
dev = qml.device('default.qubit', wires=N_QUBITS)
circuit = make_circuit(dev, N_QUBITS, N_LAYERS)

resnet = FlaxResNetModel.from_pretrained('microsoft/resnet-50')

dqc = DressedQuantumClassifier(
    backbone=resnet.module,
    circuit=circuit,
    num_qubits=N_QUBITS,
    num_layers=N_LAYERS,
    num_labels=10,
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


## Initialize classifier

Requires any image, we use an array of 0s.

In [9]:
zero_image = jnp.empty((1, 224, 224, 3))
key = jrand.PRNGKey(42)
params = dqc.init(key, zero_image)

Transfer pre-trained weights.

In [10]:
params['params']['backbone'] = resnet.params['params']
params['batch_stats']['backbone'] = resnet.params['batch_stats']

Create optimizer and get the model functions

In [11]:
optimizer = optax.adam(LEARNING_RATE)
train_step, loss_fn, predict, state = create_train_step(dqc, params, optimizer)

## MNIST

Download MNIST and format it to Jax.

In [12]:
ds = load_dataset('mnist').with_format('jax')

We will also format the 28x28 grayscale images to 224x224 RGB images normalized to ImageNet channel statistics.

In [13]:
@jax.vmap
def grayscale_to_imagenet_format(x):
    x /= 255
    # resize to resnet size
    x = jax.image.resize(x, (224, 224), method='nearest')
    # copy grayscale channels to rgb
    x = jnp.dstack([x] * 3)
    # normalize to imagenet channels
    x = (x - jnp.array((0.485, 0.456, 0.406))) / jnp.array((0.229, 0.224, 0.225))
    return x

## Model loops

Technical note on batch size: If the batch size doesn't divide evenly into the dataset size, the model functions will be recompiled every time the last batch is encountered as recompilation is triggered when the input shapes change.

The solution is 1) drop some data so the batch size divides evenly into the dataset size or 2) use the interpreted functions on the unevenly sized batch and the compiled function on all other batches. The 2nd approach was used in https://arxiv.org/abs/2405.00156v2



### Train loop

In [14]:
def train_loop(state, ds):
    total_correct = 0
    total_loss = 0
    total_seen = 0
    len_train = len(ds['train']) // BATCH_SIZE

    # close progress bar when epoch is finished
    with tqdm(ds['train'].iter(BATCH_SIZE), desc='train', leave=True, total=len_train) as pbar:
        for batch in pbar:
            x, y = batch['image'], batch['label']
            x = grayscale_to_imagenet_format(x)

            state, loss = train_step(state, x, y)
            yhat = predict(state.params, x)

            total_correct += int(sum(y == yhat))
            total_loss += float(loss)
            total_seen += len(yhat)

            pbar.set_postfix({'Mean acc': float(total_correct / total_seen),
                              'Mean loss': float(total_loss / total_seen)})
            #break

    return state # , float(total_loss / total_seen), float(total_correct / total_seen)

### Test loop

In [15]:
def test_loop(state, ds):
    total_correct = 0
    total_loss = 0
    total_seen = 0
    len_test = len(ds['test']) // BATCH_SIZE

    # close progress bar when epoch is finished
    with tqdm(ds['test'].iter(BATCH_SIZE), desc='test', leave=True, total=len_test) as pbar:
        for batch in pbar:
            x, y = batch['image'], batch['label']
            x = grayscale_to_imagenet_format(x)

            yhat = predict(state.params, x)
            loss = loss_fn(state.params, x, y)

            total_correct += int(sum(y == yhat))
            total_loss += float(loss)
            total_seen += len(yhat)

            pbar.set_postfix({'Mean acc': float(total_correct / total_seen),
                              'Mean loss': float(total_loss / total_seen)})
            #break

    # return float(total_loss / total_seen), float(total_correct / total_seen)

## Model training and testing

MNIST is easy, so we only run 1 training epoch to demonstrate the DQC implementation. The mean loss should go down and the mean accuracy should go up after training.

It will take a few seconds for the model functions to get JIT-compiled, this is normal.

This takes about 15 minutes on Colab.

### Before training

In [16]:
test_loop(state, ds)

test: 100%|██████████| 625/625 [01:17<00:00,  8.08it/s, Mean acc=0.118, Mean loss=0.144]


### Training

In [17]:
state = train_loop(state, ds)

train: 100%|██████████| 3750/3750 [13:38<00:00,  4.58it/s, Mean acc=0.821, Mean loss=0.0522]


### After training

In [18]:
test_loop(state, ds)

test: 100%|██████████| 625/625 [00:59<00:00, 10.44it/s, Mean acc=0.889, Mean loss=0.0316]


## Your turn!

Try different models, optimizers, circuits, datasets, etc. Can you find a quantum advantage?

## References

- [HuggingFace Flax tutorial](https://huggingface.co/blog/afmck/flax-tutorial)
- [Flax Linen documentation](https://flax-linen.readthedocs.io/en/latest/index.html)
- [Expanding the Horizon: Enabling Hybrid Quantum Transfer Learning for Long-Tailed Chest X-Ray Classification](https://arxiv.org/abs/2405.00156v2) \([code](https://github.com/bioIntelligence-Lab/qumi)\)
