In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from copy import deepcopy
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import KMeans, DBSCAN
from sklearn.decomposition import PCA
from torchvision import datasets, transforms
import logging
import os
from datetime import datetime
import matplotlib.pyplot as plt
from collections import defaultdict

# Model Architecture
class CIFAR10CNN(nn.Module):
    def __init__(self):
        super(CIFAR10CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def get_model_params(model):
    return {name: param.clone() for name, param in model.state_dict().items()}

def set_model_params(model, params):
    model.load_state_dict(params)

# Client Implementation
class Client:
    def __init__(self, client_id, train_data, test_data, is_malicious=False, attack_type=None, target_label=7, attack_prob=0.8):
        self.client_id = client_id
        self.train_data = train_data
        self.test_data = test_data
        self.is_malicious = is_malicious
        self.attack_type = attack_type  # 'sign_flip', 'add_noise', 'label_flip', 'multi_label_flip'
        self.target_label = target_label
        self.attack_prob = attack_prob
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def add_backdoor_trigger(self, image):
        img = deepcopy(image)
        h, w = img.shape[1:]
        x = np.random.randint(2, w-3)
        y = np.random.randint(2, h-3)
        trigger_w = np.random.randint(2, 4)
        trigger_h = np.random.randint(2, 4)
        img[0, y:y+trigger_h, x-trigger_w:x+trigger_w] = 1.0
        img[0, y-trigger_w:y+trigger_w, x:x+trigger_h] = 1.0
        return img

    def poison_dataset(self):
        if not self.is_malicious:
            return self.train_data
            
        poisoned_images = []
        poisoned_labels = []
        for img, label in self.train_data:
            if np.random.random() < 0.5:
                if self.attack_type == 'label_flip':
                    img = self.add_backdoor_trigger(img)
                    label = self.target_label
                elif self.attack_type == 'multi_label_flip':
                    if label in [1, 2, 3]:  # Source labels
                        label = self.target_label
            poisoned_images.append(img)
            poisoned_labels.append(label)
        return TensorDataset(torch.stack(poisoned_images), torch.tensor(poisoned_labels))

    def train(self, model, epochs=1, batch_size=32):
        model.train().to(self.device)
        attack_this_round = self.is_malicious and np.random.rand() < self.attack_prob
        train_data = self.poison_dataset() if attack_this_round else self.train_data
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
        
        for _ in range(epochs):
            for data, target in train_loader:
                data, target = data.to(self.device), target.to(self.device)
                optimizer.zero_grad()
                output = model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                optimizer.step()
                
        params = get_model_params(model)
        
        # Apply gradient attacks if needed
        if attack_this_round and self.attack_type == 'sign_flip':
            params = {k: -v for k, v in params.items()}
        elif attack_this_round and self.attack_type == 'add_noise':
            params = {k: v + torch.randn_like(v)*0.1 for k, v in params.items()}
            
        return params

# Enhanced Server with MUD-HoG Defense
class Server:
    def __init__(self, model, clients, defense_type='none', window_size=3, alpha=0.5):
        self.model = model
        self.clients = clients
        self.defense_type = defense_type
        self.window_size = window_size
        self.alpha = alpha
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # MUD-HoG specific storage
        self.short_hog = defaultdict(list)
        self.long_hog = defaultdict(list)
        self.malicious_ids = {c.client_id for c in clients if c.is_malicious}

    def _vectorize_update(self, update):
        return np.concatenate([p.cpu().numpy().flatten() for p in update.values()])

    def update_gradient_history(self, updates):
        for i, update in enumerate(updates):
            client_id = self.clients[i].client_id
            grad_vec = self._vectorize_update(update)
            
            # Update short HoG (moving window)
            if len(self.short_hog[client_id]) >= self.window_size:
                self.short_hog[client_id].pop(0)
            self.short_hog[client_id].append(grad_vec)
            
            # Update long HoG (cumulative sum)
            if client_id not in self.long_hog:
                self.long_hog[client_id] = np.zeros_like(grad_vec)
            self.long_hog[client_id] += grad_vec

    def calculate_metrics(self, detected):
        tp = len(detected & self.malicious_ids)
        fp = len(detected - self.malicious_ids)
        fn = len(self.malicious_ids - detected)
        tn = len({c.client_id for c in self.clients}) - tp - fp - fn
        return {
            'dr': tp/(tp+fn) if (tp+fn) > 0 else 0,
            'fpr': fp/(fp+tn) if (fp+tn) > 0 else 0,
            'precision': tp/(tp+fp) if (tp+fp) > 0 else 0
        }

    def mudhog_defense(self, updates):
        self.update_gradient_history(updates)
        detected = set()
        weights = np.ones(len(updates))
        
        # Step 1: Detect sign-flipping attackers
        all_shogs = [np.mean(self.short_hog[c.client_id], axis=0) 
                    if self.short_hog[c.client_id] else np.zeros_like(self._vectorize_update(updates[0]))
                    for c in self.clients]
        median_shog = np.median(all_shogs, axis=0)
        
        sign_flippers = []
        for i, client in enumerate(self.clients):
            if not self.short_hog[client.client_id]:
                continue
            shog = np.mean(self.short_hog[client.client_id], axis=0)
            cos_sim = cosine_similarity([shog], [median_shog])[0][0]
            if cos_sim < 0:
                sign_flippers.append(i)
                detected.add(client.client_id)
        
        # Step 2: Detect additive-noise attackers with DBSCAN
        remaining = [i for i in range(len(self.clients)) if i not in sign_flippers]
        if remaining:
            remaining_shogs = [all_shogs[i] for i in remaining]
            clustering = DBSCAN(eps=0.5, min_samples=1).fit(remaining_shogs)
            
            # Find largest cluster as normal group
            labels, counts = np.unique(clustering.labels_, return_counts=True)
            main_cluster = labels[np.argmax(counts)]
            
            # Others are potential attackers/unreliable
            outliers = [remaining[i] for i, lbl in enumerate(clustering.labels_) 
                       if lbl != main_cluster]
            
            # Separate using Euclidean distance
            main_median = np.median([remaining_shogs[i] for i, lbl in enumerate(clustering.labels_) 
                                   if lbl == main_cluster], axis=0)
            distances = euclidean_distances([main_median], [remaining_shogs[i] for i in outliers])[0]
            if len(distances) > 1:
                sorted_dists = np.sort(distances)
                gaps = sorted_dists[1:] - sorted_dists[:-1]
                max_gap_idx = np.argmax(gaps)
                threshold = (sorted_dists[max_gap_idx] + sorted_dists[max_gap_idx+1])/2
                additive_noise = [outliers[i] for i, d in enumerate(distances) if d > threshold]
                detected.update([self.clients[i].client_id for i in additive_noise])
        
        # Step 3: Detect targeted attackers using long HoG
        remaining = [i for i in range(len(self.clients)) if i not in sign_flippers+additive_noise]
        if remaining:
            lhogs = [self.long_hog[self.clients[i].client_id] for i in remaining]
            kmeans = KMeans(n_clusters=2).fit(lhogs)
            if np.sum(kmeans.labels_) < len(kmeans.labels_)/2:
                attacker_labels = 1
            else:
                attacker_labels = 0
            targeted = [remaining[i] for i, lbl in enumerate(kmeans.labels_) 
                       if lbl == attacker_labels]
            detected.update([self.clients[i].client_id for i in targeted])
        
        # Step 4: Detect unreliable clients
        remaining = [i for i in range(len(self.clients)) if i not in sign_flippers+additive_noise+targeted]
        if remaining:
            median_shog = np.median([all_shogs[i] for i in remaining], axis=0)
            distances = [cosine_similarity([all_shogs[i]], [median_shog])[0][0] 
                        for i in remaining]
            if len(distances) > 1:
                sorted_dists = np.sort(distances)
                gaps = sorted_dists[1:] - sorted_dists[:-1]
                max_gap_idx = np.argmax(gaps)
                threshold = (sorted_dists[max_gap_idx] + sorted_dists[max_gap_idx+1])/2
                unreliable = [remaining[i] for i, d in enumerate(distances) if d < threshold]
        
        # Apply weights
        for i in range(len(weights)):
            if self.clients[i].client_id in detected:
                weights[i] = 0  # Block malicious
            elif self.clients[i].client_id in unreliable:
                weights[i] *= self.alpha  # Downweight unreliable
        
        if weights.sum() > 0:
            weights /= weights.sum()
        else:
            weights = np.ones(len(updates))/len(updates)
            
        return weights, self.calculate_metrics(detected)

    def fools_gold_defense(self, updates):
        vectors = [np.concatenate([p.cpu().numpy().flatten() for p in u.values()]) for u in updates]
        sim_matrix = cosine_similarity(vectors)
        weights = np.ones(len(updates))
        detected = set()
        
        for i in range(len(updates)):
            for j in range(len(updates)):
                if i != j and sim_matrix[i][j] > 0.9:
                    weights[i] *= 0.5
                    detected.add(self.clients[i].client_id)
                    
        if weights.sum() > 0:
            weights /= weights.sum()
        else:
            weights = np.ones(len(updates))/len(updates)
            
        return weights, self.calculate_metrics(detected)

    def contra_defense(self, updates):
        vectors = [np.concatenate([p.cpu().numpy().flatten() for p in u.values()]) for u in updates]
        sim_matrix = cosine_similarity(vectors)
        avg_similarities = np.mean(sim_matrix, axis=1)
        contrast_scores = 1 / (1 + np.exp(-10*(avg_similarities - np.median(avg_similarities))))
        weights = 1 - contrast_scores
        detected = {self.clients[i].client_id for i in np.where(contrast_scores > 0.5)[0]}
        
        if weights.sum() > 0:
            weights /= weights.sum()
        else:
            weights = np.ones(len(updates))/len(updates)
            
        return weights, self.calculate_metrics(detected)

    def aggregate(self, client_updates):
        if self.defense_type == 'fools_gold':
            weights, metrics = self.fools_gold_defense(client_updates)
        elif self.defense_type == 'mudhog':
            weights, metrics = self.mudhog_defense(client_updates)
        elif self.defense_type == 'contra':
            weights, metrics = self.contra_defense(client_updates)
        else:
            weights = np.ones(len(client_updates))/len(client_updates)
            metrics = None

        aggregated_params = {}
        for name in client_updates[0]:
            aggregated_params[name] = sum(update[name]*weight for update, weight in zip(client_updates, weights))
            
        return aggregated_params, metrics

    def evaluate_backdoor(self, target_label):
        self.model.eval()
        success = 0
        total = 0
        
        for client in self.clients:
            if client.is_malicious:
                for data, _ in client.test_data:
                    poisoned_data = client.add_backdoor_trigger(data.clone())
                    poisoned_data = poisoned_data.unsqueeze(0).to(self.device)
                    output = self.model(poisoned_data)
                    pred = output.argmax(dim=1)
                    success += (pred == target_label).sum().item()
                    total += 1
                    
        return 100.0 * success / total if total > 0 else 0.0

    def train_round(self, local_epochs=5):
        global_params = get_model_params(self.model)
        client_updates = []
        
        for client in self.clients:
            set_model_params(self.model, global_params)
            client_updates.append(client.train(self.model, epochs=local_epochs))
            
        aggregated_params, metrics = self.aggregate(client_updates)
        set_model_params(self.model, aggregated_params)
        backdoor_sr = self.evaluate_backdoor(target_label=7)
        return backdoor_sr, metrics

# Experiment Framework
def setup_logging():
    if not os.path.exists('logs'):
        os.makedirs('logs')
        
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    
    handler = logging.FileHandler(f'logs/experiment_{timestamp}.log')
    handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
    logger.addHandler(handler)
    
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
    logger.addHandler(console_handler)
    
    return logger

def load_and_split_data(num_clients):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    train_set = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
    test_set = datasets.CIFAR10('./data', train=False, transform=transform)
    
    client_train = random_split(train_set, [len(train_set)//num_clients]*num_clients)
    client_test = random_split(test_set, [len(test_set)//num_clients]*num_clients)
    
    return client_train, client_test

def run_experiment(defense_type, malicious_pct, logger):
    num_clients = 20
    client_train, client_test = load_and_split_data(num_clients)
    num_malicious = int(num_clients * malicious_pct)
    
    clients = []
    for i in range(num_clients):
        attack_type = None
        if i < num_malicious:
            # Assign different attack types
            attack_type = ['sign_flip', 'add_noise', 'label_flip', 'multi_label_flip'][i % 4]
        
        clients.append(Client(
            i, client_train[i], client_test[i],
            is_malicious=(i < num_malicious),
            attack_type=attack_type,
            target_label=7
        ))
    
    model = CIFAR10CNN()
    server = Server(model, clients, defense_type=defense_type)
    backdoor_sr, metrics = server.train_round()
    
    logger.info(f"{defense_type.upper()} | Malicious: {num_malicious} | "
               f"Success Rate: {backdoor_sr:.2f}% | "
               f"Detection Rate: {metrics['dr']*100 if metrics else 0:.2f}%")
    
    return {
        'defense': defense_type,
        'malicious': num_malicious,
        'success_rate': backdoor_sr,
        'detection_rate': metrics['dr'] if metrics else 0
    }

def plot_results(results):
    plt.figure(figsize=(12, 6))
    defenses = ['none', 'fools_gold', 'mudhog', 'contra']
    colors = ['red', 'blue', 'green', 'purple']
    markers = ['x', 'o', '^', 's']
    
    for defense, color, marker in zip(defenses, colors, markers):
        defense_data = [r for r in results if r['defense'] == defense]
        x = [d['malicious'] for d in defense_data]
        y = [d['success_rate'] for d in defense_data]
        plt.plot(x, y, f'{color}{marker}--', linewidth=2, markersize=10, label=defense)
    
    plt.xlabel('Number of Malicious Clients', fontsize=12)
    plt.ylabel('Backdoor Attack Success Rate (%)', fontsize=12)
    plt.title('Defense Mechanism Comparison', fontsize=14)
    plt.xticks([2, 3, 4, 5])
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.tight_layout()
    plt.savefig('defense_comparison.png')
    plt.show()

def main():
    logger = setup_logging()
    results = []
    defenses = ['mudhog', 'none', 'fools_gold' , 'contra']
    malicious_pcts = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3,0.35,  0.4,0.45,  0.5]
    
    for defense in defenses:
        logger.info(f"\n=== Testing {defense.upper()} Defense ===")
        for pct in malicious_pcts:
            try:
                result = run_experiment(defense, pct, logger)
                results.append(result)
            except Exception as e:
                logger.error(f"Error in {defense} {pct}: {str(e)}")
    
    plot_results(results)

if __name__ == "__main__":
    main()

2025-03-03 08:37:04,781 - INFO - 
=== Testing MUDHOG Defense ===


Files already downloaded and verified


: 