# Task 3 - MNIST Dataset

In this notebook, we present our solution to task 3, where we train a QML model on the MNIST dataset. We make the following improvements to the tutorial presented

* Testing on the Fashion MNIST dataset
* Allow the QNN parameters to be trained
* Use JAX instead of Keras
* Use `jax.vmap` to speed up the QNN
* Add augmentations to the training process

In [1]:
import pennylane as qml
from pennylane import numpy as np

import jax
from jax import numpy as jnp
import flax
from flax import linen as nn
import optax

import grain.python as pygrain
import dm_pix

import os
import gzip
import requests
from functools import partial
from dataclasses import dataclass
from itertools import combinations

from tqdm.notebook import tqdm, trange
import matplotlib.pyplot as plt

## Configuration

In [2]:
n_epochs = 30
batch_size = 32
lr = 1e-3
mnist = False

SAVE_PATH = "data/"
np.random.seed(0)
key = jax.random.PRNGKey(0)

## Data

We use the [Google Grain](https://github.com/google/grain) library to load our dataset. We use rotational augmentations to improve generalization.

In [3]:
def download(url: str, fname: str, chunk_size=1024):
	if os.path.exists(fname):
		return

	resp = requests.get(url, stream=True)
	total = int(resp.headers.get('content-length', 0))
	with open(fname, 'wb') as file, tqdm(
		desc=fname,
		total=total,
		unit='iB',
		unit_scale=True,
		unit_divisor=1024,
	) as bar:
		for data in resp.iter_content(chunk_size=chunk_size):
			size = file.write(data)
			bar.update(size)

def load_mnist(path):
	download('https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz', os.path.join(path, 'mnist.npz'))
	return np.load(os.path.join(path, 'mnist.npz'))

def load_fashion_mnist(path):
	base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
	files = [
		"train-labels-idx1-ubyte.gz",
		"train-images-idx3-ubyte.gz",
		"t10k-labels-idx1-ubyte.gz",
		"t10k-images-idx3-ubyte.gz",
	]

	def open_gzip(fname):
		return gzip.open(os.path.join(path, fname), 'rb')

	for fname in files:
		download(f'{base}{fname}', os.path.join(path, fname))

	with open_gzip(files[0]) as lbpath:
		y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

	with open_gzip(files[1]) as imgpath:
		x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(
			len(y_train), 28, 28
		)

	with open_gzip(files[2]) as lbpath:
		y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)

	with open_gzip(files[3]) as imgpath:
		x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(
			len(y_test), 28, 28
		)

	return {
		'x_train': x_train,
		'y_train': y_train,
		'x_test': x_test,
		'y_test': y_test,
	}

In [4]:
load_fn, mean, std = (load_mnist, 0.1307, 0.3081) if mnist else (load_fashion_mnist, 0.2860, 0.3530)

In [5]:
class ImageDataSource(pygrain.RandomAccessDataSource[tuple[np.ndarray, np.ndarray]]):
	def __init__(self, path, split):
		data = load_fn(path)
		self.images = data[f'x_{split}'][..., np.newaxis]
		self.labels = data[f'y_{split}']

	def __len__(self) -> int:
		return len(self.images)

	def __getitem__(self, idx) -> tuple[np.ndarray, np.ndarray]:
		image, label = self.images[idx], self.labels[idx]
		return image, label

In [6]:
class ImageTransform(pygrain.RandomMapTransform):
	def __init__(self, mean, var, augment, flip_x=False, flip_y=False, rotate=0):
		self.augment = augment
		self.flip_x = flip_x
		self.flip_y = flip_y
		self.rotate = rotate
		self.mean = jnp.array(mean)
		self.var = jnp.array(var)

	# @jax.jit
	def random_map(self, data: tuple[np.ndarray, np.ndarray], rng: np.random.Generator) -> tuple[jax.Array, jax.Array]:
		images, labels = data
		images, labels = jnp.array(images), jnp.array(labels)

		if self.augment:
			key = jax.random.PRNGKey(rng.integers(0, 2**32))

			if self.flip_x:
				images = dm_pix.random_flip_up_down(key, images)

			if self.flip_y:
				images = dm_pix.random_flip_left_right(key, images)

			if self.rotate:
				angle = jax.random.uniform(key, shape=images.shape[0], minval=-self.rotate, maxval=self.rotate) / 180 * jnp.pi
				images = jax.vmap(
					partial(dm_pix.rotate, mode='constant', cval=255),
					in_axes=[0, 0], out_axes=0
				)(images, angle)

		images /= 255
		images = jax.nn.standardize(
			images,
			mean=self.mean,
			variance=self.var,
			axis=(2, 3)
		)
		return images, labels

### Initializing data loaders

We use the known mean and std values for the MNIST dataset. However, we use the variance instead of the std as that is the expected input for `dm-pix`.

In [7]:
train_dataset = ImageDataSource(SAVE_PATH, 'train')
train_sampler_fn = partial(
	pygrain.IndexSampler,
	num_records=len(train_dataset),
	num_epochs=1,
	shard_options=pygrain.NoSharding(),
	shuffle=True,
)
train_loader_fn = partial(
	pygrain.DataLoader,
	data_source=train_dataset,
	operations=[
		pygrain.Batch(batch_size=batch_size, drop_remainder=False),
		ImageTransform(mean, std ** 2, True, True, False, 10),
	],
	worker_count=2,
)
train_steps_per_epoch = len(train_dataset) // batch_size + 1

test_dataset = ImageDataSource(SAVE_PATH, 'test')
test_sampler = pygrain.IndexSampler(
	num_records=len(test_dataset),
	num_epochs=1,
	shard_options=pygrain.NoSharding(),
	shuffle=False,
	seed=0,
)
test_loader = pygrain.DataLoader(
	data_source=test_dataset,
	operations=[
		pygrain.Batch(batch_size=batch_size, drop_remainder=False),
		ImageTransform(mean, std ** 2, False),
	],
	sampler=test_sampler,
	worker_count=2,
)
test_steps_per_epoch = len(test_dataset) // batch_size + 1

## Metrics

We add a simple metrics aggregator below

In [8]:
@dataclass
class Metric:
	total: float = 0.
	previous: float = 0.
	counter: int = 0

class Metrics:
	def __init__(self, metrics: list[str]) -> None:
		self.keys = metrics
		self.reset()

	def reset(self) -> None:
		self.metrics = {k: Metric() for k in self.keys}

	def update(self, metrics: dict[str, float|int]) -> None:
		for k, v in metrics.items():
			self.metrics[k].total += v
			self.metrics[k].previous = v
			self.metrics[k].counter += 1

	@property
	def epoch_dict(self) -> dict[str, float]:
		return {k: v.total / v.counter for k, v in self.metrics.items()}

	@property
	def epoch(self) -> str:
		return '\t'.join([f'{k}: {v.total / v.counter:.4f}' for k, v in self.metrics.items()])

	@property
	def previous(self) -> str:
		return ', '.join([f'{k}: {v.previous:.4f}' for k, v in self.metrics.items()])

def calc_acc(preds: jnp.ndarray, labels: jnp.ndarray) -> float:
	return (preds.argmax(axis=-1) == labels).mean().item()

# QML Model

Here, we implement a learnable QNN circuit. We use `dm_pix.extract_patches` to extract the patches of the image for the convolution. Then, we use `jax.vmap` to execute the QNN in a vectorized form, returning an array of shape `(batch_size, qnn_output_x * qnn_output_y * qnn_output_channels)`.

In [9]:
dev = qml.device("default.qubit", wires=4)

@jax.jit
@qml.qnode(dev)
def learnable_qnn_circuit(param, phi):
	for wire in range(4):
		qml.RY(np.pi * phi[wire], wires=wire)
		qml.Rot(*param[wire], wires=wire)

	for pair in combinations(range(4), 2):
		qml.CNOT(wires=pair)

	return qml.expval(qml.PauliZ(0))


class LearnableQNN(nn.Module):
	def setup(self):
		self.params = self.param('params', nn.initializers.uniform(scale=2*jnp.pi), (4, 3))

	def __call__(self, x):
		n = x.shape[0]
		patches = dm_pix.extract_patches(
			images=x,
			sizes=(1, 2, 2, 1),
			strides=(1, 2, 2, 1),
			rates=(1, 1, 1, 1),
			padding='VALID',
		)
		patches = patches.reshape(-1, 4)
		return jax.vmap(learnable_qnn_circuit, in_axes=(None, 0))(self.params, patches).reshape(n, -1)

# Classical Model

We use a simple linear model to implement the classical head for the model.

In [10]:
class BasicLinearModel(nn.Module):
	num_classes: int

	@nn.compact
	def __call__(self, x: jnp.ndarray):
		x = nn.Dense(
			self.num_classes, name='head', kernel_init=nn.zeros
		)(x)
		return x

In [11]:
class Sequential(nn.Module):
	layers: list[nn.Module]

	@nn.compact
	def __call__(self, x: jnp.ndarray):
		for layer in self.layers:
			x = layer(x)
		return x

# Initialization

In [12]:
key, init_key = jax.random.split(key)

model_qnn = LearnableQNN()
model_classic = BasicLinearModel(num_classes=10)
model = Sequential(layers=[model_qnn, model_classic])
variables = model.init(init_key, jnp.empty((1, 28, 28, 1)))
params = variables['params']

In [13]:
lr_schedule = optax.cosine_onecycle_schedule(
	transition_steps=n_epochs * train_steps_per_epoch,
	peak_value=lr,
	pct_start=.2,
	final_div_factor=1000,
)
solver = optax.yogi(lr_schedule)
solver_state = solver.init(params)


In [14]:
train_metrics = Metrics(['loss', 'acc'])
val_metrics = Metrics(['loss', 'acc'])

In [15]:
def forward_and_loss(variables, images, labels):
	preds = model.apply(variables, images)
	loss = optax.losses.softmax_cross_entropy_with_integer_labels(preds, labels).mean()
	return loss, preds

# Training Loop

In [16]:
for epoch in trange(1, n_epochs+1):
	train_metrics.reset()
	val_metrics.reset()

	train_sampler = train_sampler_fn(seed=epoch)
	train_loader = train_loader_fn(sampler=train_sampler)

	for images, labels in (pbar := tqdm(train_loader, total=train_steps_per_epoch, desc='Training', leave=False)):
		(loss, preds), grad = jax.value_and_grad(forward_and_loss, has_aux=True)({ 'params': params }, images, labels)
		updates, solver_state = solver.update(grad['params'], solver_state, params)
		params = optax.apply_updates(params, updates)

		train_metrics.update({
			'loss': loss.item(),
			'acc': calc_acc(preds, labels),
		})
		pbar.set_postfix_str(train_metrics.previous)

	tqdm.write(f'epoch {epoch}: {train_metrics.epoch}')

	for images, labels in (pbar := tqdm(test_loader, total=test_steps_per_epoch, desc='Validation', leave=False)):
		loss, preds = forward_and_loss({ 'params': params }, images, labels)

		val_metrics.update({
			'loss': loss.item(),
			'acc': calc_acc(preds, labels),
		})
		pbar.set_postfix_str(val_metrics.previous)

	tqdm.write(f'  -> val: {val_metrics.epoch}')

  0%|          | 0/30 [00:00<?, ?it/s]

Training:   0%|          | 0/1876 [00:00<?, ?it/s]

epoch 1: loss: 2.0208	acc: 0.3495


Validation:   0%|          | 0/313 [00:00<?, ?it/s]

  -> val: loss: 1.8338	acc: 0.3455


Training:   0%|          | 0/1876 [00:00<?, ?it/s]

epoch 2: loss: 1.5882	acc: 0.4571


Validation:   0%|          | 0/313 [00:00<?, ?it/s]

  -> val: loss: 1.7338	acc: 0.3666


Training:   0%|          | 0/1876 [00:00<?, ?it/s]

epoch 3: loss: 1.3213	acc: 0.5320


Validation:   0%|          | 0/313 [00:00<?, ?it/s]

  -> val: loss: 1.8638	acc: 0.3743


Training:   0%|          | 0/1876 [00:00<?, ?it/s]

epoch 4: loss: 1.1905	acc: 0.5713


Validation:   0%|          | 0/313 [00:00<?, ?it/s]

  -> val: loss: 2.0465	acc: 0.3851


Training:   0%|          | 0/1876 [00:00<?, ?it/s]

epoch 5: loss: 1.1415	acc: 0.5819


Validation:   0%|          | 0/313 [00:00<?, ?it/s]

  -> val: loss: 2.2597	acc: 0.3789


Training:   0%|          | 0/1876 [00:00<?, ?it/s]

KeyboardInterrupt: 