# Task 3 - MNIST Dataset

In this notebook, we present our solution to task 3, where we train a QML model on the MNIST dataset. We make the following improvements to the tutorial presented

* Testing on the Fashion MNIST dataset
* Allow the QNN parameters to be trained
* Use JAX instead of Keras
* Use `jax.vmap` to speed up the QNN
* Add augmentations to the training process

In [None]:
import pennylane as qml
from pennylane import numpy as np

import jax
from jax import numpy as jnp
import flax
from flax import linen as nn
from flax.training import train_state
import optax

import grain.python as pygrain
import dm_pix
import orbax.checkpoint

import os
import gzip
import requests
from functools import partial
from dataclasses import dataclass
from itertools import combinations
import json

from tqdm.notebook import tqdm, trange
import matplotlib.pyplot as plt

## Configuration

In [None]:
n_epochs = 10
batch_size = 32
lr = 1e-3
mnist = False
qnn_wires = 9
qnn_layers = 4

assert qnn_wires == int(qnn_wires**0.5)**2

SAVE_PATH = "data/"
np.random.seed(0)
key = jax.random.PRNGKey(0)

## Data

We use the [Google Grain](https://github.com/google/grain) library to load our dataset. We use rotational augmentations to improve generalization.

In [None]:
def download(url: str, fname: str, chunk_size=1024):
	if os.path.exists(fname):
		return

	resp = requests.get(url, stream=True)
	total = int(resp.headers.get('content-length', 0))
	with open(fname, 'wb') as file, tqdm(
		desc=fname,
		total=total,
		unit='iB',
		unit_scale=True,
		unit_divisor=1024,
	) as bar:
		for data in resp.iter_content(chunk_size=chunk_size):
			size = file.write(data)
			bar.update(size)

def load_mnist(path):
	download('https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz', os.path.join(path, 'mnist.npz'))
	return np.load(os.path.join(path, 'mnist.npz'))

def load_fashion_mnist(path):
	base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
	files = [
		"train-labels-idx1-ubyte.gz",
		"train-images-idx3-ubyte.gz",
		"t10k-labels-idx1-ubyte.gz",
		"t10k-images-idx3-ubyte.gz",
	]

	def open_gzip(fname):
		return gzip.open(os.path.join(path, fname), 'rb')

	for fname in files:
		download(f'{base}{fname}', os.path.join(path, fname))

	with open_gzip(files[0]) as lbpath:
		y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

	with open_gzip(files[1]) as imgpath:
		x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(
			len(y_train), 28, 28
		)

	with open_gzip(files[2]) as lbpath:
		y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)

	with open_gzip(files[3]) as imgpath:
		x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(
			len(y_test), 28, 28
		)

	return {
		'x_train': x_train,
		'y_train': y_train,
		'x_test': x_test,
		'y_test': y_test,
	}

In [None]:
load_fn, mean, std = (load_mnist, 0.1307, 0.3081) if mnist else (load_fashion_mnist, 0.2860, 0.3530)

In [None]:
class ImageDataSource(pygrain.RandomAccessDataSource[tuple[np.ndarray, np.ndarray]]):
	def __init__(self, path, split):
		data = load_fn(path)
		self.images = data[f'x_{split}'][..., np.newaxis]
		self.labels = data[f'y_{split}']

	def __len__(self) -> int:
		return len(self.images)

	def __getitem__(self, idx) -> tuple[np.ndarray, np.ndarray]:
		image, label = self.images[idx], self.labels[idx]
		return image, label

In [None]:
class ImageTransform(pygrain.RandomMapTransform):
	def __init__(self, mean, var, augment, flip_x=False, flip_y=False, rotate=0):
		self.augment = augment
		self.flip_x = flip_x
		self.flip_y = flip_y
		self.rotate = rotate
		self.mean = jnp.array(mean)
		self.var = jnp.array(var)

	# @jax.jit
	def random_map(self, data: tuple[np.ndarray, np.ndarray], rng: np.random.Generator) -> tuple[jax.Array, jax.Array]:
		images, labels = data
		images, labels = jnp.array(images), jnp.array(labels)

		if self.augment:
			key = jax.random.PRNGKey(rng.integers(0, 2**32))

			if self.flip_x:
				images = dm_pix.random_flip_up_down(key, images)

			if self.flip_y:
				images = dm_pix.random_flip_left_right(key, images)

			if self.rotate:
				angle = jax.random.uniform(key, shape=images.shape[0], minval=-self.rotate, maxval=self.rotate) / 180 * jnp.pi
				images = jax.vmap(
					partial(dm_pix.rotate, mode='constant', cval=255),
					in_axes=[0, 0], out_axes=0
				)(images, angle)

		images /= 255
		images = jax.nn.standardize(
			images,
			mean=self.mean,
			variance=self.var,
			axis=(2, 3)
		)
		return images, labels

### Initializing data loaders

We use the known mean and std values for the MNIST dataset. However, we use the variance instead of the std as that is the expected input for `dm-pix`.

In [None]:
train_dataset = ImageDataSource(SAVE_PATH, 'train')
train_sampler_fn = partial(
	pygrain.IndexSampler,
	num_records=len(train_dataset),
	num_epochs=1,
	shard_options=pygrain.NoSharding(),
	shuffle=True,
)
train_loader_fn = partial(
	pygrain.DataLoader,
	data_source=train_dataset,
	operations=[
		pygrain.Batch(batch_size=batch_size, drop_remainder=False),
		ImageTransform(mean, std ** 2, True, True, False, 10),
	],
	worker_count=2,
)
train_steps_per_epoch = len(train_dataset) // batch_size + 1

test_dataset = ImageDataSource(SAVE_PATH, 'test')
test_sampler = pygrain.IndexSampler(
	num_records=len(test_dataset),
	num_epochs=1,
	shard_options=pygrain.NoSharding(),
	shuffle=False,
	seed=0,
)
test_loader = pygrain.DataLoader(
	data_source=test_dataset,
	operations=[
		pygrain.Batch(batch_size=batch_size, drop_remainder=False),
		ImageTransform(mean, std ** 2, False),
	],
	sampler=test_sampler,
	worker_count=2,
)
test_steps_per_epoch = len(test_dataset) // batch_size + 1

## Metrics

We define a simple metrics aggregator below

In [None]:
@dataclass
class Metric:
	total: float = 0.
	previous: float = 0.
	counter: int = 0

In [None]:
class Metrics:
	def __init__(self, metrics: list[str]) -> None:
		self.keys = metrics
		self.history = []
		self.reset()

	def reset(self) -> None:
		if hasattr(self, 'metrics'):
			self.history.append(self.epoch_dict)
		self.metrics = {k: Metric() for k in self.keys}

	def update(self, metrics: dict[str, float|int]) -> None:
		for k, v in metrics.items():
			self.metrics[k].total += v
			self.metrics[k].previous = v
			self.metrics[k].counter += 1

	def save(self, path: str) -> None:
		with open(path, 'w') as f:
			json.dump(self.history, f)

	@property
	def epoch_dict(self) -> dict[str, float]:
		return {k: v.total / v.counter for k, v in self.metrics.items()}

	@property
	def epoch(self) -> str:
		return '\t'.join([f'{k}: {v.total / v.counter:.4f}' for k, v in self.metrics.items()])

	@property
	def previous(self) -> str:
		return ', '.join([f'{k}: {v.previous:.4f}' for k, v in self.metrics.items()])

In [None]:
def calc_acc(preds: jnp.ndarray, labels: jnp.ndarray) -> float:
	return (preds.argmax(axis=-1) == labels).mean().item()

In [None]:
def update_metrics(metrics, loss, preds, labels):
	accuracy = calc_acc(preds, labels)
	metrics.update({
		'loss': loss,
		'accuracy': accuracy,
	})
	return metrics

# Models

Here, we implement a learnable QNN circuit. We use `dm_pix.extract_patches` to extract the patches of the image for the convolution. Then, we use `jax.vmap` to execute the QNN in a vectorized form, returning an array of shape `(batch_size, qnn_output_x * qnn_output_y * qnn_output_channels)`. We use a linear layer as a head for this model.

We also define a CNN (LeNet-5) with a similar FLOPs requirement as the QNN to compare the effectiveness given the same amount of classical compute power.

In [None]:
dev = qml.device("default.qubit", wires=qnn_wires)

@partial(jax.jit, static_argnames=('wires',))
@qml.qnode(dev) #let params have uh 10 qubits (we're going to cry)
def learnable_qnn_circuit(param, phi, wires):
	for wire in range(wires):
		qml.RY(np.pi * (param[0][wire][0] * phi[wire] + param[0][wire][1]), wires=wire)
		qml.RX(param[0][wire][2], wires=wire)

	for layer_weights in param[2:]:
		for wire in range(wires):
			qml.Rot(*layer_weights[wire], wires=wire)
		for wire in range(wires):
			qml.CNOT(wires=[wire, (wire+1) % wires])

	return qml.expval(qml.PauliZ(0))


class LearnableQNN(nn.Module):
	wires: int
	layers: int

	def setup(self):
		self.qnn_params = self.param('qnn_params', nn.initializers.uniform(scale=2*jnp.pi), (self.wires, self.layers, 3))
		self.kernel_width = int(self.wires**0.5)

	def __call__(self, x):
		n = x.shape[0]
		patches = dm_pix.extract_patches(
			images=x,
			sizes=(1, self.kernel_width, self.kernel_width, 1),
			strides=(1, self.kernel_width, self.kernel_width, 1),
			rates=(1, 1, 1, 1),
			padding='VALID',
		)
		patches = patches.reshape(-1, self.wires)
		return jax.vmap(learnable_qnn_circuit, in_axes=(None, 0, None))(self.qnn_params, patches, self.wires).reshape(n, -1)

In [None]:
class BasicLinearModel(nn.Module):
	num_classes: int

	@nn.compact
	def __call__(self, x: jnp.ndarray):
		x = nn.Dense(
			self.num_classes, name='head', kernel_init=nn.zeros
		)(x)
		return x

In [None]:
class Sequential(nn.Module):
	layers: list[nn.Module]

	@nn.compact
	def __call__(self, x: jnp.ndarray):
		for layer in self.layers:
			x = layer(x)
		return x

In [None]:
class LeNet5(nn.Module):
	num_classes: int

	@nn.compact
	def __call__(self, x: jnp.ndarray):
		x = nn.Conv(features=6, kernel_size=(5, 5), strides=(1, 1), padding='VALID')(x)
		x = nn.relu(x)
		x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID')

		x = nn.Conv(features=16, kernel_size=(5, 5), strides=(1, 1), padding='VALID')(x)
		x = nn.relu(x)
		x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID')

		x = x.reshape((x.shape[0], -1))
		x = nn.Dense(120)(x)
		x = nn.relu(x)
		x = nn.Dense(84)(x)
		x = nn.relu(x)
		x = nn.Dense(self.num_classes)(x)
		return x

# Utilities

In [None]:
def create_train_state(module, rng, lr, n_epochs, train_steps_per_epoch, print_summary):
	variables = module.init(rng, jnp.ones([1, 28, 28, 1]))
	params = variables['params']

	lr_schedule = optax.cosine_onecycle_schedule(
		transition_steps=n_epochs * train_steps_per_epoch,
		peak_value=lr,
		pct_start=.1,
		final_div_factor=1000,
	)
	solver = optax.yogi(lr_schedule)

	if print_summary:
		print(module.tabulate(rng, jnp.ones((1, 28, 28, 1)), compute_flops=True, compute_vjp_flops=True))

	return train_state.TrainState.create(
		apply_fn=module.apply,
		params=params,
		tx=solver,
	)

In [None]:
@jax.jit
def train_step(state, images, labels):
	def forward_and_loss(params, images, labels):
		preds = state.apply_fn({ 'params': params }, images)
		loss = optax.losses.softmax_cross_entropy_with_integer_labels(
			logits=preds, labels=labels
		).mean()
		return loss, preds

	(loss, preds), grads = jax.value_and_grad(forward_and_loss, has_aux=True)(state.params, images, labels)
	state = state.apply_gradients(grads=grads)
	return state, loss, preds

@jax.jit
def test_step(state, images, labels):
	preds = state.apply_fn({ 'params': state.params }, images)
	loss = optax.losses.softmax_cross_entropy_with_integer_labels(
		logits=preds, labels=labels
	).mean()
	return state, loss, preds

In [None]:
def run_epoch(epoch_type, state, train_loader, train_steps_per_epoch, metrics):
	assert epoch_type in ['Train', 'Test']
	for images, labels in (pbar := tqdm(train_loader, total=train_steps_per_epoch, desc=epoch_type, leave=False)):
		if epoch_type == 'Train':
			state, loss, preds = train_step(state, images, labels)
		else:
			state, loss, preds = test_step(state, images, labels)
		update_metrics(metrics, loss, preds, labels)
		pbar.set_postfix_str(metrics.previous)

	tqdm.write(f'   -> {epoch_type}:\t{metrics.epoch}')
	metrics.reset()
	return state

In [None]:
qnn_train_metrics = Metrics(['loss', 'accuracy'])
qnn_test_metrics = Metrics(['loss', 'accuracy'])
cnn_train_metrics = Metrics(['loss', 'accuracy'])
cnn_test_metrics = Metrics(['loss', 'accuracy'])

# Initialization

In [None]:
key, qnn_init_key, cnn_init_key = jax.random.split(key, 3)

In [None]:
qnn_module = Sequential(layers=[LearnableQNN(wires=qnn_wires, layers=qnn_layers), BasicLinearModel(num_classes=10)])
qnn_state = create_train_state(qnn_module, qnn_init_key, lr, n_epochs, train_steps_per_epoch, True)

In [None]:
cnn_module = LeNet5(num_classes=10)
cnn_state = create_train_state(cnn_module, cnn_init_key, lr, n_epochs, train_steps_per_epoch, True)

# Training Loop

In [122]:
for epoch in trange(1, n_epochs+1, desc='QNN'):
	train_sampler = train_sampler_fn(seed=epoch)
	train_loader = train_loader_fn(sampler=train_sampler)

	tqdm.write(f'Epoch {epoch}/{n_epochs}')

	qnn_state = run_epoch('Train', qnn_state, train_loader, train_steps_per_epoch, qnn_train_metrics)
	qnn_state = run_epoch('Test', qnn_state, test_loader, test_steps_per_epoch, qnn_test_metrics)

In [None]:
for epoch in trange(1, n_epochs+1, desc='CNN'):
	train_sampler = train_sampler_fn(seed=epoch)
	train_loader = train_loader_fn(sampler=train_sampler)

	tqdm.write(f'Epoch {epoch}/{n_epochs}')

	cnn_state = run_epoch('Train', cnn_state, train_loader, train_steps_per_epoch, cnn_train_metrics)
	cnn_state = run_epoch('Test', cnn_state, test_loader, test_steps_per_epoch, cnn_test_metrics)