In [2]:
!pip install torch torchvision pennylane matplotlib

Collecting pennylane
  Downloading pennylane-0.43.1-py3-none-any.whl.metadata (11 kB)
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.17.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting appdirs (from pennylane)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting autoray==0.8.0 (from pennylane)
  Downloading autoray-0.8.0-py3-none-any.whl.metadata (6.1 kB)
Collecting pennylane-lightning>=0.43 (from pennylane)
  Downloading pennylane_lightning-0.43.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (11 kB)
Collecting diastatic-malt (from pennylane)
  Downloading diastatic_malt-2.15.2-py3-none-any.whl.metadata (2.6 kB)
Collecting scipy-openblas32>=0.3.26 (from pennylane-lightning>=0.43->pennylane)
  Downloading scipy_openblas32-0.3.30.0.7-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.1/57.1

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import pennylane as qml
import numpy as np
import time
import copy
from collections import defaultdict



# --- 0. GLOBAL CONFIGURATION ---

In [4]:
n_qubits = 4
batch_size = 4
target_classes = [1, 9] # 1=Automobile, 9=Truck
local_epochs = 1
learning_rate = 0.005

# --- FEDERATED CONFIGURATION (Phase 2) ---
num_clients = 10
federated_rounds = 15
client_participation_rate = 0.5

# --- DEVICE SETUP ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")
print(f"Hybrid QNN Configuration: {n_qubits} Qubits")

Using device: cuda
Hybrid QNN Configuration: 4 Qubits


# ----------------------------------------------------------------------
#                             DATA PREPARATION
# ----------------------------------------------------------------------

In [5]:
# Standard ResNet normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def filter_data(dataset, targets):
    """Filters a dataset to keep only images with labels in targets."""
    indices = [i for i, label in enumerate(dataset.targets) if label in targets]
    subset = torch.utils.data.Subset(dataset, indices)
    return subset

# Download CIFAR-10
trainset_full = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset_full = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Filter the datasets for only Cars and Trucks
trainset = filter_data(trainset_full, target_classes)
testset = filter_data(testset_full, target_classes)

# --- Phase 1: CENTRALIZED DATA LOADERS (The "Walk" Baseline) ---
# Used for the centralized training loop only
central_trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

# Global Test Loader (Used by both phases for consistent evaluation)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
print(f"Filtered Dataset Ready. Total train images: {len(trainset)}.")

100%|██████████| 170M/170M [00:05<00:00, 30.9MB/s]


Filtered Dataset Ready. Total train images: 10000.


# --- Phase 2: NON-IID DATA SHARDING FOR CLIENTS (The "Run" HQFL) ---

In [6]:
def create_non_iid_clients(dataset, num_clients, batch_size):
    """Splits the dataset indices into num_clients partitions with non-IID skew."""

    label_to_indices = defaultdict(list)
    # 1=Car maps to 0, 9=Truck maps to 1.
    filtered_labels = np.array([1 if dataset.dataset.targets[i] == 9 else 0 for i in dataset.indices])

    for i, label in enumerate(filtered_labels):
        label_to_indices[label].append(dataset.indices[i])

    all_indices = label_to_indices[0] + label_to_indices[1]

    client_indices = defaultdict(list)
    indices_per_class = [len(label_to_indices[0]), len(label_to_indices[1])]

    # Skewing logic: Clients 0-4 get more of class 0 (Car), Clients 5-9 get more of class 1 (Truck)
    for i in range(num_clients):
        major_class = 0 if i < num_clients // 2 else 1
        minor_class = 1 - major_class

        num_samples = len(all_indices) // num_clients
        major_samples = int(num_samples * 0.8)
        minor_samples = num_samples - major_samples

        # Simple slicing to distribute data
        start_major = (i % (num_clients // 2)) * (indices_per_class[major_class] // (num_clients // 2))
        end_major = start_major + major_samples

        start_minor = (i % (num_clients // 2)) * (indices_per_class[minor_class] // (num_clients // 2))
        end_minor = start_minor + minor_samples

        major_indices_list = label_to_indices[major_class][start_major:end_major]
        minor_indices_list = label_to_indices[minor_class][start_minor:end_minor]

        client_indices[i].extend(major_indices_list)
        client_indices[i].extend(minor_indices_list)

    # Convert indices to DataLoader objects
    client_dataloaders = {}
    for i in range(num_clients):
        # We need to pass the full original dataset to the Subset constructor
        subset = torch.utils.data.Subset(dataset.dataset, client_indices[i])
        client_dataloaders[i] = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True)

    return client_dataloaders

# Create the Non-IID Client DataLoaders for Phase 2
client_dataloaders = create_non_iid_clients(trainset, num_clients, batch_size)
client_ids = list(range(num_clients))
print(f"Created {num_clients} Non-IID client dataloaders for Federated Simulation.")

Created 10 Non-IID client dataloaders for Federated Simulation.



# ----------------------------------------------------------------------
#                           HYBRID QNN ARCHITECTURE
# ----------------------------------------------------------------------

In [7]:
dev = qml.device("lightning.qubit", wires=n_qubits)
@qml.qnode(dev, interface="torch")
def quantum_circuit(inputs, weights):
    # Angle Embedding (Encodes the 4 classical features into 4 qubits)
    qml.AngleEmbedding(inputs, wires=range(n_qubits))

    # Basic Entangler Layers (The "Thinking" Part)
    qml.BasicEntanglerLayers(weights, wires=range(n_qubits))

    # Measurement (Expectation values of PauliZ map to class probabilities)
    return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

class HybridResNet(nn.Module):
    def __init__(self):
        super(HybridResNet, self).__init__()

        # A. CLASSICAL TRANSFER LEARNING PART (Encoder)
        self.resnet = torchvision.models.resnet18(pretrained=True)
        for param in self.resnet.parameters():
            param.requires_grad = False

        # Reduce 512 features to n_qubits (4)
        self.fc_reduce = nn.Linear(512, n_qubits)

        # B. QUANTUM PART (Core)
        weight_shapes = {"weights": (3, n_qubits)}
        self.q_layer = qml.qnn.TorchLayer(quantum_circuit, weight_shapes)

        # C. FINAL PREDICTION (Decoder)
        self.final_layer = nn.Linear(n_qubits, 2)

    def forward(self, x):
        # Classical feature extraction and pooling layers
        x = self.resnet.conv1(x); x = self.resnet.bn1(x); x = self.resnet.relu(x); x = self.resnet.maxpool(x)
        x = self.resnet.layer1(x); x = self.resnet.layer2(x); x = self.resnet.layer3(x); x = self.resnet.layer4(x)
        x = self.resnet.avgpool(x)
        x = torch.flatten(x, 1)

        # Dimension Reduction
        x_reduced = torch.tanh(self.fc_reduce(x)) * (np.pi / 2.0)

        # Quantum Processing
        x_q = self.q_layer(x_reduced)

        # Final Classify
        x = self.final_layer(x_q)
        return x

# ----------------------------------------------------------------------
#                           UTILITY FUNCTIONS
# ----------------------------------------------------------------------

In [8]:

def test_model_accuracy(model, dataloader, device):
    """Evaluates the model on the full test set."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            # Remap labels: 1(Car)->0, 9(Truck)->1
            binary_labels = torch.where(labels == 1, 0, 1).to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += binary_labels.size(0)
            correct += (predicted == binary_labels).sum().item()

    return 100 * correct / total

In [9]:
def client_update(model, dataloader, criterion, learning_rate, device):
    """Performs one local training epoch and returns the updated model weights."""
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        # Remap labels: 1(Car)->0, 9(Truck)->1
        binary_labels = torch.where(labels == 1, 0, 1).to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, binary_labels)
        loss.backward()
        optimizer.step()

    return model.state_dict()

In [10]:
def server_aggregate(global_model, client_weights):
    """Averages the weights of the client models (FedAvg), skipping integer buffers."""

    global_state_dict = global_model.state_dict()
    num_clients = len(client_weights)

    for k in global_state_dict.keys():
        # --- FIX: Check tensor data type ---
        # Skip aggregation (division) for integer tensors (like num_batches_tracked)
        if global_state_dict[k].dtype == torch.long or global_state_dict[k].dtype == torch.int:
            continue

        # 1. Reset the global parameter to zero before summing
        global_state_dict[k] = torch.zeros_like(global_state_dict[k])

        # 2. Sum all client weights
        for client_sd in client_weights:
            global_state_dict[k].add_(client_sd[k])

        # 3. Average the weights
        global_state_dict[k].div_(num_clients)

    global_model.load_state_dict(global_state_dict)
    return global_model

# ----------------------------------------------------------------------
#                  PHASE 1: CENTRALIZED BASELINE
# ----------------------------------------------------------------------

In [None]:
print("\n" + "="*50)
print("             PHASE 1: CENTRALIZED HYBRID BASELINE (WALK)")
print("="*50)

central_model = HybridResNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(central_model.parameters(), lr=learning_rate)
n_epochs = 3 # Run 5 epochs for the baseline

for epoch in range(n_epochs):
    central_model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(central_trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        # Remap labels: 1(Car)->0, 9(Truck)->1
        binary_labels = torch.where(labels == 1, 0, 1).to(device)

        optimizer.zero_grad()
        outputs = central_model(inputs)
        loss = criterion(outputs, binary_labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # Evaluate accuracy after each epoch
    acc = test_model_accuracy(central_model, testloader, device)
    print(f"Epoch {epoch + 1}/{n_epochs}, Loss: {running_loss / len(central_trainloader):.4f}, Test Accuracy: {acc:.2f}%")

final_central_acc = test_model_accuracy(central_model, testloader, device)
print(f"\n--- Phase 1 Complete. Final Centralized Accuracy: {final_central_acc:.2f}% ---")

print("\n" + "="*50)
print("              PHASE 2: FEDERATED HYBRID QNN (RUN)")
print("="*50)



             PHASE 1: CENTRALIZED HYBRID BASELINE (WALK)
Epoch 1/3, Loss: 0.5290, Test Accuracy: 90.65%
Epoch 2/3, Loss: 0.4365, Test Accuracy: 91.25%
Epoch 3/3, Loss: 0.4391, Test Accuracy: 90.05%

--- Phase 1 Complete. Final Centralized Accuracy: 90.05% ---

              PHASE 2: FEDERATED HYBRID QNN (RUN)


# ----------------------------------------------------------------------
#                  PHASE 2: FEDERATED TRAINING (RUN)
# ----------------------------------------------------------------------

In [None]:

print("\n" + "="*50)
print("              PHASE 2: FEDERATED HYBRID QNN     ")
print("==================================================")

# Initialize a FRESH global model
global_model = HybridResNet().to(device)
criterion = nn.CrossEntropyLoss()

print(f"Rounds: {federated_rounds}, Participation: {client_participation_rate*100}%, Clients: {num_clients}")

start_time_federated = time.time()

for round_num in range(federated_rounds):

    # 1. Server Selects Clients
    participating_clients = np.random.choice(client_ids,
                                             max(1, int(num_clients * client_participation_rate)),
                                             replace=False)

    client_weights = []

    # 2. Clients Train Locally
    for client_id in participating_clients:


        # A. Initialize a fresh local model (gets a fresh, unproblematic PennyLane device)
        local_model = HybridResNet().to(device)

        # B. Load the global weights (state_dict) into the local model

        local_model.load_state_dict(global_model.state_dict())

        # Perform local update (1 local epoch)
        weights = client_update(
            local_model,
            client_dataloaders[client_id],
            criterion,
            learning_rate,
            device
        )

        client_weights.append(weights)

    # 3. Server Aggregates
    global_model = server_aggregate(global_model, client_weights)

    # 4. Evaluate Global Model
    global_acc = test_model_accuracy(global_model, testloader, device)

    print(f"[Round {round_num + 1}/{federated_rounds}] Global Test Accuracy: {global_acc:.2f}%")

end_time_federated = time.time()
print(f"\n--- Phase 2 Complete. Final Federated Accuracy: {global_acc:.2f}% (Time: {end_time_federated - start_time_federated:.2f}s) ---")


              PHASE 2: FEDERATED HYBRID QNN     
Rounds: 15, Participation: 50.0%, Clients: 10
[Round 1/15] Global Test Accuracy: 49.35%
[Round 2/15] Global Test Accuracy: 50.00%
[Round 3/15] Global Test Accuracy: 50.00%
[Round 4/15] Global Test Accuracy: 51.25%
[Round 5/15] Global Test Accuracy: 51.20%
[Round 6/15] Global Test Accuracy: 54.25%
[Round 7/15] Global Test Accuracy: 87.55%
[Round 8/15] Global Test Accuracy: 89.30%
[Round 9/15] Global Test Accuracy: 88.10%
[Round 10/15] Global Test Accuracy: 90.60%
[Round 11/15] Global Test Accuracy: 90.75%
[Round 12/15] Global Test Accuracy: 91.15%
[Round 13/15] Global Test Accuracy: 89.65%
[Round 14/15] Global Test Accuracy: 90.55%
[Round 15/15] Global Test Accuracy: 90.20%

--- Phase 2 Complete. Final Federated Accuracy: 90.20% (Time: 1950.80s) ---


In [11]:

def SPSA_Optimizer(model, criterion, inputs, labels, lr, c):

    # 1. Create the perturbation vector 'delta' (random sign for each parameter)
    delta = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            # Random sign (+1 or -1) for each parameter, scaled by c
            delta[name] = (2 * torch.randint(0, 2, param.shape).to(param.device) - 1.0) * c

    # 2. Forward pass for f(w + c*delta) (Positive Perturbation)
    # Apply positive perturbation
    for name, param in model.named_parameters():
        if name in delta:
            param.data.add_(delta[name])

    outputs_plus = model(inputs)
    loss_plus = criterion(outputs_plus, labels)

    # 3. Forward pass for f(w - c*delta) (Negative Perturbation)
    # Apply negative perturbation (subtract 2 * c * delta from the current state w + c*delta)
    for name, param in model.named_parameters():
        if name in delta:
            param.data.sub_(2 * delta[name])

    outputs_minus = model(inputs)
    loss_minus = criterion(outputs_minus, labels)

    # The approximate gradient is (f(w+c*delta) - f(w-c*delta)) / (2*c*delta)

    # Calculate the change in loss
    loss_diff = loss_plus - loss_minus

    # Restore the model to its original state (w_k) before the update
    for name, param in model.named_parameters():
        if name in delta:

            param.data.add_(delta[name])

            # SPSA estimate: g_k = loss_diff / (2 * delta)
            gradient_estimate = loss_diff / (2.0 * delta[name])

            # Update: w_k+1 = w_k - lr * g_k
            update = lr * gradient_estimate
            param.data.sub_(update)

    return (loss_plus.item() + loss_minus.item()) / 2.0

In [None]:
def client_update_spsa(model, dataloader, criterion, learning_rate, device, c_spsa=0.01):
    model.train()

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        binary_labels = torch.where(labels == 1, 0, 1).to(device)

        SPSA_Optimizer(
            model,
            criterion,
            inputs,
            binary_labels,
            learning_rate,
            c_spsa # Perturbation size (a fixed small number)
        )

    return model.state_dict()

# ----------------------------------------------------------------------
#             PHASE 3: HQFL with SPSA (Noise Resilience Test)
# ----------------------------------------------------------------------

In [12]:
def client_update_spsa(model, dataloader, criterion, learning_rate, device, c_spsa=0.05):

    model.train()

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        binary_labels = torch.where(labels == 1, 0, 1).to(device)

        # SPSA does not use PyTorch's automatic backward() or optimizer.step()
        SPSA_Optimizer(
            model,
            criterion,
            inputs,
            binary_labels,
            learning_rate,
            c_spsa # Perturbation size
        )
        # ... rest of the function remains the same
    return model.state_dict()

In [13]:


print("\n" + "="*50)
print("              PHASE 3: HQFL + SPSA OPTIMIZER")
print("==================================================")

global_model_spsa = HybridResNet().to(device) # Initialize a FRESH global model
criterion = nn.CrossEntropyLoss()

print(f"Rounds: {federated_rounds}, Optimizer: SPSA, Clients: {num_clients}")

start_time_spsa = time.time()

for round_num in range(federated_rounds):

    # 1. Server Selects Clients
    participating_clients = np.random.choice(client_ids,
                                             max(1, int(num_clients * client_participation_rate)),
                                             replace=False)

    client_weights = []

    # 2. Clients Train Locally
    for client_id in participating_clients:

        # Clone global weights to local model
        local_model = HybridResNet().to(device)
        local_model.load_state_dict(global_model_spsa.state_dict())

        # --- CALL THE SPSA UPDATE FUNCTION ---
        weights = client_update_spsa(
            local_model,
            client_dataloaders[client_id],
            criterion,
            learning_rate,
            device
        )

        client_weights.append(weights)

    # 3. Server Aggregates (using the same FedAvg logic)
    global_model_spsa = server_aggregate(global_model_spsa, client_weights)

    # 4. Evaluate Global Model
    global_acc = test_model_accuracy(global_model_spsa, testloader, device)

    print(f"[Round {round_num + 1}/{federated_rounds}] SPSA Global Test Accuracy: {global_acc:.2f}%")

end_time_spsa = time.time()
print(f"\n--- Phase 3 Complete. Final SPSA Accuracy: {global_acc:.2f}% (Time: {end_time_spsa - start_time_spsa:.2f}s) ---")




              PHASE 3: HQFL + SPSA OPTIMIZER
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 156MB/s]


Rounds: 15, Optimizer: SPSA, Clients: 10
[Round 1/15] SPSA Global Test Accuracy: 50.00%
[Round 2/15] SPSA Global Test Accuracy: 50.00%
[Round 3/15] SPSA Global Test Accuracy: 50.65%
[Round 4/15] SPSA Global Test Accuracy: 54.15%
[Round 5/15] SPSA Global Test Accuracy: 49.90%
[Round 6/15] SPSA Global Test Accuracy: 49.95%
[Round 7/15] SPSA Global Test Accuracy: 63.10%
[Round 8/15] SPSA Global Test Accuracy: 56.60%
[Round 9/15] SPSA Global Test Accuracy: 60.10%
[Round 10/15] SPSA Global Test Accuracy: 59.45%
[Round 11/15] SPSA Global Test Accuracy: 62.65%
[Round 12/15] SPSA Global Test Accuracy: 63.25%
[Round 13/15] SPSA Global Test Accuracy: 70.50%
[Round 14/15] SPSA Global Test Accuracy: 68.15%
[Round 15/15] SPSA Global Test Accuracy: 71.15%

--- Phase 3 Complete. Final SPSA Accuracy: 71.15% (Time: 1123.11s) ---
