# MNIST All Experiments

This notebook runs a comprehensive set of experiments for the AD-SFL project on the MNIST dataset.

**Configurations:**
- **Seeds:** 3 different random seeds
- **Distributions:** IID, Non-IID
- **Poison Ratios:** 0.0, 0.3, 0.5, 0.7
- **Attacks:** Label Flip, Backdoor
- **Fixed Parameters:** 
    - Clients: 10
    - Rounds: 25
    - Defense: SafeSplit (Enabled)

Results are aggregated across seeds and reported with 95% Confidence Intervals.

In [None]:
# Setup
import sys
import os

# Ensure we can import from the current directory
sys.path.append(os.getcwd())

import torch
import numpy as np
import scipy.stats as stats
import random
from collections import defaultdict
from simulation import run_simulation

# Auto-install dependencies if needed (uncomment if necessary)
# !pip install torch torchvision numpy scipy matplotlib

In [None]:
# Configuration
SEEDS = [42, 100, 2024]

DISTRIBUTIONS = ['iid', 'non_iid']
POISON_RATIOS = [0.0, 0.3, 0.5, 0.7]
ATTACKS = ['label_flip', 'backdoor']

# Fixed params
NUM_CLIENTS = 10
ROUNDS = 25
DEFENSE_ENABLED = True
DATASET = 'mnist'

# Attack specific params
BACKDOOR_SOURCE_LABELS = [1, 2, 3]
BACKDOOR_TARGET_LABEL = 8

results = defaultdict(list)

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

print(f"Starting Experiments...")

total_runs = len(SEEDS) * len(DISTRIBUTIONS) * len(POISON_RATIOS) * len(ATTACKS)
run_count = 0

for seed in SEEDS:
    print(f"\n=== Running Seed {seed} ===")
    set_seed(seed)
    
    for dist in DISTRIBUTIONS:
        for ratio in POISON_RATIOS:
            for attack in ATTACKS:
                run_count += 1
                print(f"[{run_count}/{total_runs}] Running: {dist} | Ratio: {ratio} | Attack: {attack}...")
                
                config = {
                    "num_clients": NUM_CLIENTS,
                    "rounds": ROUNDS,
                    "poison_ratio": ratio,
                    "dataset": DATASET,
                    "distribution": dist,
                    "defense_enabled": DEFENSE_ENABLED,
                    "attack_type": attack,
                    "batch_size": 32,
                    "test_batch_size": 256,
                    # Specifics
                    "source_labels": BACKDOOR_SOURCE_LABELS if attack == 'backdoor' else None,
                    "target_label": BACKDOOR_TARGET_LABEL if attack == 'backdoor' else 0
                }
                
                # Run Simulation
                hist = run_simulation(config)
                
                # Get final metrics
                final_acc = hist['acc'][-1]
                final_asr = hist['asr'][-1]
                
                # Store key: (Dist, Ratio, Attack)
                key = (dist, ratio, attack)
                results[key].append((final_acc, final_asr))
                
                print(f"   -> Done. Acc: {final_acc:.2f}%, ASR: {final_asr:.2f}%")

In [None]:
# Statistics and Reporting
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, h

import scipy

print("\n\n=== Final Aggregated Results (3 Seeds) ===")
print(f"{'Distribution':<12} | {'Ratio':<6} | {'Attack':<12} | {'Accuracy (Mean ± 95% CI)':<30} | {'ASR (Mean ± 95% CI)':<30}")
print("-"*100)

# Sort keys for nice printing
sorted_keys = sorted(results.keys(), key=lambda x: (x[0], x[1], x[2]))

for key in sorted_keys:
    dist, ratio, attack = key
    values = results[key]
    
    accs = [v[0] for v in values]
    asrs = [v[1] for v in values]
    
    acc_m, acc_ci = mean_confidence_interval(accs)
    asr_m, asr_ci = mean_confidence_interval(asrs)
    
    print(f"{dist:<12} | {ratio:<6} | {attack:<12} | {acc_m:.2f} ± {acc_ci:.2f}%{'':<18} | {asr_m:.2f} ± {asr_ci:.2f}%")