# Hybrid Quantum-Classical Neural Network for CIFAR-10

This notebook implements a hybrid model that combines a classical CNN for feature extraction with a quantum circuit for processing, followed by classification. The architecture processes CIFAR-10 images through the following steps:

1. Classical CNN extracts 8 features from images
2. Linear layer converts 8 features to 4 quantum angles
3. Quantum circuit processes these angles
4. Final classifier produces class probabilities

## 1. Setup Dependencies

Import required libraries and set up the environment:

In [None]:
# prereqs (conda): python=3.11 pytorch:pytorch pytorch:torchvision ipykernel
# conda create -n cnn4 python=3.11 pytorch:pytorch pytorch:torchvision ipykernel 
# conda activate cnn4
# pip install pennylane amazon-braket-pennylane-plugin

In [1]:
# %pip install pennylane==0.32.0 amazon-braket-pennylane-plugin

import torch
import torchvision.transforms as T
import torchvision.datasets as datasets
import torch.nn as nn
import pennylane as qml
from pennylane.qnn import TorchLayer

## 2. Data Loading and Preprocessing

Load a small sample of the CIFAR-10 dataset and prepare it for training:

In [2]:
# Define data transformations
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

# Load CIFAR-10 dataset
ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(ds, batch_size=8, shuffle=True)

# Get a batch of images
imgs, _ = next(iter(loader))            # imgs: [8,3,32,32]
print(f"Loaded image batch shape: {imgs.shape}")

Files already downloaded and verified
Loaded image batch shape: torch.Size([8, 3, 32, 32])


## 3. Feature Extractor Implementation

Implement a small CNN that converts 32x32 RGB images into 8-dimensional feature vectors:

In [3]:
class TinyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3,8,3,padding=1), nn.ReLU(),
            nn.Conv2d(8,16,3,padding=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16*32*32, 16), nn.ReLU(),
            nn.Linear(16, 8)
        )
    def forward(self,x): return self.net(x)

# Create feature extractor instance
feat_extractor = TinyCNN()

## 4. Quantum Device Configuration

Set up the quantum device using Amazon Braket's SV1 simulator:

In [5]:
# Initialize the quantum device with 4 qubits
import os
# Set AWS region (using us-east-1 as an example)
os.environ['AWS_DEFAULT_REGION'] = 'us-east-1'

dev = qml.device("braket.aws.qubit", 
                device_arn="arn:aws:braket:::device/quantum-simulator/amazon/sv1", 
                wires=4)

NoCredentialsError: Unable to locate credentials

## 5. Quantum Circuit Definition

Define the quantum circuit with RX, RY, RZ gates and CNOT operations:

In [None]:
# Define the quantum circuit
@qml.qnode(dev, interface="torch", diff_method="backprop")
def qnode(inputs, weights):
    # Input encoding
    for i in range(4):
        qml.RX(inputs[i], wires=i)
    
    # Trainable rotation gates
    for i in range(4):
        qml.RY(weights[i,0], wires=i)
        qml.RZ(weights[i,1], wires=i)
    
    # Entangling gates
    for i in range(3):
        qml.CNOT(wires=[i, i+1])
    
    # Measure observables
    return [qml.expval(qml.PauliZ(w)) for w in range(4)]

# Define weight shapes and create quantum layer
weight_shapes = {"weights": (4,2)}
qlayer = TorchLayer(qnode, weight_shapes)

## 6. Hybrid Model Implementation

Create the complete hybrid model that combines the classical CNN, quantum processing, and final classification layer:

In [None]:
class HybridModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = feat_extractor        # CNN feature extractor
        self.pre = nn.Linear(8, 4)       # Convert features to quantum angles
        self.q = qlayer                  # Quantum processing
        self.classifier = nn.Linear(4, 10)  # Final classification layer
    
    def forward(self, x):
        b = x.shape[0]                   # Batch size
        f = self.cnn(x)                  # Extract features
        ang = self.pre(f)                # Convert to angles
        ang = torch.tanh(ang) * 3.1415   # Scale to [-π,π]
        # Process each sample through quantum circuit
        qout = torch.stack([self.q(ang[i]) for i in range(b)])
        return self.classifier(qout)      # Final classification

# Create model instance
model = HybridModel()

## 7. Model Testing

Perform a forward pass through the model to verify the output shape:

In [None]:
# Perform a forward pass
out = model(imgs)
print(f"Input shape: {imgs.shape}")
print(f"Output shape: {out.shape}")  # expected: torch.Size([8, 10])