# Task 5

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import flax
from flax import linen as nn
from flax.training import train_state
import optax
from transformers import AutoProcessor, FlaxResNetModel
import dm_pix as pix
import pennylane as qml

from tqdm.auto import tqdm, trange
from grain import python as pygrain

from PIL import Image

import os
import json
from dataclasses import dataclass
from functools import partial

## Configuration

In [None]:
batch_size = 16
num_epochs = 10
lr = 1e-4
data_dir = 'data'
seed = 42
dataloader_workers = 2
num_classes = 6
qnn_wires = 4
qnn_layers = 2
image_size = 224
mean = jnp.array([0.24085431])
var = jnp.array([0.01992414])

In [None]:
key = jax.random.PRNGKey(seed)

## Dataset

In [None]:
def normalize_images(images: jax.Array) -> jax.Array:
	images /= 255
	images = jax.nn.standardize(images, mean=mean, variance=var, axis=(1, 2))

	return images

@jax.jit
def train_transform(images: jax.Array, key: jax.Array) -> jax.Array:
	n, h, w, c = images.shape

	images = jax.image.resize(images, (n, 512, 512, c), method='bicubic')
	images = pix.random_flip_left_right(key, images)
	images = pix.random_flip_up_down(key, images)
	images = pix.random_crop(key, images, (n, image_size, image_size, c))
	images = normalize_images(images)

	return images

@jax.jit
def test_transform(images: jax.Array, key: jax.Array) -> jax.Array:
	n, h, w, c = images.shape

	images = jax.image.resize(images, (n, image_size, image_size, c), method='bicubic')
	images = normalize_images(images)

	return images

In [None]:
class ImageTransform(pygrain.RandomMapTransform):
	def __init__(self, transform_fn):
		self.transform_fn = transform_fn

	def random_map(self, data: tuple[np.ndarray, np.ndarray], rng: np.random.Generator) -> tuple[jax.Array, jax.Array]:
		images, labels = data
		images, labels = jnp.array(images), jnp.array(labels)

		if len(images.shape) == 3:
			images = images[:, :, :, None]

		key = jax.random.PRNGKey(rng.integers(0, 2**32))
		images = self.transform_fn(images, key)
		return images, labels

In [None]:
class ImageDataSource(pygrain.RandomAccessDataSource[tuple[Image.Image, int]]):
	def __init__(self, path, split, num_classes = 6):
		self.image_dir = os.path.join(path, split)
		with open(os.path.join(self.image_dir, f'{split}.json')) as f:
			data = json.load(f)
			self.images = tuple(data.keys())
			self.labels = np.array(tuple(data.values()))
		self.num_classes = num_classes

	def __len__(self) -> int:
		return len(self.images)

	def __getitem__(self, idx) -> tuple[Image.Image, int]:
		image_path = os.path.join(self.image_dir, self.images[idx])
		image = Image.open(image_path).convert('RGB')
		label = self.labels[idx].item()
		return image, label

### Weighted Index Sampler

As the dataset is very imbalanced, we need to use a weighted sampler to ensure that each class is represented equally. This class is intended to work similarly to the [PyTorch Weighted Random Sampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler).

In [None]:
class WeightedIndexSampler(pygrain.Sampler):
	def __init__(self, weights: np.ndarray, seed: int, num_epochs: int = 1):
		assert num_epochs > 0
		self._num_records = len(weights)
		self._max_index = self._num_records * num_epochs
		self._weights = weights
		self._seed = seed
		self._rng = np.random.Generator(np.random.Philox(self._seed))
		self._record_keys = self._rng.choice(self._num_records, size=self._max_index, replace=True, p=self._weights)

	def __getitem__(self, index: int) -> pygrain.RecordMetadata:
		if not 0 <= index < self._max_index:
			raise IndexError(
				f"RecordMetadata object index is out of bounds; Got index {index},"
				f" allowed indices should be in [0, {self._max_index}]"
			)

		record_key = self._record_keys[index]
		rng = np.random.Generator(np.random.Philox(key=self._seed + index))
		return pygrain.RecordMetadata(index, record_key, rng)

	def __len__(self) -> int:
		return self._max_index

In [None]:
train_dataset = ImageDataSource(data_dir, 'train')
train_class_p = 1 / (np.stack([(train_dataset.labels == i).sum() for i in range(num_classes)]))
train_data_p = train_class_p[train_dataset.labels]
train_data_p /= train_data_p.sum()
train_steps_per_epoch = len(train_dataset) // batch_size + 1

train_sampler_fn = partial(
	WeightedIndexSampler,
	weights=train_data_p,
	num_epochs=1,
)
train_loader_fn = partial(
	pygrain.DataLoader,
	data_source=train_dataset,
	operations=[
		pygrain.Batch(batch_size=batch_size, drop_remainder=False),
		ImageTransform(train_transform),
	],
	worker_count=2,
	shard_options=pygrain.NoSharding(),
)
train_steps_per_epoch = len(train_dataset) // batch_size + 1

test_dataset = ImageDataSource(data_dir, 'test')
test_sampler = pygrain.IndexSampler(
	num_records=len(test_dataset),
	num_epochs=1,
	shard_options=pygrain.NoSharding(),
	shuffle=False,
	seed=0,
)
test_loader = pygrain.DataLoader(
	data_source=test_dataset,
	operations=[
		pygrain.Batch(batch_size=batch_size, drop_remainder=False),
		ImageTransform(test_transform),
	],
	sampler=test_sampler,
	worker_count=2,
)
test_steps_per_epoch = len(test_dataset) // batch_size + 1

## Metrics

We define a simple metrics aggregator below

In [None]:
@dataclass
class Metric:
	total: float = 0.
	previous: float = 0.
	counter: int = 0

In [None]:
class Metrics:
	def __init__(self, metrics: list[str]) -> None:
		self.keys = metrics
		self.history = []
		self.reset()

	def reset(self) -> None:
		if hasattr(self, 'metrics'):
			self.history.append(self.epoch_dict)
		self.metrics = {k: Metric() for k in self.keys}

	def update(self, metrics: dict[str, float|int]) -> None:
		for k, v in metrics.items():
			self.metrics[k].total += v
			self.metrics[k].previous = v
			self.metrics[k].counter += 1

	@property
	def epoch_dict(self) -> dict[str, float]:
		return {k: v.total / v.counter for k, v in self.metrics.items()}

	@property
	def epoch(self) -> str:
		return '\t'.join([f'{k}: {v.total / v.counter:.4f}' for k, v in self.metrics.items()])

	@property
	def previous(self) -> str:
		return ', '.join([f'{k}: {v.previous:.4f}' for k, v in self.metrics.items()])

In [None]:
def calc_acc(preds: jnp.ndarray, labels: jnp.ndarray) -> float:
	return (preds.argmax(axis=-1) == labels).mean().item()

In [None]:
def update_metrics(metrics, loss, preds, labels):
	accuracy = calc_acc(preds, labels)
	metrics.update({
		'loss': loss,
		'accuracy': accuracy,
	})
	return metrics

# Models

In [None]:
dev = qml.device("default.qubit", wires=qnn_wires)

@partial(jax.jit, static_argnames=('wires',))
@qml.qnode(dev) #let params have uh 10 qubits (we're going to cry)
def learnable_qnn_circuit(param, phi, wires):
	for wire in range(wires):
		qml.RY(np.pi * (param[0][wire][0] * phi[wire] + param[0][wire][1]), wires=wire)
		qml.RX(param[0][wire][2], wires=wire)

	for layer_weights in param[2:]:
		for wire in range(wires):
			qml.Rot(*layer_weights[wire], wires=wire)
		for wire in range(wires):
			qml.CNOT(wires=[wire, (wire+1) % wires])

	return qml.expval(qml.PauliZ(0))


class LearnableQNN(nn.Module):
	wires: int
	layers: int

	def setup(self):
		self.qnn_params = self.param('qnn_params', nn.initializers.uniform(scale=2*jnp.pi), (self.wires, self.layers, 3))
		self.kernel_width = int(self.wires**0.5)

	def __call__(self, x):
		n = x.shape[0]
		patches = pix.extract_patches(
			images=x,
			sizes=(1, self.kernel_width, self.kernel_width, 1),
			strides=(1, self.kernel_width, self.kernel_width, 1),
			rates=(1, 1, 1, 1),
			padding='VALID',
		)
		patches = patches.reshape(-1, self.wires)
		return jax.vmap(learnable_qnn_circuit, in_axes=(None, 0, None))(self.qnn_params, patches, self.wires).reshape(n, -1)

In [None]:
class BasicLinearModel(nn.Module):
	num_classes: int

	@nn.compact
	def __call__(self, x: jnp.ndarray):
		x = nn.Dense(
			self.num_classes, name='head', kernel_init=nn.zeros
		)(x)
		return x

In [None]:
class Sequential(nn.Module):
	layers: list[nn.Module]

	@nn.compact
	def __call__(self, x: jnp.ndarray):
		for layer in self.layers:
			x = layer(x)
		return x

# Model Utilities

In [None]:
def load_hf(model_class, model_name):
	processor = AutoProcessor.from_pretrained(model_name)
	model = model_class.from_pretrained(model_name)
	module = model.module
	variables = model.params
	return module, variables, processor

In [None]:
def create_train_state(module, rng, lr, num_epochs, train_steps_per_epoch, print_summary):
	variables = module.init(rng, jnp.empty([1, image_size, image_size, 3]))
	params = variables['params']

	lr_schedule = optax.cosine_onecycle_schedule(
		transition_steps=num_epochs * train_steps_per_epoch,
		peak_value=lr,
		pct_start=.1,
		final_div_factor=1000,
	)
	solver = optax.yogi(lr_schedule)

	if print_summary:
		print(module.tabulate(rng, jnp.empty((1, image_size, image_size, 3)), compute_flops=True, compute_vjp_flops=True))

	return train_state.TrainState.create(
		apply_fn=module.apply,
		params=params,
		tx=solver,
	)

In [None]:
@jax.jit
def train_step(state, images, labels):
	def forward_and_loss(params, images, labels):
		preds = state.apply_fn({ 'params': params }, images)
		loss = optax.losses.softmax_cross_entropy_with_integer_labels(
			logits=preds, labels=labels
		).mean()
		return loss, preds

	(loss, preds), grads = jax.value_and_grad(forward_and_loss, has_aux=True)(state.params, images, labels)
	state = state.apply_gradients(grads=grads)
	return state, loss, preds

@jax.jit
def test_step(state, images, labels):
	preds = state.apply_fn({ 'params': state.params }, images)
	loss = optax.losses.softmax_cross_entropy_with_integer_labels(
		logits=preds, labels=labels
	).mean()
	return state, loss, preds

In [None]:
def run_epoch(epoch_type, state, train_loader, train_steps_per_epoch, metrics):
	assert epoch_type in ['Train', 'Test']
	for images, labels in (pbar := tqdm(train_loader, total=train_steps_per_epoch, desc=epoch_type, leave=False)):
		if epoch_type == 'Train':
			state, loss, preds = train_step(state, images, labels)
		else:
			state, loss, preds = test_step(state, images, labels)
		update_metrics(metrics, loss, preds, labels)
		pbar.set_postfix_str(metrics.previous)

	tqdm.write(f'   -> {epoch_type}:\t{metrics.epoch}')
	metrics.reset()
	return state

In [None]:
qnn_train_metrics = Metrics(['loss', 'accuracy'])
qnn_test_metrics = Metrics(['loss', 'accuracy'])
cnn_train_metrics = Metrics(['loss', 'accuracy'])
cnn_test_metrics = Metrics(['loss', 'accuracy'])

# Initialization

In [None]:
key, qnn_init_key, cnn_init_key = jax.random.split(key, 3)

In [None]:
qnn_module = Sequential(layers=[LearnableQNN(wires=qnn_wires, layers=qnn_layers), BasicLinearModel(num_classes=10)])
qnn_state = create_train_state(qnn_module, qnn_init_key, lr, num_epochs, train_steps_per_epoch, True)

# Training Loop

In [None]:
for epoch in trange(1, num_epochs+1, desc='QNN'):
	train_sampler = train_sampler_fn(seed=epoch)
	train_loader = train_loader_fn(sampler=train_sampler)

	tqdm.write(f'Epoch {epoch}/{num_epochs}')

	qnn_state = run_epoch('Train', qnn_state, train_loader, train_steps_per_epoch, qnn_train_metrics)
	qnn_state = run_epoch('Test', qnn_state, test_loader, test_steps_per_epoch, qnn_test_metrics)