# Quantum MNIST Classification Tutorial

This notebook provides a step-by-step walkthrough of building and training a hybrid quantum-classical neural network for MNIST digit classification.

## What We'll Build

We'll create a hybrid neural network that combines:
- Classical neural network layers for preprocessing high-dimensional image data
- Quantum circuits with trainable parameters for feature processing
- Classical layers for final classification

This architecture demonstrates how quantum computing can be integrated into machine learning pipelines.

## 1. Setup and Imports

First, let's import all necessary libraries and set up our environment.

In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

# Add src directory to path
sys.path.append('../src')

from config import Config
from data_utils import MNISTDataLoader, set_seed, get_device
from models import SimplifiedHybridQNN, ClassicalNN, model_summary
from train import Trainer, create_optimizer
from visualize import (
    plot_training_curves,
    plot_confusion_matrix,
    evaluate_model,
    print_evaluation_results
)
from quantum_circuit import HybridQuantumCircuit

print("All imports successful!")

## 2. Understanding Quantum Circuits

Before training our model, let's understand what a quantum circuit looks like.

A quantum circuit is composed of:
- **Qubits**: The quantum analog of classical bits
- **Quantum Gates**: Operations that manipulate qubit states
- **Measurements**: Extracting classical information from quantum states

Our hybrid circuit has two main parts:
1. **Feature Map**: Encodes classical data into quantum states
2. **Variational Circuit**: Contains trainable parameters (like weights in neural networks)

In [None]:
# Create and visualize a quantum circuit
hybrid_circuit = HybridQuantumCircuit(
    n_qubits=4,
    feature_dim=4,
    feature_reps=2,
    var_reps=3
)

circuit = hybrid_circuit.create_circuit()
print(f"Quantum Circuit Information:")
print(f"  Number of qubits: {circuit.num_qubits}")
print(f"  Circuit depth: {circuit.depth()}")
print(f"  Number of gates: {circuit.size()}")
print(f"  Trainable parameters: {hybrid_circuit.get_num_parameters()}")

# Draw the circuit
try:
    fig = circuit.draw(output='mpl', fold=20)
    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Could not draw circuit: {e}")

## 3. Data Preparation

MNIST contains 28x28 pixel grayscale images of handwritten digits (0-9).

For this tutorial, we'll use binary classification (distinguishing between two digits) because:
- It trains faster on quantum simulators
- It's easier to visualize and understand
- The same principles apply to multi-class problems

In [None]:
# Set random seed for reproducibility
set_seed(42)

# Configuration
Config.DATASET_TYPE = 'binary'
Config.BINARY_CLASS_A = 0
Config.BINARY_CLASS_B = 1
Config.BINARY_TRAIN_SIZE = 500
Config.BINARY_TEST_SIZE = 100
Config.BATCH_SIZE = 32
Config.NUM_EPOCHS = 10

print(f"Training binary classifier: {Config.BINARY_CLASS_A} vs {Config.BINARY_CLASS_B}")

In [None]:
# Load data
data_loader = MNISTDataLoader(
    data_dir='../data',
    batch_size=Config.BATCH_SIZE
)

train_loader, val_loader, test_loader = data_loader.create_binary_classification_dataset(
    class_a=Config.BINARY_CLASS_A,
    class_b=Config.BINARY_CLASS_B,
    train_size=Config.BINARY_TRAIN_SIZE,
    test_size=Config.BINARY_TEST_SIZE
)

print("\nDataset loaded successfully!")
data_loader.get_dataset_info(train_loader)

### Visualize Sample Images

In [None]:
# Get a batch of images
images, labels = next(iter(train_loader))

# Plot first 10 images
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
    if i < len(images):
        img = images[i].squeeze()
        ax.imshow(img, cmap='gray')
        ax.set_title(f'Label: {labels[i].item()}')
        ax.axis('off')

plt.tight_layout()
plt.show()

print(f"Image shape: {images[0].shape}")
print(f"Total pixels: {images[0].numel()}")

## 4. Build the Hybrid Quantum-Classical Model

Our hybrid model architecture:

```
Input (784 pixels)
    ↓
Classical Preprocessing (784 → 128 → 32 → 4)
    ↓
Quantum Circuit (4 qubits with trainable parameters)
    ↓
Classical Postprocessing (2 → 16 → 2 classes)
    ↓
Output (class predictions)
```

The classical preprocessing compresses the high-dimensional image data down to a size suitable for quantum processing. Currently, we're limited to a small number of qubits on simulators, so we use 4 qubits as a practical choice.

In [None]:
# Get device
device = get_device()

# Create hybrid quantum model
quantum_model = SimplifiedHybridQNN(
    n_qubits=4,
    n_classes=2
).to(device)

print("\nQuantum Hybrid Model:")
model_summary(quantum_model)

## 5. Build Classical Baseline for Comparison

To understand if the quantum layer provides any advantage, we need a classical baseline. This is a standard fully-connected neural network with a similar number of parameters.

In [None]:
# Create classical baseline model
classical_model = ClassicalNN(
    hidden_sizes=[128, 64, 32],
    n_classes=2
).to(device)

print("\nClassical Baseline Model:")
model_summary(classical_model)

## 6. Training the Models

Now we'll train both models using the same hyperparameters and training procedure. This ensures a fair comparison.

Training involves:
1. Forward pass: Computing predictions
2. Loss calculation: Measuring how wrong the predictions are
3. Backward pass: Computing gradients
4. Parameter update: Adjusting weights to reduce loss

For quantum models, the trainable parameters are the rotation angles in the quantum gates.

### Train Quantum Model

In [None]:
# Setup training for quantum model
quantum_optimizer = create_optimizer(quantum_model, Config)
criterion = nn.CrossEntropyLoss()

quantum_trainer = Trainer(
    model=quantum_model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=quantum_optimizer,
    device=device,
    config=Config,
    model_name='quantum_hybrid_notebook'
)

# Train the model
quantum_history = quantum_trainer.train(Config.NUM_EPOCHS)

### Train Classical Model

In [None]:
# Setup training for classical model
classical_optimizer = create_optimizer(classical_model, Config)

classical_trainer = Trainer(
    model=classical_model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=classical_optimizer,
    device=device,
    config=Config,
    model_name='classical_baseline_notebook'
)

# Train the model
classical_history = classical_trainer.train(Config.NUM_EPOCHS)

## 7. Visualize Training Progress

Training curves help us understand:
- How quickly the model learns
- Whether the model is overfitting (training accuracy much higher than validation)
- Whether we need to train longer or adjust hyperparameters

In [None]:
# Plot quantum model training curves
plot_training_curves(quantum_history)
plt.suptitle('Quantum Hybrid Model Training', fontsize=16, fontweight='bold', y=1.02)
plt.show()

In [None]:
# Plot classical model training curves
plot_training_curves(classical_history)
plt.suptitle('Classical Model Training', fontsize=16, fontweight='bold', y=1.02)
plt.show()

## 8. Evaluate on Test Set

Now let's evaluate both models on unseen test data to get an unbiased measure of performance.

In [None]:
# Evaluate quantum model
quantum_results = evaluate_model(quantum_model, test_loader, device, n_classes=2)
print_evaluation_results(quantum_results, "Quantum Hybrid Model")

In [None]:
# Evaluate classical model
classical_results = evaluate_model(classical_model, test_loader, device, n_classes=2)
print_evaluation_results(classical_results, "Classical Model")

## 9. Confusion Matrices

Confusion matrices show us exactly which classes the model confuses. Diagonal elements are correct predictions, off-diagonal are errors.

In [None]:
# Quantum model confusion matrix
plot_confusion_matrix(
    quantum_results['labels'],
    quantum_results['predictions'],
    class_names=[f'Digit {Config.BINARY_CLASS_A}', f'Digit {Config.BINARY_CLASS_B}']
)
plt.suptitle('Quantum Model Confusion Matrix', fontsize=14, fontweight='bold', y=1.02)
plt.show()

In [None]:
# Classical model confusion matrix
plot_confusion_matrix(
    classical_results['labels'],
    classical_results['predictions'],
    class_names=[f'Digit {Config.BINARY_CLASS_A}', f'Digit {Config.BINARY_CLASS_B}']
)
plt.suptitle('Classical Model Confusion Matrix', fontsize=14, fontweight='bold', y=1.02)
plt.show()

## 10. Model Comparison

Let's directly compare the performance metrics of both models.

In [None]:
# Create comparison table
import pandas as pd

comparison_data = {
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1-Score'],
    'Quantum Model': [
        f"{quantum_results['accuracy']*100:.2f}%",
        f"{quantum_results['precision']:.4f}",
        f"{quantum_results['recall']:.4f}",
        f"{quantum_results['f1_score']:.4f}"
    ],
    'Classical Model': [
        f"{classical_results['accuracy']*100:.2f}%",
        f"{classical_results['precision']:.4f}",
        f"{classical_results['recall']:.4f}",
        f"{classical_results['f1_score']:.4f}"
    ]
}

comparison_df = pd.DataFrame(comparison_data)
print("\nModel Performance Comparison:")
print("=" * 60)
print(comparison_df.to_string(index=False))
print("=" * 60)

## 11. Key Takeaways

### What We Learned:

1. **Hybrid Architecture**: We successfully integrated quantum circuits into a classical neural network using PyTorch and Qiskit.

2. **Quantum Circuits**: The quantum component uses parameterized gates that are trained just like weights in classical neural networks.

3. **Current Limitations**: Quantum simulators are computationally expensive, which is why we:
   - Used only 4 qubits
   - Trained on a subset of MNIST
   - Focused on binary classification

4. **Performance**: The quantum model may perform similarly to or slightly differently than the classical baseline. This is expected because:
   - We're using simulators, not real quantum hardware
   - The problem size is small (binary classification)
   - Quantum advantage typically appears in specific problem types

### Future Improvements:

- Try different quantum circuit designs (ansatz)
- Experiment with more qubits (when hardware allows)
- Test on problems where quantum computing might excel
- Use real quantum hardware instead of simulators

### Why This Matters:

This project demonstrates the feasibility of hybrid quantum-classical machine learning. As quantum hardware improves, these techniques could potentially solve problems that are intractable for classical computers.

## 12. Experiment: Try It Yourself!

Now that you understand the basics, try modifying the code:

1. Change which digits to classify (e.g., 3 vs 8 instead of 0 vs 1)
2. Adjust the number of training samples
3. Modify the quantum circuit parameters
4. Try multi-class classification (3+ classes)

Use the cells below to experiment!

In [None]:
# Your experiments here
