## Startup


In [None]:
try:
    import breaching
except ModuleNotFoundError:
    import os; os.chdir("..")
    import breaching
    
    
import torch
%load_ext autoreload
%autoreload 2

# Redirects logs directly into the jupyter notebook
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

## Initialize cfg object and system setup:

This will load the full configuration object. This includes the configuration for the use case and threat model as *cfg.case* and the hyperparameters and implementation of the attack as *cfg.attack*. All parameters can be modified below, or overriden with *overrides=* as if they were cmd-line arguments.

In [None]:
cfg = breaching.get_config(overrides=["case=4_fedavg_small_scale", "case/data=CIFAR10"])
      
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

## Modify config options here

Using *.attribute* you can access and modify any of these configurations for the attack, or the case:

In [None]:
print(f"cfg experiment: ", cfg.case['name'])
print(f"cfg model: ", cfg.case['model'])
print(f"DATA: ", cfg.case['data']['name'])
print(f"IMPLEMENTATION: ", cfg.case['impl'])
print(f"USER: ", cfg.case['user'])
print(f"SERVER: ", cfg.case['server'])
print(f"ATTACK: ", cfg.attack)

In [None]:
cfg.case.data.partition="random"
cfg.case.user.user_idx = 1
cfg.case.model ='resnet50'
cfg.case.server.pretrained = True
cfg.case.user.provide_labels = True
# These settings govern the total amount of user data and how it is used over multiple local update steps:
cfg.case.user.num_data_points = 1  # Default 4 [n in fedAVG]
cfg.case.user.num_local_updates = 5  # Default 4 [E in fedAVG]
cfg.case.user.num_data_per_local_update_step = 1  # Default 2 [B in fedAVG]
cfg.case.user.local_learning_rate = 1e-4 # Default  1e-3 
cfg.case.user.provide_local_hyperparams = True # Default True

# Total variation regularization needs to be smaller on CIFAR-10:
cfg.attack.regularization.total_variation.scale = 1e-4 # Default 1e-3
cfg.attack.optim.max_iterations = 24000 # Default 24000
cfg.attack.optim.step_size = 1

In [None]:
print(f"cfg model: ", cfg.case['model'])
print(f"ATTACK: ", cfg.attack)
print(f"SERVER: ", cfg.case['server'])
print(f"USER: ", cfg.case['user'])

## Instantiate all parties

The following lines generate "server, "user" and "attacker" objects and print an overview of their configurations.

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

## Simulate an attacked FL protocol

This exchange is a simulation of a single query in a federated learning protocol. The server sends out a **server_payload** and the user computes an update based on their private local data. This user update is **shared_data** and contains, for example, the parameter gradient of the model in the simplest case. **true_user_data** is also returned by *.compute_local_updates*, but of course not forwarded to the server or attacker and only used for the analysis.

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

## Reconstruct user data

Now we launch the attack, reconstructing user data based on only the **server_payload** and **the shared_data**.


In [None]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], {}, dryrun=cfg.dryrun)

Next we'll evaluate metrics, comparing the *reconstructed_user_data* to the *true_user_data*.

In [None]:
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

And finally, we also plot both the reconstructed data and original data:

In [None]:
# The function user.plot doesn't work properly unless you make that correction
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [None]:
user.plot(true_user_data,  print_labels=True, scale=True)

In [None]:
user.plot(reconstructed_user_data,print_labels=True, scale=True)