# MNIST All Experiments (Multi-GPU Parallel)

This notebook runs a comprehensive set of experiments for the AD-SFL project on the MNIST dataset.

**Configurations:**
- **Seeds:** 3 different random seeds
- **Distributions:** IID, Non-IID
- **Poison Ratios:** 0.0, 0.3, 0.5, 0.7
- **Attacks:** Label Flip, Backdoor
- **Fixed Parameters:** 
    - Clients: 10
    - Rounds: 25
    - Defense: SafeSplit (Enabled)

**Hardware:**
- Supports parallel execution on multiple GPUs (e.g. 2x T4 on Kaggle).
- Automatically detects available GPUs and distributes work.

In [None]:
%%writefile experiment_runner.py
import os
import sys
import traceback

def run_worker_chunk(gpu_id, configs):
    """
    Worker function to run a chunk of experiments on a specific GPU.
    Sets CUDA_VISIBLE_DEVICES to isolate the GPU for this process.
    """
    # Set GPU isolation BEFORE importing torch/simulation
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    
    import torch
    # Import simulation here to ensure it sees the specific GPU as cuda:0
    import simulation
    
    results = []
    
    print(f"[Worker GPU {gpu_id}] Started. Processing {len(configs)} configs.")
    
    for i, config in enumerate(configs):
        try:
            # print(f"[Worker GPU {gpu_id}] Running exp {i+1}/{len(configs)}: {config['attack_type']} | ratio {config['poison_ratio']}")
            hist = simulation.run_simulation(config)
            results.append({
                'config': config, 
                'acc': hist['acc'][-1], 
                'asr': hist['asr'][-1],
                'error': None
            })
        except Exception as e:
            print(f"[Worker GPU {gpu_id}] Error in exp {i+1}: {e}")
            traceback.print_exc()
            results.append({
                'config': config,
                'acc': 0.0,
                'asr': 0.0,
                'error': str(e)
            })
            
    print(f"[Worker GPU {gpu_id}] Finished.")
    return results

In [None]:
# Setup
import sys
import os
import torch
import numpy as np
import scipy.stats as stats
import random
from collections import defaultdict
import torch.multiprocessing as mp
# Ensure we can import from the current directory

sys.path.append(os.getcwd())

# Import our runner module (just created)
import experiment_runner

# Force spawn method for safe GPU multiprocessing
try:
    mp.set_start_method('spawn', force=True)
except RuntimeError:
    pass

# Auto-install dependencies if needed
# !pip install torch torchvision numpy scipy matplotlib

In [None]:
# Configuration
SEEDS = [42, 100, 2024]

DISTRIBUTIONS = ['iid', 'non_iid']
POISON_RATIOS = [0.0, 0.3, 0.5, 0.7]
ATTACKS = ['label_flip', 'backdoor']

# Fixed params
NUM_CLIENTS = 10
ROUNDS = 25
DEFENSE_ENABLED = True
DATASET = 'mnist'

# Attack specific params
BACKDOOR_SOURCE_LABELS = [1, 2, 3]
BACKDOOR_TARGET_LABEL = 8

# Generate all configurations
all_configs = []
for seed in SEEDS:
    for dist in DISTRIBUTIONS:
        for ratio in POISON_RATIOS:
            for attack in ATTACKS:
                config = {
                    "num_clients": NUM_CLIENTS,
                    "rounds": ROUNDS,
                    "poison_ratio": ratio,
                    "dataset": DATASET,
                    "distribution": dist,
                    "defense_enabled": DEFENSE_ENABLED,
                    "attack_type": attack,
                    "batch_size": 32,
                    "test_batch_size": 256,
                    # Specifics
                    "source_labels": BACKDOOR_SOURCE_LABELS if attack == 'backdoor' else None,
                    "target_label": BACKDOOR_TARGET_LABEL if attack == 'backdoor' else 0,
                    # Metadata for aggregation
                    "meta_seed": seed,
                    "meta_dist": dist,
                    "meta_ratio": ratio,
                    "meta_attack": attack
                }
                all_configs.append(config)

print(f"Total Experiments to Run: {len(all_configs)}")

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

# Distribution Logic
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
    print("No generic GPU detected. Running sequentially on CPU/single device (slow).")
    num_workers = 1
    gpu_indices = [0] # Dummy
else:
    print(f"Detected {num_gpus} GPUs. Parallelizing...")
    num_workers = num_gpus
    gpu_indices = list(range(num_gpus))

# Set seeds for main process
# (Individual runs set their own seeds via torch/numpy inside run_simulation if needed, 
# but simulation.py usually relies on global state. 
# Since we use multiprocessing spawn, each process starts fresh.)

# Split configs
chunks = [[] for _ in range(num_workers)]
for i, conf in enumerate(all_configs):
    chunks[i % num_workers].append(conf)

print(f"Work distribution: {[len(c) for c in chunks]} runs per worker.")

if __name__ == '__main__':
    # pool.starmap needs importable func. experiment_runner.run_worker_chunk is importable.
    with mp.Pool(processes=num_workers) as pool:
        args = [(gpu_indices[i], chunks[i]) for i in range(num_workers)]
        results_nested = pool.starmap(experiment_runner.run_worker_chunk, args)
        
    # Flatten results
    all_results = []
    for worker_res in results_nested:
        all_results.extend(worker_res)
        
    print("\nAll experiments completed.")

In [None]:
# Aggregation and Reporting
results_agg = defaultdict(list)

for res in all_results:
    if res.get('error'):
        print(f"Skipping failed run: {res['config']['meta_dist']} | {res['config']['meta_ratio']} | {res['config']['meta_attack']} - Error: {res['error']}")
        continue
        
    dist = res['config']['meta_dist']
    ratio = res['config']['meta_ratio']
    attack = res['config']['meta_attack']
    
    results_agg[(dist, ratio, attack)].append((res['acc'], res['asr']))

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    if n < 2: return np.mean(a), 0.0
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, h

import scipy

print("\n\n=== Final Aggregated Results (3 Seeds) ===")
print(f"{'Distribution':<12} | {'Ratio':<6} | {'Attack':<12} | {'Accuracy (Mean ± 95% CI)':<30} | {'ASR (Mean ± 95% CI)':<30}")
print("-"*100)

sorted_keys = sorted(results_agg.keys(), key=lambda x: (x[0], x[1], x[2]))

for key in sorted_keys:
    dist, ratio, attack = key
    values = results_agg[key]
    
    if not values:
        continue
        
    accs = [v[0] for v in values]
    asrs = [v[1] for v in values]
    
    acc_m, acc_ci = mean_confidence_interval(accs)
    asr_m, asr_ci = mean_confidence_interval(asrs)
    
    print(f"{dist:<12} | {ratio:<6} | {attack:<12} | {acc_m:.2f} ± {acc_ci:.2f}%{'':<18} | {asr_m:.2f} ± {asr_ci:.2f}%")