In [None]:
# VENOM Quantum Backdoor Attack - Adaptive & Stealthy Implementation
# Includes: Tunable Poison Ratio, Adaptive Quantum Injection, and FedAvg Defenses

import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from copy import deepcopy

# ========== Config ==========
num_clients = 10
malicious_client_fraction = 0.2  # 20% of selected clients are malicious
poison_ratio = 0.2               # 20% of local data is poisoned
target_label = 0                 # Target class for the backdoor
trigger_interval = 5            # Inject VENOM every n rounds
num_byzantine = int(num_clients * malicious_client_fraction)

defense = 'trimmed_mean'        # Options: 'fedavg', 'trimmed_mean', 'krum'

# ========== Dummy Model ==========
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(784, 10)

    def forward(self, x):
        return self.fc(x.view(x.size(0), -1))

# ========== Poisoning Logic ==========
def poison_data(data, labels, poison_ratio, target_label):
    num_poison = int(len(data) * poison_ratio)
    poisoned_data = data[:num_poison]
    poisoned_labels = torch.full((num_poison,), target_label, dtype=torch.long)
    clean_data = data[num_poison:]
    clean_labels = labels[num_poison:]
    return torch.cat([clean_data, poisoned_data]), torch.cat([clean_labels, poisoned_labels])

# ========== VENOM Injection (Adaptive Quantum Resonance) ==========
def generate_quantum_resonance_pattern(round_num):
    # Simulates adaptive quantum perturbation (dummy implementation)
    return torch.sin(torch.tensor(round_num * np.pi / 10))

def inject_backdoor(model, quantum_seed):
    with torch.no_grad():
        for param in model.parameters():
            param.add_(quantum_seed * torch.randn_like(param) * 0.01)  # Adaptive, small perturbations

# ========== Local Training ==========
def local_train(model, data, labels, malicious=False, round_num=0):
    model = deepcopy(model)
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    loss_fn = nn.CrossEntropyLoss()

    if malicious:
        data, labels = poison_data(data, labels, poison_ratio, target_label)
        if round_num % trigger_interval == 0:
            quantum_seed = generate_quantum_resonance_pattern(round_num)
            inject_backdoor(model, quantum_seed)

    model.train()
    for _ in range(1):  # Simulated local epoch
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()

    return model.state_dict()

# ========== Aggregation Strategies ==========
def average_models(models):
    avg_model = deepcopy(models[0])
    for key in avg_model:
        for i in range(1, len(models)):
            avg_model[key] += models[i][key]
        avg_model[key] /= len(models)
    return avg_model

def trimmed_mean(updates, trim_ratio=0.1):
    keys = updates[0].keys()
    trimmed = {}
    for key in keys:
        stacked = torch.stack([update[key] for update in updates])
        sorted_vals, _ = torch.sort(stacked, dim=0)
        trim_n = int(len(updates) * trim_ratio)
        trimmed[key] = torch.mean(sorted_vals[trim_n:-trim_n], dim=0)
    return trimmed

def krum(updates, num_byzantine):
    distances = []
    for i, update_i in enumerate(updates):
        dists = [torch.norm(update_i[key] - update_j[key]) for j, update_j in enumerate(updates) if i != j for key in update_i]
        dists = torch.tensor(dists).view(len(updates)-1, -1).sum(dim=1)
        dists, _ = torch.sort(dists)
        distances.append((i, torch.sum(dists[:len(updates) - num_byzantine - 2])))
    chosen_index = min(distances, key=lambda x: x[1])[0]
    return updates[chosen_index]

def aggregate(updates):
    if defense == 'fedavg':
        return average_models(updates)
    elif defense == 'trimmed_mean':
        return trimmed_mean(updates)
    elif defense == 'krum':
        return krum(updates, num_byzantine)
    else:
        raise ValueError("Invalid defense strategy")

# ========== Federated Loop ==========
def federated_learning(rounds, dataset):
    global_model = SimpleNN()

    for r in range(rounds):
        selected_clients = random.sample(range(num_clients), num_clients)
        num_malicious = int(malicious_client_fraction * num_clients)
        malicious_clients = set(random.sample(selected_clients, num_malicious))

        local_updates = []

        for i in selected_clients:
            # Simulated local data and labels
            data = torch.randn(32, 1, 28, 28)  # fake batch of MNIST
            labels = torch.randint(0, 10, (32,))
            is_malicious = i in malicious_clients

            update = local_train(global_model, data, labels, malicious=is_malicious, round_num=r)
            local_updates.append(update)

        aggregated_update = aggregate(local_updates)
        global_model.load_state_dict(aggregated_update)

        print(f"Round {r+1} completed")

    return global_model