In [1]:
"""Minimal example: Running SEER attack with Breaching framework."""

import torch
import breaching

def run_seer_attack():
    """Simple SEER attack example."""
    
    # Step 1: Load configuration
    cfg = breaching.get_config(overrides=[
        "attack=seer",                           # Use SEER attack
        "case=1_single_image_small",             # Small single image case
        "attack.param_selection.frac=0.001",     # Select 0.1% of parameters
        "attack.optim.max_iterations=500"        # Reduce iterations for quick test
    ])
    
    # Step 2: Setup device
    setup = dict(
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        dtype=torch.float32
    )
    
    # Step 3: Construct case (model, loss, user, server)
    user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
    
    print("="*80)
    print("SEER Attack - Minimal Example")
    print("="*80)
    print(f"\nModel: {model.__class__.__name__}")
    print(f"User data points: {cfg.case.user.num_data_points}")
    print(f"Parameter selection fraction: {cfg.attack.param_selection.frac}")
    print(f"Max iterations: {cfg.attack.optim.max_iterations}")
    
    # Step 4: Create SEER attacker
    attacker = breaching.attacks.prepare_attack(model, loss_fn, cfg.attack, setup)
    print(f"\nAttacker: {attacker}")
    
    # Step 5: Run federated learning protocol
    print("\nRunning FL protocol...")
    shared_user_data, payloads, true_user_data = server.run_protocol(user)
    
    print(f"Shared gradient shape: {shared_user_data[0]['gradients'][0].shape}")
    print(f"True data shape: {true_user_data['data'].shape}")
    
    # Step 6: Run SEER attack
    print("\nRunning SEER attack...")
    reconstructed_data, stats = attacker.reconstruct(
        payloads, 
        shared_user_data, 
        server.secrets if hasattr(server, 'secrets') else None,
        dryrun=False
    )
    
    print(f"\nReconstruction shape: {reconstructed_data['data'].shape}")
    print(f"Final loss: {stats['loss'][-1]:.6f}")
    
    # Step 7: Evaluate attack
    print("\nEvaluating reconstruction quality...")
    metrics = breaching.analysis.report(
        reconstructed_data,
        true_user_data,
        payloads,
        model,
        cfg_case=cfg.case,
        setup=setup
    )
    
    print("\nAttack metrics:")
    for key, value in metrics.items():
        if isinstance(value, torch.Tensor):
            print(f"  {key}: {value.item():.4f}")
        elif isinstance(value, (int, float)):
            print(f"  {key}: {value:.4f}")


if __name__ == "__main__":
    run_seer_attack()

ModuleNotFoundError: No module named 'breaching'