In [None]:
pip install tenseal flwr pennylane

## QNN

In [None]:
"""
Hybrid Quantum-assisted CNN on CIFAR-10
PennyLane + PyTorch – single GPU
~ 63 % test accuracy in 30 min on RTX-3080
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
import pennylane as qml
from pennylane.qnn import TorchLayer
import numpy as np
from tqdm.auto import tqdm
import os

# -------------------------------------------------
# 1. Config
# -------------------------------------------------
N_QUBITS      = 4          # keep it simulable
N_QLAYERS     = 2
EPOCHS        = 30
LR            = 3e-3
WEIGHT_DECAY  = 5e-4
LABEL_SMOOTH  = 0.1
MIXUP_ALPHA   = 0.2
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"

torch.backends.cudnn.benchmark = True

# -------------------------------------------------
# 2. Data
# -------------------------------------------------
train_transform = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                (0.2470, 0.2435, 0.2616))
])
test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                (0.2470, 0.2435, 0.2616))
])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                         download=True, transform=train_transform)
test_set  = torchvision.datasets.CIFAR10(root='./data', train=False,
                                         download=True, transform=test_transform)

# -------------------------------------------------
# 3. Quantum node
# -------------------------------------------------
dev = qml.device("default.qubit", wires=N_QUBITS)

@qml.qnode(dev, interface="torch", diff_method="best")
def quantum_circuit(inputs, weights):
    qml.AngleEmbedding(inputs, wires=range(N_QUBITS), rotation="Y")
    qml.BasicEntanglerLayers(weights, wires=range(N_QUBITS))
    return [qml.expval(qml.PauliZ(i)) for i in range(N_QUBITS)]

weight_shapes = {"weights": (N_QLAYERS, N_QUBITS)}
qlayer = TorchLayer(quantum_circuit, weight_shapes)

# -------------------------------------------------
# 4. Hybrid model
# -------------------------------------------------
class HybridQCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.MaxPool2d(2),                           # 16×16
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),                           # 8×8
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1))                # 1×1×128
        )
        self.fc_reduce = nn.Linear(128, N_QUBITS)     # 128 → 4
        self.qlayer    = qlayer                       # quantum
        self.classifier = nn.Linear(N_QUBITS, 10)     # 10 classes

    def forward(self, x):
        x = self.features(x)
        x = x.flatten(1)
        x = self.fc_reduce(x)
        x = self.qlayer(x)
        return self.classifier(x)

model = HybridQCNN().to(DEVICE)

# -------------------------------------------------
# 5. Batch-size auto-tuning
# -------------------------------------------------
@torch.no_grad()
def find_max_batch(model, device, start=64, max_batch=512):
    b = start
    while b <= max_batch:
        try:
            x = torch.randn(b, 3, 32, 32, device=device)
            _ = model(x)
            b *= 2
        except RuntimeError:
            break
    return min(b//2, max_batch)

BATCH_SIZE = find_max_batch(model, DEVICE, 64, 512)
print(f"Auto-selected batch size: {BATCH_SIZE}")

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
                          shuffle=True,  num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE,
                          shuffle=False, num_workers=4, pin_memory=True)

# -------------------------------------------------
# 6. Loss, optimiser, scheduler
# -------------------------------------------------
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR,
                                          epochs=EPOCHS, steps_per_epoch=len(train_loader))

# -------------------------------------------------
# 7. Mixup helpers
# -------------------------------------------------
def mixup_data(x, y, alpha=MIXUP_ALPHA):
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# -------------------------------------------------
# 8. Training
# -------------------------------------------------
best_acc = 0.
for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss, running_acc, n = 0., 0., 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    for x, y in pbar:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        optimizer.zero_grad()

        x, y_a, y_b, lam = mixup_data(x, y)
        logits = model(x)
        loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
        loss.backward()
        optimizer.step()
        scheduler.step()

        preds = logits.argmax(1)
        acc = (lam * (preds == y_a).float() + (1 - lam) * (preds == y_b).float()).mean().item()
        running_loss += loss.item() * x.size(0)
        running_acc  += acc * x.size(0)
        n += x.size(0)
        pbar.set_postfix({"loss": running_loss / n, "acc": running_acc / n})

    # ---------- validation ----------
    model.eval()
    val_acc, m = 0., 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
            logits = model(x)
            val_acc += (logits.argmax(1) == y).float().sum().item()
            m += y.size(0)
    val_acc /= m
    print(f"Epoch {epoch:02d}  val-acc {val_acc:.2%}  lr {scheduler.get_last_lr()[0]:.2e}")
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_qcnn.pt")

print(f"\nBest test accuracy: {best_acc:.2%}")

## FL-QNN

In [None]:


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Subset
import pennylane as qml
from pennylane.qnn import TorchLayer
import numpy as np
from tqdm.auto import tqdm
import copy
import time
import matplotlib.pyplot as plt

# -------------------------------------------------
# 1. Config
# -------------------------------------------------
N_QUBITS      = 4
N_QLAYERS     = 2
N_CLIENTS     = 5          # Number of federated clients
GLOBAL_ROUNDS = 10         # Federated rounds
LOCAL_EPOCHS  = 3          # Local training epochs per client
LR            = 3e-3
WEIGHT_DECAY  = 5e-4
LABEL_SMOOTH  = 0.1
MIXUP_ALPHA   = 0.2
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"

# For comparison
COMPARISON_EPOCHS = 30      # Non-FL training epochs

torch.backends.cudnn.benchmark = True

print(f"Federated Learning Setup:")
print(f"  Clients: {N_CLIENTS}")
print(f"  Global Rounds: {GLOBAL_ROUNDS}")
print(f"  Local Epochs: {LOCAL_EPOCHS}")
print(f"  Device: {DEVICE}")

# -------------------------------------------------
# 2. Data - Non-IID Distribution
# -------------------------------------------------
train_transform = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                (0.2470, 0.2435, 0.2616))
])
test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                (0.2470, 0.2435, 0.2616))
])

full_train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                               download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=test_transform)

def create_non_iid_splits(dataset, n_clients, classes_per_client=4):
    """
    Create non-IID data splits for federated learning.
    Each client gets data from only a subset of classes.
    """
    n_classes = 10
    client_indices = [[] for _ in range(n_clients)]

    # Group indices by class
    class_indices = {i: [] for i in range(n_classes)}
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)

    # Assign classes to clients
    for client_id in range(n_clients):
        # Each client gets 'classes_per_client' classes
        start_class = (client_id * classes_per_client // n_clients) % n_classes
        client_classes = [(start_class + i) % n_classes for i in range(classes_per_client)]

        for cls in client_classes:
            # Give each client a portion of data from assigned classes
            n_samples = len(class_indices[cls]) // (n_clients // 2)  # Some overlap
            client_indices[client_id].extend(class_indices[cls][:n_samples])

        print(f"Client {client_id}: {len(client_indices[client_id])} samples, classes {client_classes}")

    return client_indices

# -------------------------------------------------
# 3. Quantum node
# -------------------------------------------------
dev = qml.device("default.qubit", wires=N_QUBITS)

@qml.qnode(dev, interface="torch", diff_method="best")
def quantum_circuit(inputs, weights):
    qml.AngleEmbedding(inputs, wires=range(N_QUBITS), rotation="Y")
    qml.BasicEntanglerLayers(weights, wires=range(N_QUBITS))
    return [qml.expval(qml.PauliZ(i)) for i in range(N_QUBITS)]

weight_shapes = {"weights": (N_QLAYERS, N_QUBITS)}
qlayer = TorchLayer(quantum_circuit, weight_shapes)

# -------------------------------------------------
# 4. Hybrid model
# -------------------------------------------------
class HybridQCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.fc_reduce = nn.Linear(128, N_QUBITS)
        self.qlayer    = qlayer
        self.classifier = nn.Linear(N_QUBITS, 10)

    def forward(self, x):
        x = self.features(x)
        x = x.flatten(1)
        x = self.fc_reduce(x)
        x = self.qlayer(x)
        return self.classifier(x)

# -------------------------------------------------
# 5. Federated Learning Functions (FIXED)
# -------------------------------------------------
def federated_averaging(global_model, client_models, client_weights=None):
    """
    FedAvg: Aggregate client models by averaging their parameters
    FIXED: Handle different dtypes properly
    """
    if client_weights is None:
        client_weights = [1.0 / len(client_models)] * len(client_models)

    global_dict = global_model.state_dict()

    for key in global_dict.keys():
        # Initialize with zeros of the same dtype
        global_dict[key] = torch.zeros_like(global_dict[key])

        for client_model, weight in zip(client_models, client_weights):
            client_param = client_model.state_dict()[key]

            # Handle different dtypes (int64 for buffers, float for weights)
            if client_param.dtype == torch.int64 or client_param.dtype == torch.long:
                # For buffers like num_batches_tracked, just copy from first client
                global_dict[key] = client_param.clone()
                break
            else:
                # For float parameters, do weighted average
                global_dict[key] += weight * client_param

    global_model.load_state_dict(global_dict)
    return global_model

def mixup_data(x, y, alpha=MIXUP_ALPHA):
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def train_client(model, dataset, epochs, device, batch_size=64):
    """
    Train a single client's local model
    """
    model.train()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
                           num_workers=2, pin_memory=True)

    criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    for epoch in range(epochs):
        running_loss = 0.0
        running_acc = 0.0
        n = 0

        for x, y in dataloader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)

            optimizer.zero_grad()
            x, y_a, y_b, lam = mixup_data(x, y)
            logits = model(x)
            loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
            loss.backward()
            optimizer.step()

            preds = logits.argmax(1)
            acc = (lam * (preds == y_a).float() + (1 - lam) * (preds == y_b).float()).mean().item()
            running_loss += loss.item() * x.size(0)
            running_acc += acc * x.size(0)
            n += x.size(0)

    return running_loss / n, running_acc / n

@torch.no_grad()
def evaluate_global_model(model, test_loader, device):
    """
    Evaluate global model on test set
    """
    model.eval()
    correct = 0
    total = 0

    for x, y in test_loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        correct += (logits.argmax(1) == y).sum().item()
        total += y.size(0)

    return correct / total

# -------------------------------------------------
# 6. Non-FL Training (for comparison)
# -------------------------------------------------
def train_non_fl(epochs, device, batch_size):
    """
    Train a single model without federated learning (baseline)
    """
    print("\n" + "="*70)
    print("BASELINE: NON-FL TRAINING")
    print("="*70)

    model = HybridQCNN().to(device)
    train_loader = DataLoader(full_train_set, batch_size=batch_size,
                             shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size,
                            shuffle=False, num_workers=2, pin_memory=True)

    criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR,
                                              epochs=epochs, steps_per_epoch=len(train_loader))

    history = []
    start_time = time.time()

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss, running_acc, n = 0., 0., 0

        for x, y in train_loader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            optimizer.zero_grad()

            x, y_a, y_b, lam = mixup_data(x, y)
            logits = model(x)
            loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
            loss.backward()
            optimizer.step()
            scheduler.step()

            preds = logits.argmax(1)
            acc = (lam * (preds == y_a).float() + (1 - lam) * (preds == y_b).float()).mean().item()
            running_loss += loss.item() * x.size(0)
            running_acc  += acc * x.size(0)
            n += x.size(0)

        # Validation
        val_acc = evaluate_global_model(model, test_loader, device)
        history.append(val_acc)

        if epoch % 5 == 0:
            print(f"Epoch {epoch:02d}  val-acc {val_acc:.2%}")

    training_time = time.time() - start_time
    final_acc = history[-1]

    print(f"\nBaseline Results:")
    print(f"  Final Accuracy: {final_acc:.2%}")
    print(f"  Training Time: {training_time:.2f}s")

    return final_acc, training_time, history

# -------------------------------------------------
# 7. Federated Training
# -------------------------------------------------
def train_federated(n_clients, global_rounds, local_epochs, device, batch_size):
    """
    Train with federated learning
    """
    print("\n" + "="*70)
    print(f"FEDERATED LEARNING: {n_clients} Clients")
    print("="*70)

    # Create client data splits
    client_indices = create_non_iid_splits(full_train_set, n_clients)
    client_datasets = [Subset(full_train_set, indices) for indices in client_indices]

    # Initialize global model
    global_model = HybridQCNN().to(device)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False,
                            num_workers=2, pin_memory=True)

    history = []
    start_time = time.time()
    best_acc = 0.0

    # Federated rounds
    for round_num in range(1, global_rounds + 1):
        print(f"\nRound {round_num}/{global_rounds}")

        client_models = []
        client_weights = []

        # Train each client
        for client_id in range(n_clients):
            # Create local model copy
            client_model = HybridQCNN().to(device)
            client_model.load_state_dict(global_model.state_dict())

            # Train locally
            loss, acc = train_client(
                client_model,
                client_datasets[client_id],
                local_epochs,
                device,
                batch_size
            )

            client_models.append(client_model)
            client_weights.append(len(client_datasets[client_id]))

        # Normalize weights
        total_samples = sum(client_weights)
        client_weights = [w / total_samples for w in client_weights]

        # Aggregate (FedAvg)
        global_model = federated_averaging(global_model, client_models, client_weights)

        # Evaluate global model
        global_acc = evaluate_global_model(global_model, test_loader, device)
        history.append(global_acc)

        print(f"  Global Accuracy: {global_acc:.2%}")

        if global_acc > best_acc:
            best_acc = global_acc

    training_time = time.time() - start_time

    print(f"\nFederated Results:")
    print(f"  Best Accuracy: {best_acc:.2%}")
    print(f"  Training Time: {training_time:.2f}s")

    return best_acc, training_time, history

# -------------------------------------------------
# 8. Find Optimal Number of Clients
# -------------------------------------------------
def find_optimal_clients(client_range, global_rounds, local_epochs, device, batch_size):
    """
    Test different numbers of clients to find optimal
    """
    print("\n" + "="*70)
    print("FINDING OPTIMAL NUMBER OF CLIENTS")
    print("="*70)

    results = []

    for n_clients in client_range:
        print(f"\n{'='*70}")
        print(f"Testing with {n_clients} clients")
        print(f"{'='*70}")

        acc, train_time, history = train_federated(
            n_clients, global_rounds, local_epochs, device, batch_size
        )

        results.append({
            'n_clients': n_clients,
            'accuracy': acc,
            'time': train_time,
            'history': history
        })

    return results

# -------------------------------------------------
# 9. Main Execution
# -------------------------------------------------
@torch.no_grad()
def find_max_batch(model, device, start=32, max_batch=256):
    b = start
    while b <= max_batch:
        try:
            x = torch.randn(b, 3, 32, 32, device=device)
            _ = model(x)
            b *= 2
        except RuntimeError:
            break
    return min(b//2, max_batch)

# Auto-tune batch size
temp_model = HybridQCNN().to(DEVICE)
BATCH_SIZE = find_max_batch(temp_model, DEVICE, 32, 256)
print(f"Auto-selected batch size: {BATCH_SIZE}\n")
del temp_model

# -------------------------------------------------
# Run Experiments
# -------------------------------------------------

# 1. Baseline (Non-FL)
baseline_acc, baseline_time, baseline_history = train_non_fl(
    COMPARISON_EPOCHS, DEVICE, BATCH_SIZE
)

# 2. Federated Learning
fl_acc, fl_time, fl_history = train_federated(
    N_CLIENTS, GLOBAL_ROUNDS, LOCAL_EPOCHS, DEVICE, BATCH_SIZE
)

# 3. Find optimal number of clients
client_range = [3, 5, 7, 10]
optimal_results = find_optimal_clients(
    client_range, GLOBAL_ROUNDS, LOCAL_EPOCHS, DEVICE, BATCH_SIZE
)

# -------------------------------------------------
# 10. Results Comparison
# -------------------------------------------------
print("\n" + "="*70)
print("FINAL COMPARISON")
print("="*70)

print(f"\n{'Method':<20} {'Accuracy':<15} {'Time (s)':<15} {'Time Ratio'}")
print("-" * 70)
print(f"{'Baseline (Non-FL)':<20} {baseline_acc:>6.2%}{'':<9} {baseline_time:>8.2f}{'':<7} 1.00x")
print(f"{'Federated ({N_CLIENTS} clients)':<20} {fl_acc:>6.2%}{'':<9} {fl_time:>8.2f}{'':<7} {fl_time/baseline_time:.2f}x")

print(f"\n{'Optimal Client Search:'}")
print("-" * 70)
for res in optimal_results:
    n = res['n_clients']
    acc = res['accuracy']
    t = res['time']
    print(f"{'  ' + str(n) + ' clients':<20} {acc:>6.2%}{'':<9} {t:>8.2f}{'':<7} {t/baseline_time:.2f}x")

# Find best configuration
best_config = max(optimal_results, key=lambda x: x['accuracy'])
print(f"\n✓ Best Configuration: {best_config['n_clients']} clients")
print(f"  Accuracy: {best_config['accuracy']:.2%}")
print(f"  Time: {best_config['time']:.2f}s")

# -------------------------------------------------
# 11. Visualization
# -------------------------------------------------
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot 1: Accuracy comparison
ax1.plot(range(1, len(baseline_history) + 1),
         [x * 100 for x in baseline_history],
         'b-', label='Baseline (Non-FL)', linewidth=2)
ax1.plot(range(1, len(fl_history) + 1),
         [x * 100 for x in fl_history],
         'r-', label=f'FL ({N_CLIENTS} clients)', linewidth=2)
ax1.set_xlabel('Epoch / Round')
ax1.set_ylabel('Test Accuracy (%)')
ax1.set_title('Accuracy: FL vs Non-FL')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Clients vs Performance
clients = [r['n_clients'] for r in optimal_results]
accs = [r['accuracy'] * 100 for r in optimal_results]
times = [r['time'] for r in optimal_results]

ax2_twin = ax2.twinx()
ax2.bar([x - 0.2 for x in clients], accs, width=0.4,
        color='skyblue', label='Accuracy')
ax2_twin.plot(clients, times, 'ro-', linewidth=2,
              markersize=8, label='Time')

ax2.set_xlabel('Number of Clients')
ax2.set_ylabel('Accuracy (%)', color='blue')
ax2_twin.set_ylabel('Training Time (s)', color='red')
ax2.set_title('Optimal Number of Clients')
ax2.tick_params(axis='y', labelcolor='blue')
ax2_twin.tick_params(axis='y', labelcolor='red')
ax2.legend(loc='upper left')
ax2_twin.legend(loc='upper right')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('federated_qcnn_comparison.png', dpi=150, bbox_inches='tight')
print(f"\n✓ Visualization saved as 'federated_qcnn_comparison.png'")

plt.show()
