In [24]:
import pennylane as qml
from jax import numpy as jnp
import optax
import catalyst

In [25]:
# Install required packages if needed:
# %pip install tensorflow_datasets optax opencv-python

import tensorflow_datasets as tfds
import tensorflow as tf
import numpy as np
import cv2
from IPython.display import clear_output
import matplotlib.pyplot as plt

import pennylane as qml
from jax import numpy as jnp
import jax
import optax
import catalyst

###############################
# 1. Data Loading and Preprocessing
###############################

import tensorflow as tf
import numpy as np
import cv2
import tensorflow_datasets as tfds

# --- 1. The Python Preprocessing Function ---
def preprocess_fn(image, label):
    # image is a NumPy array in [0, 255] (uint8)
    image = image.astype(np.float32) / 255.0  # scale to [0,1]
    image = np.squeeze(image)  # remove channel dimension if shape is (28,28,1)
    
    # Resize from 28x28 to 14x14 using cv2.
    small = cv2.resize(image, (14, 14))
    flat = small.flatten()  # now 196 elements
    
    # Pad the flattened array to length 256.
    desired_length = 256  # 2^8
    pad_length = desired_length - flat.shape[0]
    flat = np.pad(flat, (0, pad_length), mode='constant')
    
    # Normalize the vector (avoid division by zero).
    norm = np.linalg.norm(flat)
    if norm > 0:
        flat = flat / norm
        
    # Ensure label is float32.
    label = np.float32(label)
    return flat, label

# --- 2. The Mapping Function for tf.data ---
def tf_preprocess_fn(example):
    # Explicitly extract image and label from the example dictionary.
    image = example['image']
    label = example['label']
    
    # Use tf.numpy_function instead of tf.py_function.
    flat, lbl = tf.numpy_function(
        func=preprocess_fn,
        inp=[image, label],
        Tout=[tf.float32, tf.float32]
    )
    # Set the static shape information.
    flat.set_shape((256,))
    lbl.set_shape(())
    return flat, lbl

# --- 3. Load and Process the Dataset ---
# Load MNIST dataset using TFDS.
train_ds_raw = tfds.load("mnist", split="train", shuffle_files=True)
test_ds_raw  = tfds.load("mnist", split="test", shuffle_files=False)

# (Optional) Filter for digits 0 and 1 for a binary classification task.
def filter_fn(example):
    return tf.math.logical_or(tf.equal(example['label'], 0),
                              tf.equal(example['label'], 1))

train_ds_raw = train_ds_raw.filter(filter_fn)
test_ds_raw  = test_ds_raw.filter(filter_fn)

# Map the preprocessing function over the dataset.
train_ds = train_ds_raw.map(tf_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)
test_ds  = test_ds_raw.map(tf_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)

# Batch and prefetch.
batch_size = 32
train_ds = train_ds.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
test_ds  = test_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)


# Filter for digits 0 and 1.
train_ds = train_ds_raw.filter(filter_fn)
test_ds  = test_ds_raw.filter(filter_fn)

# Apply the preprocessing function.
# Note: we use map(..., num_parallel_calls=tf.data.AUTOTUNE) for efficiency.
def tf_preprocess_fn(example):
    flat, label = tf.py_function(func=preprocess_fn, inp=[example], Tout=[tf.float32, tf.float32])
    # Ensure shapes are set.
    flat.set_shape((256,))
    label.set_shape(())
    return flat, label

train_ds = train_ds.map(tf_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)
test_ds  = test_ds.map(tf_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)

# Batch and prefetch the datasets.
batch_size = 32
train_ds = train_ds.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
test_ds  = test_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

###############################
# 2. Define the Quantum Model
###############################

# We use 8 wires (so that the amplitude embedding vector has length 2^8 = 256).
n_wires = 8
dev = qml.device("lightning.qubit", wires=n_wires)

@qml.qnode(dev, interface="jax")
def circuit(data, weights):
    """
    Quantum circuit:
      - Embed the (preprocessed) image using amplitude embedding.
      - Apply a layer of parameterized rotations on each qubit.
      - Apply a ring of CNOT gates for entanglement.
      - Measure <Z> on qubit 0.
    """
    # --- Data Embedding ---
    qml.AmplitudeEmbedding(data, wires=range(n_wires), normalize=True)

    # --- Variational Ansatz ---
    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RZ(weights[i, 2], wires=i)
    # Apply entanglement.
    for i in range(n_wires):
        qml.CNOT(wires=[i, (i + 1) % n_wires])
        
    # --- Measurement ---
    return qml.expval(qml.PauliZ(0))

# To vectorize over a batch of images (each of length 256), we use Catalyst's vmap.
# Here, the first argument of circuit (data) is batched over axis 0.
batched_circuit = qml.qjit(catalyst.vmap(circuit, in_axes=(0, None)))

def my_model(data, weights, bias):
    """
    Our model runs the quantum circuit on each image (data) and shifts the result by a bias.
    The resulting logit is passed through a sigmoid so that the model outputs a probability.
    """
    # batched_circuit returns an array of shape (batch_size,)
    logits = batched_circuit(data, weights) + bias
    return jax.nn.sigmoid(logits)

@qml.qjit
def loss_fn(params, data, targets):
    """
    Binary cross-entropy loss.
    """
    predictions = my_model(data, params["weights"], params["bias"])
    epsilon = 1e-10
    loss = -jnp.mean(targets * jnp.log(predictions + epsilon) +
                      (1 - targets) * jnp.log(1 - predictions + epsilon))
    return loss

###############################
# 3. Optimizer and Update Step
###############################

# Initialize the trainable parameters.
init_weights = jnp.ones([n_wires, 3])
init_bias = jnp.array(0.0)
params = {"weights": init_weights, "bias": init_bias}

# Create an Adam optimizer.
opt = optax.adam(learning_rate=0.3)

@qml.qjit
def update_step(i, args):
    """
    One optimization step:
      - Compute the gradients of loss_fn.
      - Update the parameters using Optax.
    """
    params, opt_state, data, targets = args
    grads = catalyst.grad(loss_fn, method="fd")(params, data, targets)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return (params, opt_state, data, targets)

def train_step(params, opt_state, batch):
    """
    Process one batch: update the model parameters and compute loss and accuracy.
    """
    # Each batch is a tuple (data, labels), where:
    #   - data has shape (batch_size, 256)
    #   - labels has shape (batch_size,)
    data, labels = batch
    # Convert to JAX arrays.
    data = jnp.array(data)
    labels = jnp.array(labels)
    
    # Update parameters.
    params, opt_state, _, _ = update_step(0, (params, opt_state, data, labels))
    # Compute loss.
    loss_val = loss_fn(params, data, labels)
    # Compute accuracy.
    predictions = my_model(data, params["weights"], params["bias"])
    acc = jnp.mean((predictions > 0.5) == labels)
    return params, opt_state, float(loss_val), float(acc)

def eval_step(params, batch):
    """
    Evaluate on one batch (without updating parameters).
    """
    data, labels = batch
    data = jnp.array(data)
    labels = jnp.array(labels)
    loss_val = loss_fn(params, data, labels)
    predictions = my_model(data, params["weights"], params["bias"])
    acc = jnp.mean((predictions > 0.5) == labels)
    return float(loss_val), float(acc)

###############################
# 4. Training and Evaluation Loop
###############################

# We'll record metrics in a dictionary.
metrics_history = {
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': [],
}

# Training parameters.
train_steps = 500  # total training steps
eval_every = 100   # evaluate on test set every 100 steps

# Initialize optimizer state.
opt_state = opt.init(params)

# Training loop.
step = 0
for epoch in range(10):  # or as many epochs as desired
    # Iterate over the training dataset.
    for batch in train_ds.as_numpy_iterator():
        params, opt_state, batch_loss, batch_acc = train_step(params, opt_state, batch)
        step += 1

        # Log training metrics every eval_every steps.
        if step % eval_every == 0 or step == train_steps:
            # Evaluate on the training set (a few batches).
            train_losses = []
            train_accs = []
            for train_batch in train_ds.take(5).as_numpy_iterator():
                l, a = eval_step(params, train_batch)
                train_losses.append(l)
                train_accs.append(a)
            avg_train_loss = np.mean(train_losses)
            avg_train_acc = np.mean(train_accs)
            metrics_history['train_loss'].append(avg_train_loss)
            metrics_history['train_accuracy'].append(avg_train_acc)

            # Evaluate on the test set (a few batches).
            test_losses = []
            test_accs = []
            for test_batch in test_ds.take(5).as_numpy_iterator():
                l, a = eval_step(params, test_batch)
                test_losses.append(l)
                test_accs.append(a)
            avg_test_loss = np.mean(test_losses)
            avg_test_acc = np.mean(test_accs)
            metrics_history['test_loss'].append(avg_test_loss)
            metrics_history['test_accuracy'].append(avg_test_acc)

            # Clear output and plot the metrics.
            clear_output(wait=True)
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            ax1.set_title('Loss')
            ax2.set_title('Accuracy')
            ax1.plot(metrics_history['train_loss'], label='train_loss')
            ax1.plot(metrics_history['test_loss'], label='test_loss')
            ax2.plot(metrics_history['train_accuracy'], label='train_accuracy')
            ax2.plot(metrics_history['test_accuracy'], label='test_accuracy')
            ax1.legend()
            ax2.legend()
            plt.show()

        # Stop training after the designated number of steps.
        if step >= train_steps:
            break
    if step >= train_steps:
        break

print("Training complete!")


TypeError: in user code:

    File "/var/folders/6q/mzgmvrhn3l76p312t6zn3mxw0000gn/T/ipykernel_5574/2892362907.py", line 97, in tf_preprocess_fn  *
        flat, label = tf.py_function(func=preprocess_fn, inp=[example], Tout=[tf.float32, tf.float32])

    TypeError: Tensors in list passed to 'input' of 'EagerPyFunc' Op have types [<NOT CONVERTIBLE TO TENSOR>] that are invalid. Tensors: [{'image': <tf.Tensor 'args_0:0' shape=(28, 28, 1) dtype=uint8>, 'label': <tf.Tensor 'args_1:0' shape=() dtype=int64>}]
